weaving in

This commit is contained in:
Blake Mizerany 2024-05-01 15:32:24 -07:00
parent 8afe873f17
commit 844217bcf1
3 changed files with 126 additions and 4 deletions

View File

@ -31,9 +31,19 @@ type Client struct {
BaseURL string
Logger *slog.Logger
// NameFill is a string that is used to fill in the missing parts of
// a name when it is not fully qualified. It is used to make a name
// fully qualified before pushing or pulling it. The default is
// "registry.ollama.ai/library/_:latest".
//
// Most users can ignore this field. It is intended for use by
// clients that need to push or pull names to registries other than
// registry.ollama.ai, and for testing.
NameFill string
}
func (c *Client) logger() *slog.Logger {
func (c *Client) log() *slog.Logger {
return cmp.Or(c.Logger, slog.Default())
}
@ -92,12 +102,12 @@ type Cache interface {
// layers that are not already in the cache. It returns an error if any part
// of the process fails, specifically:
func (c *Client) Pull(ctx context.Context, cache Cache, name string) error {
mn := model.ParseName(name)
mn := parseNameFill(name, c.NameFill)
if !mn.IsFullyQualified() {
return fmt.Errorf("ollama: pull: invalid name: %s", name)
}
log := c.logger().With("name", name)
log := c.log().With("name", name)
pr, err := ollama.Do[*apitype.PullResponse](ctx, c.oclient(), "GET", "/v1/pull/"+name, nil)
if err != nil {
@ -211,6 +221,14 @@ func (nopSeeker) Seek(int64, int) (int64, error) {
return 0, nil
}
func parseNameFill(name, fill string) model.Name {
f := model.ParseNameBare(fill)
if !f.IsFullyQualified() {
panic(fmt.Errorf("invalid fill: %q", fill))
}
return model.Merge(model.ParseNameBare(name), f)
}
// Push pushes a manifest to the server and responds to the server's
// requests for layer uploads, if any, and finally commits the manifest for
// name. It returns an error if any part of the process fails, specifically:
@ -218,7 +236,7 @@ func (nopSeeker) Seek(int64, int) (int64, error) {
// If the server requests layers not found in the cache, ErrLayerNotFound is
// returned.
func (c *Client) Push(ctx context.Context, cache Cache, name string) error {
mn := model.ParseName(name)
mn := parseNameFill(name, c.NameFill)
if !mn.IsFullyQualified() {
return fmt.Errorf("ollama: push: invalid name: %s", name)
}
@ -259,6 +277,7 @@ func (c *Client) Push(ctx context.Context, cache Cache, name string) error {
}
defer f.Close()
c.log().Info("pushing layer", "digest", need.Digest, "start", need.Start, "end", need.End)
cp, err := PushLayer(ctx, f, need.URL, need.Start, need.End)
if err != nil {
return fmt.Errorf("PushLayer: %w: %s", err, need.Digest)

75
server/cache.go Normal file
View File

@ -0,0 +1,75 @@
package server
import (
"cmp"
"fmt"
"os"
"path/filepath"
"github.com/ollama/ollama/client/registry"
"github.com/ollama/ollama/types/model"
)
// cache is a simple demo disk cache. it does not validate anything
type cache struct {
dir string
}
func defaultCache() registry.Cache {
homeDir, _ := os.UserHomeDir()
if homeDir == "" {
panic("could not determine home directory")
}
modelsDir := cmp.Or(
os.Getenv("OLLAMA_MODELS"),
filepath.Join(homeDir, ".ollama", "models"),
)
return &cache{modelsDir}
}
func invalidDigest(digest string) error {
return fmt.Errorf("invalid digest: %s", digest)
}
func (c *cache) OpenLayer(d model.Digest) (registry.ReadAtSeekCloser, error) {
return os.Open(c.LayerFile(d))
}
func (c *cache) LayerFile(d model.Digest) string {
return filepath.Join(c.dir, "blobs", d.String())
}
func (c *cache) PutLayerFile(d model.Digest, fromPath string) error {
if !d.IsValid() {
return invalidDigest(d.String())
}
bfile := c.LayerFile(d)
dir, _ := filepath.Split(bfile)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
return os.Rename(fromPath, bfile)
}
func (c *cache) ManifestData(name model.Name) []byte {
if !name.IsFullyQualified() {
return nil
}
data, err := os.ReadFile(filepath.Join(c.dir, "manifests", name.Filepath()))
if err != nil {
return nil
}
return data
}
func (c *cache) SetManifestData(name model.Name, data []byte) error {
if !name.IsFullyQualified() {
return fmt.Errorf("invalid name: %s", name)
}
filep := filepath.Join(c.dir, "manifests", name.Filepath())
dir, _ := filepath.Split(filep)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
return os.WriteFile(filep, data, 0644)
}

View File

@ -17,6 +17,7 @@ import (
"path/filepath"
"strconv"
"strings"
"sync"
"syscall"
"time"
@ -25,6 +26,7 @@ import (
"golang.org/x/exp/slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/client/registry"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
@ -33,6 +35,14 @@ import (
"github.com/ollama/ollama/version"
)
var experiments = sync.OnceValue(func() []string {
return strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ",")
})
func useExperiemntal(flag string) bool {
return slices.Contains(experiments(), flag)
}
var mode string = gin.DebugMode
type Server struct {
@ -444,6 +454,24 @@ func (s *Server) PullModelHandler(c *gin.Context) {
return
}
if useExperiemntal("pull") {
rc := &registry.Client{
BaseURL: os.Getenv("OLLAMA_REGISTRY_BASE_URL"),
}
modelsDir, err := modelsDir()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
cache := &cache{dir: modelsDir}
// TODO(bmizerany): progress updates
if err := rc.Pull(c.Request.Context(), cache, model); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
return
}
ch := make(chan any)
go func() {
defer close(ch)