From 980070dce6e6d98aa883f8d519e76ce602066897 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 8 May 2024 18:02:07 -0700 Subject: [PATCH] remove last bits of ParseModelPath --- server/images.go | 94 ++------------------- server/layer.go | 20 +++++ server/manifest.go | 14 ++++ server/modelpath.go | 120 --------------------------- server/modelpath_test.go | 85 ------------------- server/routes.go | 174 ++++++++++++++++++--------------------- server/routes_test.go | 8 +- server/sched.go | 2 +- server/sched_test.go | 29 +++---- server/upload.go | 4 +- 10 files changed, 146 insertions(+), 404 deletions(-) diff --git a/server/images.go b/server/images.go index 68d23ab2..c3529e9a 100644 --- a/server/images.go +++ b/server/images.go @@ -4,14 +4,11 @@ import ( "bytes" "cmp" "context" - "crypto/sha256" "encoding/base64" - "encoding/hex" "encoding/json" "errors" "fmt" "io" - "log" "log/slog" "net/http" "net/url" @@ -42,9 +39,8 @@ type registryOptions struct { } type Model struct { - Name string `json:"name"` + Name model.Name Config ConfigV2 - ShortName string ModelPath string ParentModel string AdapterPaths []string @@ -161,46 +157,17 @@ type RootFS struct { DiffIDs []string `json:"diff_ids"` } -func GetManifest(mp ModelPath) (*ManifestV2, string, error) { - fp, err := mp.GetManifestPath() - if err != nil { - return nil, "", err - } - - if _, err = os.Stat(fp); err != nil { - return nil, "", err - } - - var manifest *ManifestV2 - - bts, err := os.ReadFile(fp) - if err != nil { - return nil, "", fmt.Errorf("couldn't open file '%s'", fp) - } - - shaSum := sha256.Sum256(bts) - shaStr := hex.EncodeToString(shaSum[:]) - - if err := json.Unmarshal(bts, &manifest); err != nil { - return nil, "", err - } - - return manifest, shaStr, nil -} - -func GetModel(name string) (*Model, error) { - mp := ParseModelPath(name) - manifest, digest, err := GetManifest(mp) +func GetModel(name model.Name) (*Model, error) { + manifest, err := ParseNamedManifest(name) if err != nil { return nil, err } model := &Model{ - Name: mp.GetFullTagname(), - ShortName: mp.GetShortTagname(), - Digest: digest, - Template: "{{ .Prompt }}", - License: []string{}, + Name: name, + Digest: manifest.digest, + Template: "{{ .Prompt }}", + License: []string{}, } filename, err := GetBlobsPath(manifest.Config.Digest) @@ -688,18 +655,8 @@ func PullModel(ctx context.Context, name model.Name, opts registryOptions, fn fu fn(api.ProgressResponse{Status: "verifying sha256 digest"}) for _, layer := range layers { - if err := verifyBlob(layer.Digest); err != nil { - if errors.Is(err, errDigestMismatch) { - // something went wrong, delete the blob - fp, err := GetBlobsPath(layer.Digest) - if err != nil { - return err - } - if err := os.Remove(fp); err != nil { - // log this, but return the original error - slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err)) - } - } + if err := layer.Verify(); err != nil { + _ = layer.Remove() return err } } @@ -737,17 +694,6 @@ func pullModelManifest(ctx context.Context, name model.Name, baseURL *url.URL, o return m, err } -// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer -func GetSHA256Digest(r io.Reader) (string, int64) { - h := sha256.New() - n, err := io.Copy(h, r) - if err != nil { - log.Fatal(err) - } - - return fmt.Sprintf("sha256:%x", h.Sum(nil)), n -} - var errUnauthorized = fmt.Errorf("unauthorized: access denied") // getTokenSubject returns the subject of a JWT token, it does not validate the token @@ -907,25 +853,3 @@ func parseRegistryChallenge(authStr string) registryChallenge { Scope: getValue(authStr, "scope"), } } - -var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again") - -func verifyBlob(digest string) error { - fp, err := GetBlobsPath(digest) - if err != nil { - return err - } - - f, err := os.Open(fp) - if err != nil { - return err - } - defer f.Close() - - fileDigest, _ := GetSHA256Digest(f) - if digest != fileDigest { - return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest) - } - - return nil -} diff --git a/server/layer.go b/server/layer.go index ad8c2069..70c47709 100644 --- a/server/layer.go +++ b/server/layer.go @@ -115,6 +115,26 @@ func (l *Layer) Remove() error { return os.Remove(filepath.Join(p, l.Digest)) } +func (l *Layer) Verify() error { + rc, err := l.Open() + if err != nil { + return err + } + defer rc.Close() + + sha256sum := sha256.New() + if _, err := io.Copy(sha256sum, rc); err != nil { + return err + } + + digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)) + if digest != l.Digest { + return fmt.Errorf("digest mismatch: %s != %s", digest, l.Digest) + } + + return nil +} + func Layers() (map[string]*Layer, error) { blobs, err := GetBlobsPath("") if err != nil { diff --git a/server/manifest.go b/server/manifest.go index 2856db9b..b06a51d5 100644 --- a/server/manifest.go +++ b/server/manifest.go @@ -194,3 +194,17 @@ func Manifests() (map[model.Name]*Manifest, error) { return ms, nil } + +func GetManifestPath() (string, error) { + dir, err := modelsDir() + if err != nil { + return "", err + } + + path := filepath.Join(dir, "manifests") + if err := os.MkdirAll(path, 0o755); err != nil { + return "", err + } + + return path, nil +} diff --git a/server/modelpath.go b/server/modelpath.go index 25a817ca..135dc240 100644 --- a/server/modelpath.go +++ b/server/modelpath.go @@ -2,105 +2,16 @@ package server import ( "errors" - "fmt" - "net/url" "os" "path/filepath" "regexp" "strings" ) -type ModelPath struct { - ProtocolScheme string - Registry string - Namespace string - Repository string - Tag string -} - -const ( - DefaultRegistry = "registry.ollama.ai" - DefaultNamespace = "library" - DefaultTag = "latest" - DefaultProtocolScheme = "https" -) - var ( - ErrInvalidImageFormat = errors.New("invalid image format") - ErrInvalidProtocol = errors.New("invalid protocol scheme") - ErrInsecureProtocol = errors.New("insecure protocol http") ErrInvalidDigestFormat = errors.New("invalid digest format") ) -func ParseModelPath(name string) ModelPath { - mp := ModelPath{ - ProtocolScheme: DefaultProtocolScheme, - Registry: DefaultRegistry, - Namespace: DefaultNamespace, - Repository: "", - Tag: DefaultTag, - } - - before, after, found := strings.Cut(name, "://") - if found { - mp.ProtocolScheme = before - name = after - } - - name = strings.ReplaceAll(name, string(os.PathSeparator), "/") - parts := strings.Split(name, "/") - switch len(parts) { - case 3: - mp.Registry = parts[0] - mp.Namespace = parts[1] - mp.Repository = parts[2] - case 2: - mp.Namespace = parts[0] - mp.Repository = parts[1] - case 1: - mp.Repository = parts[0] - } - - if repo, tag, found := strings.Cut(mp.Repository, ":"); found { - mp.Repository = repo - mp.Tag = tag - } - - return mp -} - -var errModelPathInvalid = errors.New("invalid model path") - -func (mp ModelPath) Validate() error { - if mp.Repository == "" { - return fmt.Errorf("%w: model repository name is required", errModelPathInvalid) - } - - if strings.Contains(mp.Tag, ":") { - return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid) - } - - return nil -} - -func (mp ModelPath) GetNamespaceRepository() string { - return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository) -} - -func (mp ModelPath) GetFullTagname() string { - return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag) -} - -func (mp ModelPath) GetShortTagname() string { - if mp.Registry == DefaultRegistry { - if mp.Namespace == DefaultNamespace { - return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag) - } - return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag) - } - return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag) -} - // modelsDir returns the value of the OLLAMA_MODELS environment variable or the user's home directory if OLLAMA_MODELS is not set. // The models directory is where Ollama stores its model files and manifests. func modelsDir() (string, error) { @@ -114,37 +25,6 @@ func modelsDir() (string, error) { return filepath.Join(home, ".ollama", "models"), nil } -// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist. -func (mp ModelPath) GetManifestPath() (string, error) { - dir, err := modelsDir() - if err != nil { - return "", err - } - - return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil -} - -func (mp ModelPath) BaseURL() *url.URL { - return &url.URL{ - Scheme: mp.ProtocolScheme, - Host: mp.Registry, - } -} - -func GetManifestPath() (string, error) { - dir, err := modelsDir() - if err != nil { - return "", err - } - - path := filepath.Join(dir, "manifests") - if err := os.MkdirAll(path, 0o755); err != nil { - return "", err - } - - return path, nil -} - func GetBlobsPath(digest string) (string, error) { dir, err := modelsDir() if err != nil { diff --git a/server/modelpath_test.go b/server/modelpath_test.go index 30741d87..8445a6d1 100644 --- a/server/modelpath_test.go +++ b/server/modelpath_test.go @@ -68,88 +68,3 @@ func TestGetBlobsPath(t *testing.T) { }) } } - -func TestParseModelPath(t *testing.T) { - tests := []struct { - name string - arg string - want ModelPath - }{ - { - "full path https", - "https://example.com/ns/repo:tag", - ModelPath{ - ProtocolScheme: "https", - Registry: "example.com", - Namespace: "ns", - Repository: "repo", - Tag: "tag", - }, - }, - { - "full path http", - "http://example.com/ns/repo:tag", - ModelPath{ - ProtocolScheme: "http", - Registry: "example.com", - Namespace: "ns", - Repository: "repo", - Tag: "tag", - }, - }, - { - "no protocol", - "example.com/ns/repo:tag", - ModelPath{ - ProtocolScheme: "https", - Registry: "example.com", - Namespace: "ns", - Repository: "repo", - Tag: "tag", - }, - }, - { - "no registry", - "ns/repo:tag", - ModelPath{ - ProtocolScheme: "https", - Registry: DefaultRegistry, - Namespace: "ns", - Repository: "repo", - Tag: "tag", - }, - }, - { - "no namespace", - "repo:tag", - ModelPath{ - ProtocolScheme: "https", - Registry: DefaultRegistry, - Namespace: DefaultNamespace, - Repository: "repo", - Tag: "tag", - }, - }, - { - "no tag", - "repo", - ModelPath{ - ProtocolScheme: "https", - Registry: DefaultRegistry, - Namespace: DefaultNamespace, - Repository: "repo", - Tag: DefaultTag, - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - got := ParseModelPath(tc.arg) - - if got != tc.want { - t.Errorf("got: %q want: %q", got, tc.want) - } - }) - } -} diff --git a/server/routes.go b/server/routes.go index ea1067f2..bb2ca681 100644 --- a/server/routes.go +++ b/server/routes.go @@ -75,45 +75,43 @@ func isSupportedImageType(image []byte) bool { } func (s *Server) GenerateHandler(c *gin.Context) { - checkpointStart := time.Now() - var req api.GenerateRequest - err := c.ShouldBindJSON(&req) - - switch { - case errors.Is(err, io.EOF): + var r api.GenerateRequest + if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return - case err != nil: + } else if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + n := model.ParseName(r.Model) + // validate the request switch { - case req.Model == "": + case !n.IsValid(): c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) return - case len(req.Format) > 0 && req.Format != "json": + case len(r.Format) > 0 && r.Format != "json": c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"}) return - case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0): + case r.Raw && (r.Template != "" || r.System != "" || len(r.Context) > 0): c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"}) return } - for _, img := range req.Images { + for _, img := range r.Images { if !isSupportedImageType(img) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"}) return } } - model, err := GetModel(req.Model) + model, err := GetModel(n) if err != nil { var pErr *fs.PathError if errors.As(err, &pErr) { - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", r.Model)}) return } c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -125,17 +123,17 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - opts, err := modelOptions(model, req.Options) + opts, err := modelOptions(model, r.Options) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } var sessionDuration time.Duration - if req.KeepAlive == nil { + if r.KeepAlive == nil { sessionDuration = getDefaultSessionDuration() } else { - sessionDuration = req.KeepAlive.Duration + sessionDuration = r.KeepAlive.Duration } rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration) @@ -150,10 +148,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { // an empty request loads the model // note: for a short while template was used in lieu // of `raw` mode so we need to check for it too - if req.Prompt == "" && req.Template == "" && req.System == "" { + if r.Prompt == "" && r.Template == "" && r.System == "" { c.JSON(http.StatusOK, api.GenerateResponse{ CreatedAt: time.Now().UTC(), - Model: req.Model, + Model: r.Model, Done: true, DoneReason: "load", }) @@ -164,37 +162,37 @@ func (s *Server) GenerateHandler(c *gin.Context) { var prompt string switch { - case req.Raw: - prompt = req.Prompt - case req.Prompt != "": - if req.Template == "" { - req.Template = model.Template + case r.Raw: + prompt = r.Prompt + case r.Prompt != "": + if r.Template == "" { + r.Template = model.Template } - if req.System == "" { - req.System = model.System + if r.System == "" { + r.System = model.System } - slog.Debug("generate handler", "prompt", req.Prompt) - slog.Debug("generate handler", "template", req.Template) - slog.Debug("generate handler", "system", req.System) + slog.Debug("generate handler", "prompt", r.Prompt) + slog.Debug("generate handler", "template", r.Template) + slog.Debug("generate handler", "system", r.System) var sb strings.Builder - for i := range req.Images { + for i := range r.Images { fmt.Fprintf(&sb, "[img-%d] ", i) } - sb.WriteString(req.Prompt) + sb.WriteString(r.Prompt) - p, err := Prompt(req.Template, req.System, sb.String(), "", true) + p, err := Prompt(r.Template, r.System, sb.String(), "", true) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } sb.Reset() - if req.Context != nil { - prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context) + if r.Context != nil { + prev, err := runner.llama.Detokenize(c.Request.Context(), r.Context) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -215,33 +213,33 @@ func (s *Server) GenerateHandler(c *gin.Context) { go func() { defer close(ch) - fn := func(r llm.CompletionResponse) { + fn := func(comp llm.CompletionResponse) { // Build up the full response - if _, err := generated.WriteString(r.Content); err != nil { + if _, err := generated.WriteString(comp.Content); err != nil { ch <- gin.H{"error": err.Error()} return } resp := api.GenerateResponse{ - Model: req.Model, + Model: r.Model, CreatedAt: time.Now().UTC(), - Done: r.Done, - Response: r.Content, - DoneReason: r.DoneReason, + Done: comp.Done, + DoneReason: comp.DoneReason, + Response: comp.Content, Metrics: api.Metrics{ - PromptEvalCount: r.PromptEvalCount, - PromptEvalDuration: r.PromptEvalDuration, - EvalCount: r.EvalCount, - EvalDuration: r.EvalDuration, + PromptEvalCount: comp.PromptEvalCount, + PromptEvalDuration: comp.PromptEvalDuration, + EvalCount: comp.EvalCount, + EvalDuration: comp.EvalDuration, }, } - if r.Done { + if comp.Done { resp.TotalDuration = time.Since(checkpointStart) resp.LoadDuration = checkpointLoaded.Sub(checkpointStart) - if !req.Raw { - p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false) + if !r.Raw { + p, err := Prompt(r.Template, r.System, r.Prompt, generated.String(), false) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -254,7 +252,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - resp.Context = append(req.Context, tokens...) + resp.Context = append(r.Context, tokens...) } } @@ -262,17 +260,17 @@ func (s *Server) GenerateHandler(c *gin.Context) { } var images []llm.ImageData - for i := range req.Images { + for i := range r.Images { images = append(images, llm.ImageData{ ID: i, - Data: req.Images[i], + Data: r.Images[i], }) } // Start prediction req := llm.CompletionRequest{ Prompt: prompt, - Format: req.Format, + Format: r.Format, Images: images, Options: opts, } @@ -281,7 +279,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { } }() - if req.Stream != nil && !*req.Stream { + if r.Stream != nil && !*r.Stream { // Accumulate responses into the final response var final api.GenerateResponse var sb strings.Builder @@ -339,44 +337,43 @@ func getDefaultSessionDuration() time.Duration { } func (s *Server) EmbeddingsHandler(c *gin.Context) { - var req api.EmbeddingRequest - err := c.ShouldBindJSON(&req) - switch { - case errors.Is(err, io.EOF): + var r api.EmbeddingRequest + if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return - case err != nil: + } else if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - if req.Model == "" { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + n := model.ParseName(r.Model) + if !n.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", r.Model)}) return } - model, err := GetModel(req.Model) + model, err := GetModel(n) if err != nil { var pErr *fs.PathError if errors.As(err, &pErr) { - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", r.Model)}) return } c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - opts, err := modelOptions(model, req.Options) + opts, err := modelOptions(model, r.Options) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } var sessionDuration time.Duration - if req.KeepAlive == nil { + if r.KeepAlive == nil { sessionDuration = getDefaultSessionDuration() } else { - sessionDuration = req.KeepAlive.Duration + sessionDuration = r.KeepAlive.Duration } rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration) @@ -389,12 +386,12 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { } // an empty request loads the model - if req.Prompt == "" { + if r.Prompt == "" { c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}}) return } - embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt) + embedding, err := runner.llama.Embedding(c.Request.Context(), r.Prompt) if err != nil { slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) @@ -497,7 +494,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) { name := model.ParseName(cmp.Or(r.Model, r.Name)) if !name.IsValid() { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"}) + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))}) return } @@ -582,30 +579,25 @@ func (s *Server) DeleteModelHandler(c *gin.Context) { } func (s *Server) ShowModelHandler(c *gin.Context) { - var req api.ShowRequest - err := c.ShouldBindJSON(&req) - switch { - case errors.Is(err, io.EOF): + var r api.ShowRequest + if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return - case err != nil: + } else if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - if req.Model != "" { - // noop - } else if req.Name != "" { - req.Model = req.Name - } else { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + n := model.ParseName(cmp.Or(r.Model, r.Name)) + if !n.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))}) return } - resp, err := GetModelInfo(req) + resp, err := GetModelInfo(n, r) if err != nil { if os.IsNotExist(err) { - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", r.Model)}) } else { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } @@ -615,8 +607,8 @@ func (s *Server) ShowModelHandler(c *gin.Context) { c.JSON(http.StatusOK, resp) } -func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { - model, err := GetModel(req.Model) +func GetModelInfo(name model.Name, req api.ShowRequest) (*api.ShowResponse, error) { + model, err := GetModel(name) if err != nil { return nil, err } @@ -674,7 +666,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { var sb strings.Builder fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"") fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:") - fmt.Fprintf(&sb, "# FROM %s\n\n", model.ShortName) + fmt.Fprintf(&sb, "# FROM %s\n\n", name.DisplayShortest()) fmt.Fprint(&sb, model.String()) resp.Modelfile = sb.String() @@ -1087,27 +1079,21 @@ func (s *Server) ChatHandler(c *gin.Context) { checkpointStart := time.Now() var req api.ChatRequest - err := c.ShouldBindJSON(&req) - switch { - case errors.Is(err, io.EOF): + if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return - case err != nil: + } else if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // validate the request - switch { - case req.Model == "": - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) - return - case len(req.Format) > 0 && req.Format != "json": - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"}) + n := model.ParseName(req.Model) + if !n.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", req.Model)}) return } - model, err := GetModel(req.Model) + model, err := GetModel(n) if err != nil { var pErr *fs.PathError if errors.As(err, &pErr) { diff --git a/server/routes_test.go b/server/routes_test.go index a5e9da23..d7779e9a 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -146,9 +146,9 @@ func Test_Routes(t *testing.T) { assert.Nil(t, err) assert.Equal(t, resp.StatusCode, 200) - model, err := GetModel("t-bone") + m, err := GetModel(model.ParseName("t-bone")) assert.Nil(t, err) - assert.Equal(t, "t-bone:latest", model.ShortName) + assert.Equal(t, "t-bone:latest", m.Name.DisplayShortest()) }, }, { @@ -167,9 +167,9 @@ func Test_Routes(t *testing.T) { req.Body = io.NopCloser(bytes.NewReader(jsonData)) }, Expected: func(t *testing.T, resp *http.Response) { - model, err := GetModel("beefsteak") + m, err := GetModel(model.ParseName("beefsteak")) assert.Nil(t, err) - assert.Equal(t, "beefsteak:latest", model.ShortName) + assert.Equal(t, "beefsteak:latest", m.Name.DisplayShortest()) }, }, { diff --git a/server/sched.go b/server/sched.go index eff2b117..415899b5 100644 --- a/server/sched.go +++ b/server/sched.go @@ -310,7 +310,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) // show a generalized compatibility error until there is a better way to // check for model compatibility if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") { - err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName) + err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.Name.DisplayShortest()) } slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err) req.errCh <- err diff --git a/server/sched_test.go b/server/sched_test.go index 7e4faa61..c2ff247b 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -16,6 +16,7 @@ import ( "github.com/ollama/ollama/gpu" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/server/envconfig" + "github.com/ollama/ollama/types/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -107,12 +108,12 @@ func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm. return scenario.srv, nil } -func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle { +func newScenario(t *testing.T, ctx context.Context, name model.Name, estimatedVRAM uint64) *bundle { scenario := &bundle{} scenario.ctx, scenario.ctxDone = context.WithCancel(ctx) t.Helper() - f, err := os.CreateTemp(t.TempDir(), modelName) + f, err := os.CreateTemp(t.TempDir(), name.Model) assert.Nil(t, err) defer f.Close() @@ -134,7 +135,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV assert.Nil(t, err) fname := f.Name() - model := &Model{Name: modelName, ModelPath: fname} + model := &Model{Name: name, ModelPath: fname} scenario.ggml, err = llm.LoadModel(model.ModelPath) require.NoError(t, err) @@ -155,24 +156,24 @@ func TestRequests(t *testing.T) { defer done() // Same model, same request - scenario1a := newScenario(t, ctx, "ollama-model-1", 10) + scenario1a := newScenario(t, ctx, model.ParseName("ollama-model-1"), 10) scenario1a.req.sessionDuration = 0 - scenario1b := newScenario(t, ctx, "ollama-model-1", 11) + scenario1b := newScenario(t, ctx, model.ParseName("ollama-model-1"), 11) scenario1b.req.model = scenario1a.req.model scenario1b.ggml = scenario1a.ggml scenario1b.req.sessionDuration = 0 // simple reload of same model - scenario2a := newScenario(t, ctx, "ollama-model-1", 20) + scenario2a := newScenario(t, ctx, model.ParseName("ollama-model-1"), 20) scenario2a.req.model = scenario1a.req.model scenario2a.ggml = scenario1a.ggml // Multiple loaded models - scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte) - scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte) - scenario3c := newScenario(t, ctx, "ollama-model-4a", 30) + scenario3a := newScenario(t, ctx, model.ParseName("ollama-model-3a"), 1*format.GigaByte) + scenario3b := newScenario(t, ctx, model.ParseName("ollama-model-3b"), 24*format.GigaByte) + scenario3c := newScenario(t, ctx, model.ParseName("ollama-model-4a"), 30) scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed - scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded + scenario3d := newScenario(t, ctx, model.ParseName("ollama-model-3c"), 30) // Needs prior unloaded s := InitScheduler(ctx) s.getGpuFn = func() gpu.GpuInfoList { @@ -310,11 +311,11 @@ func TestGetRunner(t *testing.T) { defer done() // Same model, same request - scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) + scenario1a := newScenario(t, ctx, model.ParseName("ollama-model-1a"), 10) scenario1a.req.sessionDuration = 0 - scenario1b := newScenario(t, ctx, "ollama-model-1b", 10) + scenario1b := newScenario(t, ctx, model.ParseName("ollama-model-1b"), 10) scenario1b.req.sessionDuration = 0 - scenario1c := newScenario(t, ctx, "ollama-model-1c", 10) + scenario1c := newScenario(t, ctx, model.ParseName("ollama-model-1c"), 10) scenario1c.req.sessionDuration = 0 envconfig.MaxQueuedRequests = 1 s := InitScheduler(ctx) @@ -370,7 +371,7 @@ func TestPrematureExpired(t *testing.T) { defer done() // Same model, same request - scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) + scenario1a := newScenario(t, ctx, model.ParseName("ollama-model-1a"), 10) s := InitScheduler(ctx) s.getGpuFn = func() gpu.GpuInfoList { g := gpu.GpuInfo{Library: "metal"} diff --git a/server/upload.go b/server/upload.go index 17b6515b..ffddf94d 100644 --- a/server/upload.go +++ b/server/upload.go @@ -12,6 +12,7 @@ import ( "net/http" "net/url" "os" + "path" "sync" "sync/atomic" "time" @@ -56,9 +57,10 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg } if b.From != "" { + n := model.ParseName(b.From) values := requestURL.Query() values.Add("mount", b.Digest) - values.Add("from", ParseModelPath(b.From).GetNamespaceRepository()) + values.Add("from", path.Join(n.Namespace, n.Model)) requestURL.RawQuery = values.Encode() }