diff --git a/api/types.go b/api/types.go index 609c4a8a..4c445209 100644 --- a/api/types.go +++ b/api/types.go @@ -183,11 +183,12 @@ type CopyRequest struct { } type PullRequest struct { - Model string `json:"model"` - Insecure bool `json:"insecure,omitempty"` - Username string `json:"username"` - Password string `json:"password"` - Stream *bool `json:"stream,omitempty"` + Model string `json:"model"` + Insecure bool `json:"insecure,omitempty"` + Username string `json:"username"` + Password string `json:"password"` + Stream *bool `json:"stream,omitempty"` + CurrentDigest string `json:"current_digest,omitempty"` // Name is deprecated, see Model Name string `json:"name"` @@ -241,6 +242,7 @@ type GenerateResponse struct { type ModelDetails struct { ParentModel string `json:"parent_model"` + Digest string `json:"digest"` Format string `json:"format"` Family string `json:"family"` Families []string `json:"families"` diff --git a/cmd/cmd.go b/cmd/cmd.go index 915fa993..61ee9986 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "log" + "log/slog" "net" "net/http" "os" @@ -357,6 +358,62 @@ func CopyHandler(cmd *cobra.Command, args []string) error { } func PullHandler(cmd *cobra.Command, args []string) error { + upgradeAll, err := cmd.Flags().GetBool("upgrade-all") + if err != nil { + return err + } + + if !upgradeAll { + if len(args) == 0 { + return fmt.Errorf("no model specified to pull") + } + return pull(cmd, args[0], "") + } + + fp, err := server.GetManifestPath() + if err != nil { + return err + } + + type modelInfo struct { + Name string + Digest string + } + + var modelList []modelInfo + + walkFunc := func(path string, info os.FileInfo, _ error) error { + if info.IsDir() { + return nil + } + + dir, file := filepath.Split(path) + dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator)) + tag := strings.Join([]string{dir, file}, ":") + + model, err := server.GetModel(tag) + if err != nil { + return nil + } + + modelList = append(modelList, modelInfo{tag, "sha256:" + model.Digest}) + return nil + } + + if err = filepath.Walk(fp, walkFunc); err != nil { + return err + } + + for _, m := range modelList { + err = pull(cmd, m.Name, m.Digest) + if err != nil { + slog.Warn(fmt.Sprintf("couldn't pull model '%s'", m.Name)) + } + } + return nil +} + +func pull(cmd *cobra.Command, name string, currentDigest string) error { insecure, err := cmd.Flags().GetBool("insecure") if err != nil { return err @@ -368,7 +425,7 @@ func PullHandler(cmd *cobra.Command, args []string) error { } p := progress.NewProgress(os.Stderr) - defer p.Stop() + defer p.StopWithoutClear() bars := make(map[string]*progress.Bar) @@ -402,7 +459,7 @@ func PullHandler(cmd *cobra.Command, args []string) error { return nil } - request := api.PullRequest{Name: args[0], Insecure: insecure} + request := api.PullRequest{Name: name, Insecure: insecure, CurrentDigest: currentDigest} if err := client.Pull(cmd.Context(), &request, fn); err != nil { return err } @@ -884,12 +941,13 @@ func NewCLI() *cobra.Command { pullCmd := &cobra.Command{ Use: "pull MODEL", Short: "Pull a model from a registry", - Args: cobra.ExactArgs(1), + Args: cobra.RangeArgs(0, 1), PreRunE: checkServerHeartbeat, RunE: PullHandler, } pullCmd.Flags().Bool("insecure", false, "Use an insecure registry") + pullCmd.Flags().Bool("upgrade-all", false, "Upgrade all models if they're out of date") pushCmd := &cobra.Command{ Use: "push MODEL", diff --git a/progress/progress.go b/progress/progress.go index 556ba00f..f615304d 100644 --- a/progress/progress.go +++ b/progress/progress.go @@ -52,6 +52,10 @@ func (p *Progress) Stop() bool { return stopped } +func (p *Progress) StopWithoutClear() bool { + return p.stop() +} + func (p *Progress) StopAndClear() bool { fmt.Fprint(p.w, "\033[?25l") defer fmt.Fprint(p.w, "\033[?25h") diff --git a/server/images.go b/server/images.go index ab3b4faa..9f2030f3 100644 --- a/server/images.go +++ b/server/images.go @@ -471,7 +471,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars switch { case errors.Is(err, os.ErrNotExist): fn(api.ProgressResponse{Status: "pulling model"}) - if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil { + if err := PullModel(ctx, c.Args, "", &RegistryOptions{}, fn); err != nil { return err } @@ -1041,7 +1041,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu return nil } -func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { +func PullModel(ctx context.Context, name, currentDigest string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) var manifest *ManifestV2 @@ -1069,13 +1069,23 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu return fmt.Errorf("insecure protocol http") } - fn(api.ProgressResponse{Status: "pulling manifest"}) + if currentDigest == "" { + fn(api.ProgressResponse{Status: "pulling manifest"}) + } - manifest, err = pullModelManifest(ctx, mp, regOpts) + manifest, err = pullModelManifest(ctx, mp, currentDigest, regOpts) if err != nil { return fmt.Errorf("pull model manifest: %s", err) } + if currentDigest != "" { + if manifest == nil { + // we already have the model + return nil + } + fn(api.ProgressResponse{Status: "upgrading " + mp.GetShortTagname()}) + } + var layers []*Layer layers = append(layers, manifest.Layers...) layers = append(layers, manifest.Config) @@ -1147,17 +1157,27 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu return nil } -func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) { +func pullModelManifest(ctx context.Context, mp ModelPath, currentDigest string, regOpts *RegistryOptions) (*ManifestV2, error) { requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) headers := make(http.Header) headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json") + + if currentDigest != "" { + headers.Set("If-None-Match", currentDigest) + } + resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts) if err != nil { return nil, err } defer resp.Body.Close() + // todo we can potentially read the manifest locally and return it here + if resp.StatusCode == http.StatusNotModified { + return nil, nil + } + var m *ManifestV2 if err := json.NewDecoder(resp.Body).Decode(&m); err != nil { return nil, err diff --git a/server/routes.go b/server/routes.go index 56c275c9..0b1a79e9 100644 --- a/server/routes.go +++ b/server/routes.go @@ -451,7 +451,7 @@ func PullModelHandler(c *gin.Context) { ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() - if err := PullModel(ctx, model, regOpts, fn); err != nil { + if err := PullModel(ctx, model, req.CurrentDigest, regOpts, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() @@ -673,6 +673,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { modelDetails := api.ModelDetails{ ParentModel: model.ParentModel, + Digest: "sha256:" + model.Digest, Format: model.Config.ModelFormat, Family: model.Config.ModelFamily, Families: model.Config.ModelFamilies,