diff --git a/server/download.go b/server/download.go index faa06dd2..6b1cfd56 100644 --- a/server/download.go +++ b/server/download.go @@ -23,6 +23,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/format" + "github.com/ollama/ollama/types/model" ) const maxRetries = 6 @@ -332,15 +333,16 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) } } -type downloadOpts struct { - mp ModelPath +type downloadOptions struct { + name model.Name + baseURL *url.URL digest string regOpts *registryOptions fn func(api.ProgressResponse) } // downloadBlob downloads a blob from the registry and stores it in the blobs directory -func downloadBlob(ctx context.Context, opts downloadOpts) error { +func downloadBlob(ctx context.Context, opts downloadOptions) error { fp, err := GetBlobsPath(opts.digest) if err != nil { return err @@ -365,8 +367,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error { data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest}) download := data.(*blobDownload) if !ok { - requestURL := opts.mp.BaseURL() - requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest) + requestURL := opts.baseURL.JoinPath("blobs", opts.digest) if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil { blobDownloadManager.Delete(opts.digest) return err diff --git a/server/images.go b/server/images.go index 67253ec7..96966f6f 100644 --- a/server/images.go +++ b/server/images.go @@ -16,6 +16,7 @@ import ( "net/http" "net/url" "os" + "path" "path/filepath" "runtime" "strconv" @@ -781,59 +782,43 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu return nil } -func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { - mp := ParseModelPath(name) +func PullModel(ctx context.Context, name model.Name, opts registryOptions, fn func(api.ProgressResponse)) error { + old, _ := ParseNamedManifest(name) - var manifest *ManifestV2 - var err error - var noprune string - - // build deleteMap to prune unused layers - deleteMap := make(map[string]struct{}) - - if !envconfig.NoPrune { - manifest, _, err = GetManifest(mp) - if err != nil && !errors.Is(err, os.ErrNotExist) { - return err - } - - if manifest != nil { - for _, l := range manifest.Layers { - deleteMap[l.Digest] = struct{}{} - } - deleteMap[manifest.Config.Digest] = struct{}{} - } + if !name.IsFullyQualified() { + return model.Unqualified(name) } - if mp.ProtocolScheme == "http" && !regOpts.Insecure { - return fmt.Errorf("insecure protocol http") + scheme := "https" + if opts.Insecure { + scheme = "http" + } + + baseURL, err := url.Parse(fmt.Sprintf("%s://%s", scheme, path.Join(name.Host, "v2", name.Namespace, name.Model))) + if err != nil { + return err } fn(api.ProgressResponse{Status: "pulling manifest"}) - - manifest, err = pullModelManifest(ctx, mp, regOpts) + m, err := pullModelManifest(ctx, name, baseURL, &opts) if err != nil { return fmt.Errorf("pull model manifest: %s", err) } - var layers []*Layer - layers = append(layers, manifest.Layers...) - layers = append(layers, manifest.Config) - + layers := append(m.Layers, m.Config) for _, layer := range layers { if err := downloadBlob( ctx, - downloadOpts{ - mp: mp, + downloadOptions{ + name: name, + baseURL: baseURL, digest: layer.Digest, - regOpts: regOpts, + regOpts: &opts, fn: fn, }); err != nil { return err } - delete(deleteMap, layer.Digest) } - delete(deleteMap, manifest.Config.Digest) fn(api.ProgressResponse{Status: "verifying sha256 digest"}) for _, layer := range layers { @@ -854,45 +839,25 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu } fn(api.ProgressResponse{Status: "writing manifest"}) - - manifestJSON, err := json.Marshal(manifest) - if err != nil { + if err := WriteManifest(name, m.Config, m.Layers); err != nil { return err } - fp, err := mp.GetManifestPath() - if err != nil { - return err - } - if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil { - return err - } - - err = os.WriteFile(fp, manifestJSON, 0o644) - if err != nil { - slog.Info(fmt.Sprintf("couldn't write to %s", fp)) - return err - } - - if noprune == "" { + if !envconfig.NoPrune && old != nil { fn(api.ProgressResponse{Status: "removing any unused layers"}) - err = deleteUnusedLayers(nil, deleteMap) - if err != nil { - return err - } + _ = old.RemoveLayers() } fn(api.ProgressResponse{Status: "success"}) - return nil } -func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) { - requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) +func pullModelManifest(ctx context.Context, name model.Name, baseURL *url.URL, opts *registryOptions) (*ManifestV2, error) { + requestURL := baseURL.JoinPath("manifests", name.Tag) headers := make(http.Header) headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json") - resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts) + resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts) if err != nil { return nil, err } diff --git a/server/model.go b/server/model.go index d1cacfe1..3e8f86ae 100644 --- a/server/model.go +++ b/server/model.go @@ -26,7 +26,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe m, err := ParseNamedManifest(name) switch { case errors.Is(err, os.ErrNotExist): - if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil { + if err := PullModel(ctx, name, registryOptions{}, fn); err != nil { return nil, err } diff --git a/server/routes.go b/server/routes.go index bf15079c..55c49970 100644 --- a/server/routes.go +++ b/server/routes.go @@ -408,24 +408,18 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { } func (s *Server) PullModelHandler(c *gin.Context) { - var req api.PullRequest - err := c.ShouldBindJSON(&req) - switch { - case errors.Is(err, io.EOF): + var r api.PullRequest + if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return - case err != nil: + } else if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - var model string - if req.Model != "" { - model = req.Model - } else if req.Name != "" { - model = req.Name - } else { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + n := model.ParseName(cmp.Or(r.Model, r.Name)) + if !n.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))}) return } @@ -436,19 +430,15 @@ func (s *Server) PullModelHandler(c *gin.Context) { ch <- r } - regOpts := ®istryOptions{ - Insecure: req.Insecure, - } - ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() - if err := PullModel(ctx, model, regOpts, fn); err != nil { + if err := PullModel(ctx, n, registryOptions{Insecure: r.Insecure}, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() - if req.Stream != nil && !*req.Stream { + if r.Stream != nil && !*r.Stream { waitForStream(c, ch) return }