add --upgrade-all flag to refresh any stale models

This commit is contained in:
Patrick Devine 2024-01-24 14:16:03 -08:00
parent b5cf31b460
commit 021b1bdc4a
5 changed files with 99 additions and 14 deletions

View File

@ -183,11 +183,12 @@ type CopyRequest struct {
}
type PullRequest struct {
Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"`
Username string `json:"username"`
Password string `json:"password"`
Stream *bool `json:"stream,omitempty"`
Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"`
Username string `json:"username"`
Password string `json:"password"`
Stream *bool `json:"stream,omitempty"`
CurrentDigest string `json:"current_digest,omitempty"`
// Name is deprecated, see Model
Name string `json:"name"`
@ -241,6 +242,7 @@ type GenerateResponse struct {
type ModelDetails struct {
ParentModel string `json:"parent_model"`
Digest string `json:"digest"`
Format string `json:"format"`
Family string `json:"family"`
Families []string `json:"families"`

View File

@ -11,6 +11,7 @@ import (
"fmt"
"io"
"log"
"log/slog"
"net"
"net/http"
"os"
@ -357,6 +358,62 @@ func CopyHandler(cmd *cobra.Command, args []string) error {
}
func PullHandler(cmd *cobra.Command, args []string) error {
upgradeAll, err := cmd.Flags().GetBool("upgrade-all")
if err != nil {
return err
}
if !upgradeAll {
if len(args) == 0 {
return fmt.Errorf("no model specified to pull")
}
return pull(cmd, args[0], "")
}
fp, err := server.GetManifestPath()
if err != nil {
return err
}
type modelInfo struct {
Name string
Digest string
}
var modelList []modelInfo
walkFunc := func(path string, info os.FileInfo, _ error) error {
if info.IsDir() {
return nil
}
dir, file := filepath.Split(path)
dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
tag := strings.Join([]string{dir, file}, ":")
model, err := server.GetModel(tag)
if err != nil {
return nil
}
modelList = append(modelList, modelInfo{tag, "sha256:" + model.Digest})
return nil
}
if err = filepath.Walk(fp, walkFunc); err != nil {
return err
}
for _, m := range modelList {
err = pull(cmd, m.Name, m.Digest)
if err != nil {
slog.Warn(fmt.Sprintf("couldn't pull model '%s'", m.Name))
}
}
return nil
}
func pull(cmd *cobra.Command, name string, currentDigest string) error {
insecure, err := cmd.Flags().GetBool("insecure")
if err != nil {
return err
@ -368,7 +425,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
}
p := progress.NewProgress(os.Stderr)
defer p.Stop()
defer p.StopWithoutClear()
bars := make(map[string]*progress.Bar)
@ -402,7 +459,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
return nil
}
request := api.PullRequest{Name: args[0], Insecure: insecure}
request := api.PullRequest{Name: name, Insecure: insecure, CurrentDigest: currentDigest}
if err := client.Pull(cmd.Context(), &request, fn); err != nil {
return err
}
@ -884,12 +941,13 @@ func NewCLI() *cobra.Command {
pullCmd := &cobra.Command{
Use: "pull MODEL",
Short: "Pull a model from a registry",
Args: cobra.ExactArgs(1),
Args: cobra.RangeArgs(0, 1),
PreRunE: checkServerHeartbeat,
RunE: PullHandler,
}
pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
pullCmd.Flags().Bool("upgrade-all", false, "Upgrade all models if they're out of date")
pushCmd := &cobra.Command{
Use: "push MODEL",

View File

@ -52,6 +52,10 @@ func (p *Progress) Stop() bool {
return stopped
}
func (p *Progress) StopWithoutClear() bool {
return p.stop()
}
func (p *Progress) StopAndClear() bool {
fmt.Fprint(p.w, "\033[?25l")
defer fmt.Fprint(p.w, "\033[?25h")

View File

@ -471,7 +471,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
switch {
case errors.Is(err, os.ErrNotExist):
fn(api.ProgressResponse{Status: "pulling model"})
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
if err := PullModel(ctx, c.Args, "", &RegistryOptions{}, fn); err != nil {
return err
}
@ -1041,7 +1041,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
return nil
}
func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
func PullModel(ctx context.Context, name, currentDigest string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
var manifest *ManifestV2
@ -1069,13 +1069,23 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
return fmt.Errorf("insecure protocol http")
}
fn(api.ProgressResponse{Status: "pulling manifest"})
if currentDigest == "" {
fn(api.ProgressResponse{Status: "pulling manifest"})
}
manifest, err = pullModelManifest(ctx, mp, regOpts)
manifest, err = pullModelManifest(ctx, mp, currentDigest, regOpts)
if err != nil {
return fmt.Errorf("pull model manifest: %s", err)
}
if currentDigest != "" {
if manifest == nil {
// we already have the model
return nil
}
fn(api.ProgressResponse{Status: "upgrading " + mp.GetShortTagname()})
}
var layers []*Layer
layers = append(layers, manifest.Layers...)
layers = append(layers, manifest.Config)
@ -1147,17 +1157,27 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
return nil
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
func pullModelManifest(ctx context.Context, mp ModelPath, currentDigest string, regOpts *RegistryOptions) (*ManifestV2, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
if currentDigest != "" {
headers.Set("If-None-Match", currentDigest)
}
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// todo we can potentially read the manifest locally and return it here
if resp.StatusCode == http.StatusNotModified {
return nil, nil
}
var m *ManifestV2
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return nil, err

View File

@ -451,7 +451,7 @@ func PullModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PullModel(ctx, model, regOpts, fn); err != nil {
if err := PullModel(ctx, model, req.CurrentDigest, regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@ -673,6 +673,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
modelDetails := api.ModelDetails{
ParentModel: model.ParentModel,
Digest: "sha256:" + model.Digest,
Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies,