From 8afe873f1728f8b022980c479fc46ee890e7c7a3 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Tue, 30 Apr 2024 16:53:47 -0700 Subject: [PATCH] ... --- client/registry/registry.go | 45 +++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/client/registry/registry.go b/client/registry/registry.go index aac4230d..c542a840 100644 --- a/client/registry/registry.go +++ b/client/registry/registry.go @@ -17,6 +17,7 @@ import ( "github.com/ollama/ollama/client/ollama" "github.com/ollama/ollama/client/registry/apitype" + "github.com/ollama/ollama/types/model" "golang.org/x/exp/constraints" "golang.org/x/sync/errgroup" ) @@ -54,20 +55,23 @@ type Cache interface { // // If the digest is invalid, or the layer does not exist, the empty // string is returned. - LayerFile(digest string) string + LayerFile(model.Digest) string // OpenLayer opens the layer file for the given model digest and // returns it, or an if any. The caller is responsible for closing // the returned file. - OpenLayer(digest string) (ReadAtSeekCloser, error) + OpenLayer(model.Digest) (ReadAtSeekCloser, error) // PutLayerFile moves the layer file at fromPath to the cache for // the given model digest. It is a hack intended to short circuit a // file copy operation. // + // The file returned is expected to exist for the lifetime of the + // cache. + // // TODO(bmizerany): remove this; find a better way. Once we move // this into a build package, we should be able to get rid of this. - PutLayerFile(digest, fromPath string) error + PutLayerFile(_ model.Digest, fromPath string) error // SetManifestData sets the provided manifest data for the given // model name. If the manifest data is empty, the manifest is @@ -75,19 +79,24 @@ type Cache interface { // // It is an error to call SetManifestData with a name that is not // complete. - SetManifestData(name string, data []byte) error + SetManifestData(model.Name, []byte) error // ManifestData returns the manifest data for the given model name. // // If the name incomplete, or the manifest does not exist, the empty // string is returned. - ManifestData(name string) []byte + ManifestData(name model.Name) []byte } // Pull pulls the manifest for name, and downloads any of its required // layers that are not already in the cache. It returns an error if any part // of the process fails, specifically: func (c *Client) Pull(ctx context.Context, cache Cache, name string) error { + mn := model.ParseName(name) + if !mn.IsFullyQualified() { + return fmt.Errorf("ollama: pull: invalid name: %s", name) + } + log := c.logger().With("name", name) pr, err := ollama.Do[*apitype.PullResponse](ctx, c.oclient(), "GET", "/v1/pull/"+name, nil) @@ -101,10 +110,14 @@ func (c *Client) Pull(ctx context.Context, cache Cache, name string) error { // download required layers we do not already have for _, l := range pr.Manifest.Layers { - if cache.LayerFile(l.Digest) != "" { + d, err := model.ParseDigest(l.Digest) + if err != nil { + return fmt.Errorf("ollama: reading manifest: %w: %s", err, l.Digest) + } + if cache.LayerFile(d) != "" { continue } - err := func() error { + err = func() error { log := log.With("digest", l.Digest, "mediaType", l.MediaType, "size", l.Size) log.Debug("starting download") @@ -170,7 +183,7 @@ func (c *Client) Pull(ctx context.Context, cache Cache, name string) error { } tmpFile.Close() // release our hold on the file before moving it - return cache.PutLayerFile(l.Digest, tmpFile.Name()) + return cache.PutLayerFile(d, tmpFile.Name()) }() if err != nil { return fmt.Errorf("ollama: pull: %w", err) @@ -187,7 +200,7 @@ func (c *Client) Pull(ctx context.Context, cache Cache, name string) error { } // TODO(bmizerany): remove dep on model.Name - return cache.SetManifestData(name, data) + return cache.SetManifestData(mn, data) } type nopSeeker struct { @@ -205,7 +218,11 @@ func (nopSeeker) Seek(int64, int) (int64, error) { // If the server requests layers not found in the cache, ErrLayerNotFound is // returned. func (c *Client) Push(ctx context.Context, cache Cache, name string) error { - manifest := cache.ManifestData(name) + mn := model.ParseName(name) + if !mn.IsFullyQualified() { + return fmt.Errorf("ollama: push: invalid name: %s", name) + } + manifest := cache.ManifestData(mn) if len(manifest) == 0 { return fmt.Errorf("manifest not found: %s", name) } @@ -232,7 +249,11 @@ func (c *Client) Push(ctx context.Context, cache Cache, name string) error { var g errgroup.Group for _, need := range pr.Needs { g.Go(func() error { - f, err := cache.OpenLayer(need.Digest) + nd, err := model.ParseDigest(need.Digest) + if err != nil { + return fmt.Errorf("ParseDigest: %w: %s", err, need.Digest) + } + f, err := cache.OpenLayer(nd) if err != nil { return fmt.Errorf("OpenLayer: %w: %s", err, need.Digest) } @@ -266,7 +287,7 @@ func (c *Client) Push(ctx context.Context, cache Cache, name string) error { } } - return cache.SetManifestData(name, manifest) + return cache.SetManifestData(mn, manifest) } func PushLayer(ctx context.Context, body io.ReaderAt, url string, start, end int64) (*apitype.CompletePart, error) {