This commit is contained in:
Blake Mizerany 2024-04-30 16:53:47 -07:00
parent 7ba71c3989
commit 8afe873f17

View File

@ -17,6 +17,7 @@ import (
"github.com/ollama/ollama/client/ollama" "github.com/ollama/ollama/client/ollama"
"github.com/ollama/ollama/client/registry/apitype" "github.com/ollama/ollama/client/registry/apitype"
"github.com/ollama/ollama/types/model"
"golang.org/x/exp/constraints" "golang.org/x/exp/constraints"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
@ -54,20 +55,23 @@ type Cache interface {
// //
// If the digest is invalid, or the layer does not exist, the empty // If the digest is invalid, or the layer does not exist, the empty
// string is returned. // string is returned.
LayerFile(digest string) string LayerFile(model.Digest) string
// OpenLayer opens the layer file for the given model digest and // OpenLayer opens the layer file for the given model digest and
// returns it, or an if any. The caller is responsible for closing // returns it, or an if any. The caller is responsible for closing
// the returned file. // the returned file.
OpenLayer(digest string) (ReadAtSeekCloser, error) OpenLayer(model.Digest) (ReadAtSeekCloser, error)
// PutLayerFile moves the layer file at fromPath to the cache for // PutLayerFile moves the layer file at fromPath to the cache for
// the given model digest. It is a hack intended to short circuit a // the given model digest. It is a hack intended to short circuit a
// file copy operation. // file copy operation.
// //
// The file returned is expected to exist for the lifetime of the
// cache.
//
// TODO(bmizerany): remove this; find a better way. Once we move // TODO(bmizerany): remove this; find a better way. Once we move
// this into a build package, we should be able to get rid of this. // this into a build package, we should be able to get rid of this.
PutLayerFile(digest, fromPath string) error PutLayerFile(_ model.Digest, fromPath string) error
// SetManifestData sets the provided manifest data for the given // SetManifestData sets the provided manifest data for the given
// model name. If the manifest data is empty, the manifest is // model name. If the manifest data is empty, the manifest is
@ -75,19 +79,24 @@ type Cache interface {
// //
// It is an error to call SetManifestData with a name that is not // It is an error to call SetManifestData with a name that is not
// complete. // complete.
SetManifestData(name string, data []byte) error SetManifestData(model.Name, []byte) error
// ManifestData returns the manifest data for the given model name. // ManifestData returns the manifest data for the given model name.
// //
// If the name incomplete, or the manifest does not exist, the empty // If the name incomplete, or the manifest does not exist, the empty
// string is returned. // string is returned.
ManifestData(name string) []byte ManifestData(name model.Name) []byte
} }
// Pull pulls the manifest for name, and downloads any of its required // Pull pulls the manifest for name, and downloads any of its required
// layers that are not already in the cache. It returns an error if any part // layers that are not already in the cache. It returns an error if any part
// of the process fails, specifically: // of the process fails, specifically:
func (c *Client) Pull(ctx context.Context, cache Cache, name string) error { func (c *Client) Pull(ctx context.Context, cache Cache, name string) error {
mn := model.ParseName(name)
if !mn.IsFullyQualified() {
return fmt.Errorf("ollama: pull: invalid name: %s", name)
}
log := c.logger().With("name", name) log := c.logger().With("name", name)
pr, err := ollama.Do[*apitype.PullResponse](ctx, c.oclient(), "GET", "/v1/pull/"+name, nil) pr, err := ollama.Do[*apitype.PullResponse](ctx, c.oclient(), "GET", "/v1/pull/"+name, nil)
@ -101,10 +110,14 @@ func (c *Client) Pull(ctx context.Context, cache Cache, name string) error {
// download required layers we do not already have // download required layers we do not already have
for _, l := range pr.Manifest.Layers { for _, l := range pr.Manifest.Layers {
if cache.LayerFile(l.Digest) != "" { d, err := model.ParseDigest(l.Digest)
if err != nil {
return fmt.Errorf("ollama: reading manifest: %w: %s", err, l.Digest)
}
if cache.LayerFile(d) != "" {
continue continue
} }
err := func() error { err = func() error {
log := log.With("digest", l.Digest, "mediaType", l.MediaType, "size", l.Size) log := log.With("digest", l.Digest, "mediaType", l.MediaType, "size", l.Size)
log.Debug("starting download") log.Debug("starting download")
@ -170,7 +183,7 @@ func (c *Client) Pull(ctx context.Context, cache Cache, name string) error {
} }
tmpFile.Close() // release our hold on the file before moving it tmpFile.Close() // release our hold on the file before moving it
return cache.PutLayerFile(l.Digest, tmpFile.Name()) return cache.PutLayerFile(d, tmpFile.Name())
}() }()
if err != nil { if err != nil {
return fmt.Errorf("ollama: pull: %w", err) return fmt.Errorf("ollama: pull: %w", err)
@ -187,7 +200,7 @@ func (c *Client) Pull(ctx context.Context, cache Cache, name string) error {
} }
// TODO(bmizerany): remove dep on model.Name // TODO(bmizerany): remove dep on model.Name
return cache.SetManifestData(name, data) return cache.SetManifestData(mn, data)
} }
type nopSeeker struct { type nopSeeker struct {
@ -205,7 +218,11 @@ func (nopSeeker) Seek(int64, int) (int64, error) {
// If the server requests layers not found in the cache, ErrLayerNotFound is // If the server requests layers not found in the cache, ErrLayerNotFound is
// returned. // returned.
func (c *Client) Push(ctx context.Context, cache Cache, name string) error { func (c *Client) Push(ctx context.Context, cache Cache, name string) error {
manifest := cache.ManifestData(name) mn := model.ParseName(name)
if !mn.IsFullyQualified() {
return fmt.Errorf("ollama: push: invalid name: %s", name)
}
manifest := cache.ManifestData(mn)
if len(manifest) == 0 { if len(manifest) == 0 {
return fmt.Errorf("manifest not found: %s", name) return fmt.Errorf("manifest not found: %s", name)
} }
@ -232,7 +249,11 @@ func (c *Client) Push(ctx context.Context, cache Cache, name string) error {
var g errgroup.Group var g errgroup.Group
for _, need := range pr.Needs { for _, need := range pr.Needs {
g.Go(func() error { g.Go(func() error {
f, err := cache.OpenLayer(need.Digest) nd, err := model.ParseDigest(need.Digest)
if err != nil {
return fmt.Errorf("ParseDigest: %w: %s", err, need.Digest)
}
f, err := cache.OpenLayer(nd)
if err != nil { if err != nil {
return fmt.Errorf("OpenLayer: %w: %s", err, need.Digest) return fmt.Errorf("OpenLayer: %w: %s", err, need.Digest)
} }
@ -266,7 +287,7 @@ func (c *Client) Push(ctx context.Context, cache Cache, name string) error {
} }
} }
return cache.SetManifestData(name, manifest) return cache.SetManifestData(mn, manifest)
} }
func PushLayer(ctx context.Context, body io.ReaderAt, url string, start, end int64) (*apitype.CompletePart, error) { func PushLayer(ctx context.Context, body io.ReaderAt, url string, start, end int64) (*apitype.CompletePart, error) {