From 3e24edd9ed63c6808101e8f6551369d2ab4fd862 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 8 May 2024 17:34:54 -0700 Subject: [PATCH] update push to use model.Name --- server/images.go | 35 +++++++++++++++++------------------ server/routes.go | 26 ++++++++------------------ server/upload.go | 41 ++++++++++++++++++++++++----------------- 3 files changed, 49 insertions(+), 53 deletions(-) diff --git a/server/images.go b/server/images.go index b5bf7ad6..0e4599a2 100644 --- a/server/images.go +++ b/server/images.go @@ -16,6 +16,7 @@ import ( "net/http" "net/url" "os" + "path" "path/filepath" "runtime" "slices" @@ -795,45 +796,42 @@ func PruneDirectory(path string) error { return nil } -func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { - mp := ParseModelPath(name) +func PushModel(ctx context.Context, name model.Name, opts registryOptions, fn func(api.ProgressResponse)) error { fn(api.ProgressResponse{Status: "retrieving manifest"}) - if mp.ProtocolScheme == "http" && !regOpts.Insecure { - return errors.New("insecure protocol http") - } - - manifest, _, err := GetManifest(mp) + m, err := ParseNamedManifest(name) if err != nil { - fn(api.ProgressResponse{Status: "couldn't retrieve manifest"}) return err } - var layers []Layer - layers = append(layers, manifest.Layers...) - if manifest.Config.Digest != "" { - layers = append(layers, manifest.Config) + scheme := "https" + if opts.Insecure { + scheme = "http" } - for _, layer := range layers { - if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil { + baseURL, err := url.Parse(fmt.Sprintf("%s://%s", scheme, path.Join(name.Host, "v2", name.Namespace, name.Model))) + if err != nil { + return err + } + + for _, layer := range append(m.Layers, m.Config) { + if err := uploadBlob(ctx, uploadOptions{name: name, baseURL: baseURL, layer: layer, regOpts: &opts, fn: fn}); err != nil { slog.Info(fmt.Sprintf("error uploading blob: %v", err)) return err } } fn(api.ProgressResponse{Status: "pushing manifest"}) - requestURL := mp.BaseURL() - requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) + requestURL := baseURL.JoinPath("manifests", name.Tag) - manifestJSON, err := json.Marshal(manifest) + manifestJSON, err := json.Marshal(m) if err != nil { return err } headers := make(http.Header) headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json") - resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts) + resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), &opts) if err != nil { return err } @@ -1105,6 +1103,7 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header return nil, err } + slog.Debug("request upstream", "method", method, "request", requestURL.Redacted(), "status", resp.StatusCode) return resp, nil } diff --git a/server/routes.go b/server/routes.go index 5e9f51e1..8624645c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -514,24 +514,18 @@ func (s *Server) PullHandler(c *gin.Context) { } func (s *Server) PushHandler(c *gin.Context) { - var req api.PushRequest - err := c.ShouldBindJSON(&req) - switch { - case errors.Is(err, io.EOF): + var r api.PushRequest + if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return - case err != nil: + } else if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - var model string - if req.Model != "" { - model = req.Model - } else if req.Name != "" { - model = req.Name - } else { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + n := model.ParseName(cmp.Or(r.Model, r.Name)) + if !n.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))}) return } @@ -542,19 +536,15 @@ func (s *Server) PushHandler(c *gin.Context) { ch <- r } - regOpts := ®istryOptions{ - Insecure: req.Insecure, - } - ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() - if err := PushModel(ctx, model, regOpts, fn); err != nil { + if err := PushModel(ctx, n, registryOptions{Insecure: r.Insecure}, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() - if req.Stream != nil && !*req.Stream { + if r.Stream != nil && !*r.Stream { waitForStream(c, ch) return } diff --git a/server/upload.go b/server/upload.go index 020e8955..3e084451 100644 --- a/server/upload.go +++ b/server/upload.go @@ -21,6 +21,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/format" + "github.com/ollama/ollama/types/model" ) var blobUploadManager sync.Map @@ -108,7 +109,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg offset += size } - slog.Info(fmt.Sprintf("uploading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size))) + slog.Info("uploading blob", "digest", b.Digest, "size", format.HumanBytes(b.Total), "parts", len(b.Parts), "size per part", format.HumanBytes(b.Parts[0].Size)) requestURL, err = url.Parse(location) if err != nil { @@ -362,40 +363,46 @@ func (p *progressWriter) Rollback() { p.written = 0 } -func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error { - requestURL := mp.BaseURL() - requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest) +type uploadOptions struct { + name model.Name + baseURL *url.URL + layer Layer + regOpts *registryOptions + fn func(api.ProgressResponse) +} - resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts) +func uploadBlob(ctx context.Context, opts uploadOptions) error { + requestURL := opts.baseURL.JoinPath("blobs", opts.layer.Digest) + + resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts.regOpts) switch { case errors.Is(err, os.ErrNotExist): case err != nil: return err default: defer resp.Body.Close() - fn(api.ProgressResponse{ - Status: fmt.Sprintf("pushing %s", layer.Digest[7:19]), - Digest: layer.Digest, - Total: layer.Size, - Completed: layer.Size, + opts.fn(api.ProgressResponse{ + Status: fmt.Sprintf("pushing %s", opts.layer.Digest[7:19]), + Digest: opts.layer.Digest, + Total: opts.layer.Size, + Completed: opts.layer.Size, }) return nil } - data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer}) + data, ok := blobUploadManager.LoadOrStore(opts.layer.Digest, &blobUpload{Layer: opts.layer}) upload := data.(*blobUpload) if !ok { - requestURL := mp.BaseURL() - requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/") - if err := upload.Prepare(ctx, requestURL, opts); err != nil { - blobUploadManager.Delete(layer.Digest) + requestURL := opts.baseURL.JoinPath("blobs", "uploads") + if err := upload.Prepare(ctx, requestURL, opts.regOpts); err != nil { + blobUploadManager.Delete(opts.layer.Digest) return err } //nolint:contextcheck - go upload.Run(context.Background(), opts) + go upload.Run(context.Background(), opts.regOpts) } - return upload.Wait(ctx, fn) + return upload.Wait(ctx, opts.fn) }