remove last bits of ParseModelPath

This commit is contained in:
Michael Yang 2024-05-08 18:02:07 -07:00
parent af11838245
commit 980070dce6
10 changed files with 146 additions and 404 deletions

View File

@ -4,14 +4,11 @@ import (
"bytes"
"cmp"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"log/slog"
"net/http"
"net/url"
@ -42,9 +39,8 @@ type registryOptions struct {
}
type Model struct {
Name string `json:"name"`
Name model.Name
Config ConfigV2
ShortName string
ModelPath string
ParentModel string
AdapterPaths []string
@ -161,46 +157,17 @@ type RootFS struct {
DiffIDs []string `json:"diff_ids"`
}
func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
fp, err := mp.GetManifestPath()
if err != nil {
return nil, "", err
}
if _, err = os.Stat(fp); err != nil {
return nil, "", err
}
var manifest *ManifestV2
bts, err := os.ReadFile(fp)
if err != nil {
return nil, "", fmt.Errorf("couldn't open file '%s'", fp)
}
shaSum := sha256.Sum256(bts)
shaStr := hex.EncodeToString(shaSum[:])
if err := json.Unmarshal(bts, &manifest); err != nil {
return nil, "", err
}
return manifest, shaStr, nil
}
func GetModel(name string) (*Model, error) {
mp := ParseModelPath(name)
manifest, digest, err := GetManifest(mp)
func GetModel(name model.Name) (*Model, error) {
manifest, err := ParseNamedManifest(name)
if err != nil {
return nil, err
}
model := &Model{
Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(),
Digest: digest,
Template: "{{ .Prompt }}",
License: []string{},
Name: name,
Digest: manifest.digest,
Template: "{{ .Prompt }}",
License: []string{},
}
filename, err := GetBlobsPath(manifest.Config.Digest)
@ -688,18 +655,8 @@ func PullModel(ctx context.Context, name model.Name, opts registryOptions, fn fu
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers {
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
}
}
if err := layer.Verify(); err != nil {
_ = layer.Remove()
return err
}
}
@ -737,17 +694,6 @@ func pullModelManifest(ctx context.Context, name model.Name, baseURL *url.URL, o
return m, err
}
// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
func GetSHA256Digest(r io.Reader) (string, int64) {
h := sha256.New()
n, err := io.Copy(h, r)
if err != nil {
log.Fatal(err)
}
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
var errUnauthorized = fmt.Errorf("unauthorized: access denied")
// getTokenSubject returns the subject of a JWT token, it does not validate the token
@ -907,25 +853,3 @@ func parseRegistryChallenge(authStr string) registryChallenge {
Scope: getValue(authStr, "scope"),
}
}
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
func verifyBlob(digest string) error {
fp, err := GetBlobsPath(digest)
if err != nil {
return err
}
f, err := os.Open(fp)
if err != nil {
return err
}
defer f.Close()
fileDigest, _ := GetSHA256Digest(f)
if digest != fileDigest {
return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest)
}
return nil
}

View File

@ -115,6 +115,26 @@ func (l *Layer) Remove() error {
return os.Remove(filepath.Join(p, l.Digest))
}
func (l *Layer) Verify() error {
rc, err := l.Open()
if err != nil {
return err
}
defer rc.Close()
sha256sum := sha256.New()
if _, err := io.Copy(sha256sum, rc); err != nil {
return err
}
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
if digest != l.Digest {
return fmt.Errorf("digest mismatch: %s != %s", digest, l.Digest)
}
return nil
}
func Layers() (map[string]*Layer, error) {
blobs, err := GetBlobsPath("")
if err != nil {

View File

@ -194,3 +194,17 @@ func Manifests() (map[model.Name]*Manifest, error) {
return ms, nil
}
func GetManifestPath() (string, error) {
dir, err := modelsDir()
if err != nil {
return "", err
}
path := filepath.Join(dir, "manifests")
if err := os.MkdirAll(path, 0o755); err != nil {
return "", err
}
return path, nil
}

View File

@ -2,105 +2,16 @@ package server
import (
"errors"
"fmt"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
)
type ModelPath struct {
ProtocolScheme string
Registry string
Namespace string
Repository string
Tag string
}
const (
DefaultRegistry = "registry.ollama.ai"
DefaultNamespace = "library"
DefaultTag = "latest"
DefaultProtocolScheme = "https"
)
var (
ErrInvalidImageFormat = errors.New("invalid image format")
ErrInvalidProtocol = errors.New("invalid protocol scheme")
ErrInsecureProtocol = errors.New("insecure protocol http")
ErrInvalidDigestFormat = errors.New("invalid digest format")
)
func ParseModelPath(name string) ModelPath {
mp := ModelPath{
ProtocolScheme: DefaultProtocolScheme,
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "",
Tag: DefaultTag,
}
before, after, found := strings.Cut(name, "://")
if found {
mp.ProtocolScheme = before
name = after
}
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
mp.Registry = parts[0]
mp.Namespace = parts[1]
mp.Repository = parts[2]
case 2:
mp.Namespace = parts[0]
mp.Repository = parts[1]
case 1:
mp.Repository = parts[0]
}
if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
mp.Repository = repo
mp.Tag = tag
}
return mp
}
var errModelPathInvalid = errors.New("invalid model path")
func (mp ModelPath) Validate() error {
if mp.Repository == "" {
return fmt.Errorf("%w: model repository name is required", errModelPathInvalid)
}
if strings.Contains(mp.Tag, ":") {
return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid)
}
return nil
}
func (mp ModelPath) GetNamespaceRepository() string {
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
}
func (mp ModelPath) GetFullTagname() string {
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
func (mp ModelPath) GetShortTagname() string {
if mp.Registry == DefaultRegistry {
if mp.Namespace == DefaultNamespace {
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
// modelsDir returns the value of the OLLAMA_MODELS environment variable or the user's home directory if OLLAMA_MODELS is not set.
// The models directory is where Ollama stores its model files and manifests.
func modelsDir() (string, error) {
@ -114,37 +25,6 @@ func modelsDir() (string, error) {
return filepath.Join(home, ".ollama", "models"), nil
}
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
func (mp ModelPath) GetManifestPath() (string, error) {
dir, err := modelsDir()
if err != nil {
return "", err
}
return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
}
func (mp ModelPath) BaseURL() *url.URL {
return &url.URL{
Scheme: mp.ProtocolScheme,
Host: mp.Registry,
}
}
func GetManifestPath() (string, error) {
dir, err := modelsDir()
if err != nil {
return "", err
}
path := filepath.Join(dir, "manifests")
if err := os.MkdirAll(path, 0o755); err != nil {
return "", err
}
return path, nil
}
func GetBlobsPath(digest string) (string, error) {
dir, err := modelsDir()
if err != nil {

View File

@ -68,88 +68,3 @@ func TestGetBlobsPath(t *testing.T) {
})
}
}
func TestParseModelPath(t *testing.T) {
tests := []struct {
name string
arg string
want ModelPath
}{
{
"full path https",
"https://example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"full path http",
"http://example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "http",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no protocol",
"example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no registry",
"ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no namespace",
"repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: "tag",
},
},
{
"no tag",
"repo",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: DefaultTag,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := ParseModelPath(tc.arg)
if got != tc.want {
t.Errorf("got: %q want: %q", got, tc.want)
}
})
}
}

View File

@ -75,45 +75,43 @@ func isSupportedImageType(image []byte) bool {
}
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
var r api.GenerateRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
n := model.ParseName(r.Model)
// validate the request
switch {
case req.Model == "":
case !n.IsValid():
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
case len(req.Format) > 0 && req.Format != "json":
case len(r.Format) > 0 && r.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
case r.Raw && (r.Template != "" || r.System != "" || len(r.Context) > 0):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return
}
for _, img := range req.Images {
for _, img := range r.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return
}
}
model, err := GetModel(req.Model)
model, err := GetModel(n)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", r.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@ -125,17 +123,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
opts, err := modelOptions(model, req.Options)
opts, err := modelOptions(model, r.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var sessionDuration time.Duration
if req.KeepAlive == nil {
if r.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
sessionDuration = r.KeepAlive.Duration
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
@ -150,10 +148,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// an empty request loads the model
// note: for a short while template was used in lieu
// of `raw` mode so we need to check for it too
if req.Prompt == "" && req.Template == "" && req.System == "" {
if r.Prompt == "" && r.Template == "" && r.System == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Model: r.Model,
Done: true,
DoneReason: "load",
})
@ -164,37 +162,37 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var prompt string
switch {
case req.Raw:
prompt = req.Prompt
case req.Prompt != "":
if req.Template == "" {
req.Template = model.Template
case r.Raw:
prompt = r.Prompt
case r.Prompt != "":
if r.Template == "" {
r.Template = model.Template
}
if req.System == "" {
req.System = model.System
if r.System == "" {
r.System = model.System
}
slog.Debug("generate handler", "prompt", req.Prompt)
slog.Debug("generate handler", "template", req.Template)
slog.Debug("generate handler", "system", req.System)
slog.Debug("generate handler", "prompt", r.Prompt)
slog.Debug("generate handler", "template", r.Template)
slog.Debug("generate handler", "system", r.System)
var sb strings.Builder
for i := range req.Images {
for i := range r.Images {
fmt.Fprintf(&sb, "[img-%d] ", i)
}
sb.WriteString(req.Prompt)
sb.WriteString(r.Prompt)
p, err := Prompt(req.Template, req.System, sb.String(), "", true)
p, err := Prompt(r.Template, r.System, sb.String(), "", true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sb.Reset()
if req.Context != nil {
prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
if r.Context != nil {
prev, err := runner.llama.Detokenize(c.Request.Context(), r.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@ -215,33 +213,33 @@ func (s *Server) GenerateHandler(c *gin.Context) {
go func() {
defer close(ch)
fn := func(r llm.CompletionResponse) {
fn := func(comp llm.CompletionResponse) {
// Build up the full response
if _, err := generated.WriteString(r.Content); err != nil {
if _, err := generated.WriteString(comp.Content); err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp := api.GenerateResponse{
Model: req.Model,
Model: r.Model,
CreatedAt: time.Now().UTC(),
Done: r.Done,
Response: r.Content,
DoneReason: r.DoneReason,
Done: comp.Done,
DoneReason: comp.DoneReason,
Response: comp.Content,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
PromptEvalCount: comp.PromptEvalCount,
PromptEvalDuration: comp.PromptEvalDuration,
EvalCount: comp.EvalCount,
EvalDuration: comp.EvalDuration,
},
}
if r.Done {
if comp.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
if !r.Raw {
p, err := Prompt(r.Template, r.System, r.Prompt, generated.String(), false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@ -254,7 +252,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
resp.Context = append(req.Context, tokens...)
resp.Context = append(r.Context, tokens...)
}
}
@ -262,17 +260,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
var images []llm.ImageData
for i := range req.Images {
for i := range r.Images {
images = append(images, llm.ImageData{
ID: i,
Data: req.Images[i],
Data: r.Images[i],
})
}
// Start prediction
req := llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Format: r.Format,
Images: images,
Options: opts,
}
@ -281,7 +279,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
}()
if req.Stream != nil && !*req.Stream {
if r.Stream != nil && !*r.Stream {
// Accumulate responses into the final response
var final api.GenerateResponse
var sb strings.Builder
@ -339,44 +337,43 @@ func getDefaultSessionDuration() time.Duration {
}
func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
var r api.EmbeddingRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
n := model.ParseName(r.Model)
if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", r.Model)})
return
}
model, err := GetModel(req.Model)
model, err := GetModel(n)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", r.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
opts, err := modelOptions(model, req.Options)
opts, err := modelOptions(model, r.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var sessionDuration time.Duration
if req.KeepAlive == nil {
if r.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
sessionDuration = r.KeepAlive.Duration
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
@ -389,12 +386,12 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
}
// an empty request loads the model
if req.Prompt == "" {
if r.Prompt == "" {
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
return
}
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
embedding, err := runner.llama.Embedding(c.Request.Context(), r.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
@ -497,7 +494,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
name := model.ParseName(cmp.Or(r.Model, r.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
return
}
@ -582,30 +579,25 @@ func (s *Server) DeleteModelHandler(c *gin.Context) {
}
func (s *Server) ShowModelHandler(c *gin.Context) {
var req api.ShowRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
var r api.ShowRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Model != "" {
// noop
} else if req.Name != "" {
req.Model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
n := model.ParseName(cmp.Or(r.Model, r.Name))
if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
return
}
resp, err := GetModelInfo(req)
resp, err := GetModelInfo(n, r)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", r.Model)})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
@ -615,8 +607,8 @@ func (s *Server) ShowModelHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp)
}
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
model, err := GetModel(req.Model)
func GetModelInfo(name model.Name, req api.ShowRequest) (*api.ShowResponse, error) {
model, err := GetModel(name)
if err != nil {
return nil, err
}
@ -674,7 +666,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
var sb strings.Builder
fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"")
fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
fmt.Fprintf(&sb, "# FROM %s\n\n", model.ShortName)
fmt.Fprintf(&sb, "# FROM %s\n\n", name.DisplayShortest())
fmt.Fprint(&sb, model.String())
resp.Modelfile = sb.String()
@ -1087,27 +1079,21 @@ func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.ChatRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// validate the request
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
n := model.ParseName(req.Model)
if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", req.Model)})
return
}
model, err := GetModel(req.Model)
model, err := GetModel(n)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {

View File

@ -146,9 +146,9 @@ func Test_Routes(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, resp.StatusCode, 200)
model, err := GetModel("t-bone")
m, err := GetModel(model.ParseName("t-bone"))
assert.Nil(t, err)
assert.Equal(t, "t-bone:latest", model.ShortName)
assert.Equal(t, "t-bone:latest", m.Name.DisplayShortest())
},
},
{
@ -167,9 +167,9 @@ func Test_Routes(t *testing.T) {
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
model, err := GetModel("beefsteak")
m, err := GetModel(model.ParseName("beefsteak"))
assert.Nil(t, err)
assert.Equal(t, "beefsteak:latest", model.ShortName)
assert.Equal(t, "beefsteak:latest", m.Name.DisplayShortest())
},
},
{

View File

@ -310,7 +310,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.Name.DisplayShortest())
}
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
req.errCh <- err

View File

@ -16,6 +16,7 @@ import (
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/server/envconfig"
"github.com/ollama/ollama/types/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -107,12 +108,12 @@ func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.
return scenario.srv, nil
}
func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
func newScenario(t *testing.T, ctx context.Context, name model.Name, estimatedVRAM uint64) *bundle {
scenario := &bundle{}
scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
t.Helper()
f, err := os.CreateTemp(t.TempDir(), modelName)
f, err := os.CreateTemp(t.TempDir(), name.Model)
assert.Nil(t, err)
defer f.Close()
@ -134,7 +135,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
assert.Nil(t, err)
fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname}
model := &Model{Name: name, ModelPath: fname}
scenario.ggml, err = llm.LoadModel(model.ModelPath)
require.NoError(t, err)
@ -155,24 +156,24 @@ func TestRequests(t *testing.T) {
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
scenario1a := newScenario(t, ctx, model.ParseName("ollama-model-1"), 10)
scenario1a.req.sessionDuration = 0
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
scenario1b := newScenario(t, ctx, model.ParseName("ollama-model-1"), 11)
scenario1b.req.model = scenario1a.req.model
scenario1b.ggml = scenario1a.ggml
scenario1b.req.sessionDuration = 0
// simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
scenario2a := newScenario(t, ctx, model.ParseName("ollama-model-1"), 20)
scenario2a.req.model = scenario1a.req.model
scenario2a.ggml = scenario1a.ggml
// Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
scenario3a := newScenario(t, ctx, model.ParseName("ollama-model-3a"), 1*format.GigaByte)
scenario3b := newScenario(t, ctx, model.ParseName("ollama-model-3b"), 24*format.GigaByte)
scenario3c := newScenario(t, ctx, model.ParseName("ollama-model-4a"), 30)
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
scenario3d := newScenario(t, ctx, model.ParseName("ollama-model-3c"), 30) // Needs prior unloaded
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
@ -310,11 +311,11 @@ func TestGetRunner(t *testing.T) {
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
scenario1a := newScenario(t, ctx, model.ParseName("ollama-model-1a"), 10)
scenario1a.req.sessionDuration = 0
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
scenario1b := newScenario(t, ctx, model.ParseName("ollama-model-1b"), 10)
scenario1b.req.sessionDuration = 0
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c := newScenario(t, ctx, model.ParseName("ollama-model-1c"), 10)
scenario1c.req.sessionDuration = 0
envconfig.MaxQueuedRequests = 1
s := InitScheduler(ctx)
@ -370,7 +371,7 @@ func TestPrematureExpired(t *testing.T) {
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
scenario1a := newScenario(t, ctx, model.ParseName("ollama-model-1a"), 10)
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}

View File

@ -12,6 +12,7 @@ import (
"net/http"
"net/url"
"os"
"path"
"sync"
"sync/atomic"
"time"
@ -56,9 +57,10 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
}
if b.From != "" {
n := model.ParseName(b.From)
values := requestURL.Query()
values.Add("mount", b.Digest)
values.Add("from", ParseModelPath(b.From).GetNamespaceRepository())
values.Add("from", path.Join(n.Namespace, n.Model))
requestURL.RawQuery = values.Encode()
}