diff --git a/client/registry/registry.go b/client/registry/registry.go index c542a840..957583f8 100644 --- a/client/registry/registry.go +++ b/client/registry/registry.go @@ -31,9 +31,19 @@ type Client struct { BaseURL string Logger *slog.Logger + + // NameFill is a string that is used to fill in the missing parts of + // a name when it is not fully qualified. It is used to make a name + // fully qualified before pushing or pulling it. The default is + // "registry.ollama.ai/library/_:latest". + // + // Most users can ignore this field. It is intended for use by + // clients that need to push or pull names to registries other than + // registry.ollama.ai, and for testing. + NameFill string } -func (c *Client) logger() *slog.Logger { +func (c *Client) log() *slog.Logger { return cmp.Or(c.Logger, slog.Default()) } @@ -92,12 +102,12 @@ type Cache interface { // layers that are not already in the cache. It returns an error if any part // of the process fails, specifically: func (c *Client) Pull(ctx context.Context, cache Cache, name string) error { - mn := model.ParseName(name) + mn := parseNameFill(name, c.NameFill) if !mn.IsFullyQualified() { return fmt.Errorf("ollama: pull: invalid name: %s", name) } - log := c.logger().With("name", name) + log := c.log().With("name", name) pr, err := ollama.Do[*apitype.PullResponse](ctx, c.oclient(), "GET", "/v1/pull/"+name, nil) if err != nil { @@ -211,6 +221,14 @@ func (nopSeeker) Seek(int64, int) (int64, error) { return 0, nil } +func parseNameFill(name, fill string) model.Name { + f := model.ParseNameBare(fill) + if !f.IsFullyQualified() { + panic(fmt.Errorf("invalid fill: %q", fill)) + } + return model.Merge(model.ParseNameBare(name), f) +} + // Push pushes a manifest to the server and responds to the server's // requests for layer uploads, if any, and finally commits the manifest for // name. It returns an error if any part of the process fails, specifically: @@ -218,7 +236,7 @@ func (nopSeeker) Seek(int64, int) (int64, error) { // If the server requests layers not found in the cache, ErrLayerNotFound is // returned. func (c *Client) Push(ctx context.Context, cache Cache, name string) error { - mn := model.ParseName(name) + mn := parseNameFill(name, c.NameFill) if !mn.IsFullyQualified() { return fmt.Errorf("ollama: push: invalid name: %s", name) } @@ -259,6 +277,7 @@ func (c *Client) Push(ctx context.Context, cache Cache, name string) error { } defer f.Close() + c.log().Info("pushing layer", "digest", need.Digest, "start", need.Start, "end", need.End) cp, err := PushLayer(ctx, f, need.URL, need.Start, need.End) if err != nil { return fmt.Errorf("PushLayer: %w: %s", err, need.Digest) diff --git a/server/cache.go b/server/cache.go new file mode 100644 index 00000000..8838d18e --- /dev/null +++ b/server/cache.go @@ -0,0 +1,75 @@ +package server + +import ( + "cmp" + "fmt" + "os" + "path/filepath" + + "github.com/ollama/ollama/client/registry" + "github.com/ollama/ollama/types/model" +) + +// cache is a simple demo disk cache. it does not validate anything +type cache struct { + dir string +} + +func defaultCache() registry.Cache { + homeDir, _ := os.UserHomeDir() + if homeDir == "" { + panic("could not determine home directory") + } + modelsDir := cmp.Or( + os.Getenv("OLLAMA_MODELS"), + filepath.Join(homeDir, ".ollama", "models"), + ) + return &cache{modelsDir} +} + +func invalidDigest(digest string) error { + return fmt.Errorf("invalid digest: %s", digest) +} + +func (c *cache) OpenLayer(d model.Digest) (registry.ReadAtSeekCloser, error) { + return os.Open(c.LayerFile(d)) +} + +func (c *cache) LayerFile(d model.Digest) string { + return filepath.Join(c.dir, "blobs", d.String()) +} + +func (c *cache) PutLayerFile(d model.Digest, fromPath string) error { + if !d.IsValid() { + return invalidDigest(d.String()) + } + bfile := c.LayerFile(d) + dir, _ := filepath.Split(bfile) + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + return os.Rename(fromPath, bfile) +} + +func (c *cache) ManifestData(name model.Name) []byte { + if !name.IsFullyQualified() { + return nil + } + data, err := os.ReadFile(filepath.Join(c.dir, "manifests", name.Filepath())) + if err != nil { + return nil + } + return data +} + +func (c *cache) SetManifestData(name model.Name, data []byte) error { + if !name.IsFullyQualified() { + return fmt.Errorf("invalid name: %s", name) + } + filep := filepath.Join(c.dir, "manifests", name.Filepath()) + dir, _ := filepath.Split(filep) + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + return os.WriteFile(filep, data, 0644) +} diff --git a/server/routes.go b/server/routes.go index b1962d23..db0a9e93 100644 --- a/server/routes.go +++ b/server/routes.go @@ -17,6 +17,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "syscall" "time" @@ -25,6 +26,7 @@ import ( "golang.org/x/exp/slices" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/client/registry" "github.com/ollama/ollama/gpu" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/openai" @@ -33,6 +35,14 @@ import ( "github.com/ollama/ollama/version" ) +var experiments = sync.OnceValue(func() []string { + return strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ",") +}) + +func useExperiemntal(flag string) bool { + return slices.Contains(experiments(), flag) +} + var mode string = gin.DebugMode type Server struct { @@ -444,6 +454,24 @@ func (s *Server) PullModelHandler(c *gin.Context) { return } + if useExperiemntal("pull") { + rc := ®istry.Client{ + BaseURL: os.Getenv("OLLAMA_REGISTRY_BASE_URL"), + } + modelsDir, err := modelsDir() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + cache := &cache{dir: modelsDir} + // TODO(bmizerany): progress updates + if err := rc.Pull(c.Request.Context(), cache, model); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + return + } + ch := make(chan any) go func() { defer close(ch)