Compare commits

...

10 Commits

Author SHA1 Message Date
Michael Yang
980070dce6 remove last bits of ParseModelPath 2024-05-13 14:25:37 -07:00
Michael Yang
af11838245 update push to use model.Name 2024-05-13 14:25:27 -07:00
Michael Yang
7cb2fd3555 fix(server): prune files 2024-05-13 14:24:52 -07:00
Michael Yang
0aeaeaa058 update pull handler to use model.Name 2024-05-13 14:24:05 -07:00
Michael Yang
b91cf0893d update create handler to use model.Name 2024-05-13 13:41:13 -07:00
Michael Yang
8215545c6d more resilient Manifests 2024-05-13 13:40:50 -07:00
Michael Yang
34a1dbe6ec filepath.Join 2024-05-13 13:40:50 -07:00
Michael Yang
6f2a09abfd remove DeleteModel 2024-05-13 13:40:50 -07:00
Michael Yang
14f9dc4e6a routes: use Manifests for ListHandler 2024-05-13 13:40:50 -07:00
Michael Yang
eeba2cbae3 update delete handler to use model.Name 2024-05-13 13:40:50 -07:00
15 changed files with 608 additions and 864 deletions

View File

@ -24,7 +24,8 @@ func terminate(cmd *exec.Cmd) error {
if err != nil { if err != nil {
return err return err
} }
defer dll.Release() // nolint: errcheck //nolint:errcheck
defer dll.Release()
pid := cmd.Process.Pid pid := cmd.Process.Pid
@ -73,7 +74,8 @@ func isProcessExited(pid int) (bool, error) {
if err != nil { if err != nil {
return false, fmt.Errorf("failed to open process: %v", err) return false, fmt.Errorf("failed to open process: %v", err)
} }
defer windows.CloseHandle(hProcess) // nolint: errcheck //nolint:errcheck
defer windows.CloseHandle(hProcess)
var exitCode uint32 var exitCode uint32
err = windows.GetExitCodeProcess(hProcess, &exitCode) err = windows.GetExitCodeProcess(hProcess, &exitCode)

View File

@ -81,7 +81,7 @@ func (i *Instance) Readline() (string, error) {
defer func() { defer func() {
fd := int(syscall.Stdin) fd := int(syscall.Stdin)
// nolint: errcheck //nolint:errcheck
UnsetRawMode(fd, i.Terminal.termios) UnsetRawMode(fd, i.Terminal.termios)
i.Terminal.rawmode = false i.Terminal.rawmode = false
}() }()

View File

@ -23,6 +23,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/types/model"
) )
const maxRetries = 6 const maxRetries = 6
@ -332,15 +333,16 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
} }
} }
type downloadOpts struct { type downloadOptions struct {
mp ModelPath name model.Name
baseURL *url.URL
digest string digest string
regOpts *registryOptions regOpts *registryOptions
fn func(api.ProgressResponse) fn func(api.ProgressResponse)
} }
// downloadBlob downloads a blob from the registry and stores it in the blobs directory // downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) error { func downloadBlob(ctx context.Context, opts downloadOptions) error {
fp, err := GetBlobsPath(opts.digest) fp, err := GetBlobsPath(opts.digest)
if err != nil { if err != nil {
return err return err
@ -365,14 +367,13 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest}) data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
download := data.(*blobDownload) download := data.(*blobDownload)
if !ok { if !ok {
requestURL := opts.mp.BaseURL() requestURL := opts.baseURL.JoinPath("blobs", opts.digest)
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil { if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
blobDownloadManager.Delete(opts.digest) blobDownloadManager.Delete(opts.digest)
return err return err
} }
// nolint: contextcheck //nolint:contextcheck
go download.Run(context.Background(), requestURL, opts.regOpts) go download.Run(context.Background(), requestURL, opts.regOpts)
} }

View File

@ -4,18 +4,16 @@ import (
"bytes" "bytes"
"cmp" "cmp"
"context" "context"
"crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"log/slog" "log/slog"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"path"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv" "strconv"
@ -41,9 +39,8 @@ type registryOptions struct {
} }
type Model struct { type Model struct {
Name string `json:"name"` Name model.Name
Config ConfigV2 Config ConfigV2
ShortName string
ModelPath string ModelPath string
ParentModel string ParentModel string
AdapterPaths []string AdapterPaths []string
@ -160,46 +157,17 @@ type RootFS struct {
DiffIDs []string `json:"diff_ids"` DiffIDs []string `json:"diff_ids"`
} }
func GetManifest(mp ModelPath) (*ManifestV2, string, error) { func GetModel(name model.Name) (*Model, error) {
fp, err := mp.GetManifestPath() manifest, err := ParseNamedManifest(name)
if err != nil {
return nil, "", err
}
if _, err = os.Stat(fp); err != nil {
return nil, "", err
}
var manifest *ManifestV2
bts, err := os.ReadFile(fp)
if err != nil {
return nil, "", fmt.Errorf("couldn't open file '%s'", fp)
}
shaSum := sha256.Sum256(bts)
shaStr := hex.EncodeToString(shaSum[:])
if err := json.Unmarshal(bts, &manifest); err != nil {
return nil, "", err
}
return manifest, shaStr, nil
}
func GetModel(name string) (*Model, error) {
mp := ParseModelPath(name)
manifest, digest, err := GetManifest(mp)
if err != nil { if err != nil {
return nil, err return nil, err
} }
model := &Model{ model := &Model{
Name: mp.GetFullTagname(), Name: name,
ShortName: mp.GetShortTagname(), Digest: manifest.digest,
Digest: digest, Template: "{{ .Prompt }}",
Template: "{{ .Prompt }}", License: []string{},
License: []string{},
} }
filename, err := GetBlobsPath(manifest.Config.Digest) filename, err := GetBlobsPath(manifest.Config.Digest)
@ -314,7 +282,7 @@ func realpath(rel, from string) string {
return abspath return abspath
} }
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) { func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) {
config := ConfigV2{ config := ConfigV2{
OS: "linux", OS: "linux",
Architecture: "amd64", Architecture: "amd64",
@ -546,16 +514,10 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
} }
} }
unref := make(map[string]struct{}) if !envconfig.NoPrune {
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil { if old, err := ParseNamedManifest(name); err == nil {
for _, layer := range manifest.Layers { //nolint:errcheck
if !slices.Contains(digests, layer.Digest) { defer old.RemoveLayers()
unref[layer.Digest] = struct{}{}
}
}
if manifest.Config.Digest != layer.Digest {
unref[manifest.Config.Digest] = struct{}{}
} }
} }
@ -564,12 +526,6 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
return err return err
} }
if !envconfig.NoPrune {
if err := deleteUnusedLayers(nil, unref); err != nil {
return err
}
}
fn(api.ProgressResponse{Status: "success"}) fn(api.ProgressResponse{Status: "success"})
return nil return nil
} }
@ -613,207 +569,42 @@ func CopyModel(src, dst model.Name) error {
return err return err
} }
func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}) error { func PushModel(ctx context.Context, name model.Name, opts registryOptions, fn func(api.ProgressResponse)) error {
fp, err := GetManifestPath()
if err != nil {
return err
}
walkFunc := func(path string, info os.FileInfo, _ error) error {
if info.IsDir() {
return nil
}
dir, file := filepath.Split(path)
dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
tag := strings.Join([]string{dir, file}, ":")
fmp := ParseModelPath(tag)
// skip the manifest we're trying to delete
if skipModelPath != nil && skipModelPath.GetFullTagname() == fmp.GetFullTagname() {
return nil
}
// save (i.e. delete from the deleteMap) any files used in other manifests
manifest, _, err := GetManifest(fmp)
if err != nil {
// nolint: nilerr
return nil
}
for _, layer := range manifest.Layers {
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
return nil
}
if err := filepath.Walk(fp, walkFunc); err != nil {
return err
}
// only delete the files which are still in the deleteMap
for k := range deleteMap {
fp, err := GetBlobsPath(k)
if err != nil {
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
continue
}
if err := os.Remove(fp); err != nil {
slog.Info(fmt.Sprintf("couldn't remove file '%s': %v", fp, err))
continue
}
}
return nil
}
func PruneLayers() error {
deleteMap := make(map[string]struct{})
p, err := GetBlobsPath("")
if err != nil {
return err
}
blobs, err := os.ReadDir(p)
if err != nil {
slog.Info(fmt.Sprintf("couldn't read dir '%s': %v", p, err))
return err
}
for _, blob := range blobs {
name := blob.Name()
name = strings.ReplaceAll(name, "-", ":")
_, err := GetBlobsPath(name)
if err != nil {
if errors.Is(err, ErrInvalidDigestFormat) {
// remove invalid blobs (e.g. partial downloads)
if err := os.Remove(filepath.Join(p, blob.Name())); err != nil {
slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err)
}
}
continue
}
deleteMap[name] = struct{}{}
}
slog.Info(fmt.Sprintf("total blobs: %d", len(deleteMap)))
err = deleteUnusedLayers(nil, deleteMap)
if err != nil {
return err
}
slog.Info(fmt.Sprintf("total unused blobs removed: %d", len(deleteMap)))
return nil
}
func PruneDirectory(path string) error {
info, err := os.Lstat(path)
if err != nil {
return err
}
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
entries, err := os.ReadDir(path)
if err != nil {
return err
}
for _, entry := range entries {
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
return err
}
}
entries, err = os.ReadDir(path)
if err != nil {
return err
}
if len(entries) > 0 {
return nil
}
return os.Remove(path)
}
return nil
}
func DeleteModel(name string) error {
mp := ParseModelPath(name)
manifest, _, err := GetManifest(mp)
if err != nil {
return err
}
deleteMap := make(map[string]struct{})
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = struct{}{}
}
deleteMap[manifest.Config.Digest] = struct{}{}
err = deleteUnusedLayers(&mp, deleteMap)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
err = os.Remove(fp)
if err != nil {
slog.Info(fmt.Sprintf("couldn't remove manifest file '%s': %v", fp, err))
return err
}
return nil
}
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"}) fn(api.ProgressResponse{Status: "retrieving manifest"})
if mp.ProtocolScheme == "http" && !regOpts.Insecure { m, err := ParseNamedManifest(name)
return fmt.Errorf("insecure protocol http")
}
manifest, _, err := GetManifest(mp)
if err != nil { if err != nil {
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
return err return err
} }
var layers []*Layer scheme := "https"
layers = append(layers, manifest.Layers...) if opts.Insecure {
layers = append(layers, manifest.Config) scheme = "http"
}
for _, layer := range layers { baseURL, err := url.Parse(fmt.Sprintf("%s://%s", scheme, path.Join(name.Host, "v2", name.Namespace, name.Model)))
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil { if err != nil {
return err
}
for _, layer := range append(m.Layers, m.Config) {
if err := uploadBlob(ctx, uploadOptions{name: name, baseURL: baseURL, layer: layer, regOpts: &opts, fn: fn}); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err)) slog.Info(fmt.Sprintf("error uploading blob: %v", err))
return err return err
} }
} }
fn(api.ProgressResponse{Status: "pushing manifest"}) fn(api.ProgressResponse{Status: "pushing manifest"})
requestURL := mp.BaseURL() requestURL := baseURL.JoinPath("v2", name.Namespace, name.Model, "manifests", name.Tag)
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
manifestJSON, err := json.Marshal(manifest) manifestJSON, err := json.Marshal(m)
if err != nil { if err != nil {
return err return err
} }
headers := make(http.Header) headers := make(http.Header)
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json") headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts) resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), &opts)
if err != nil { if err != nil {
return err return err
} }
@ -824,118 +615,72 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return nil return nil
} }
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { func PullModel(ctx context.Context, name model.Name, opts registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name) old, _ := ParseNamedManifest(name)
var manifest *ManifestV2 if !name.IsFullyQualified() {
var err error return model.Unqualified(name)
var noprune string
// build deleteMap to prune unused layers
deleteMap := make(map[string]struct{})
if !envconfig.NoPrune {
manifest, _, err = GetManifest(mp)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if manifest != nil {
for _, l := range manifest.Layers {
deleteMap[l.Digest] = struct{}{}
}
deleteMap[manifest.Config.Digest] = struct{}{}
}
} }
if mp.ProtocolScheme == "http" && !regOpts.Insecure { scheme := "https"
return fmt.Errorf("insecure protocol http") if opts.Insecure {
scheme = "http"
}
baseURL, err := url.Parse(fmt.Sprintf("%s://%s", scheme, path.Join(name.Host, "v2", name.Namespace, name.Model)))
if err != nil {
return err
} }
fn(api.ProgressResponse{Status: "pulling manifest"}) fn(api.ProgressResponse{Status: "pulling manifest"})
m, err := pullModelManifest(ctx, name, baseURL, &opts)
manifest, err = pullModelManifest(ctx, mp, regOpts)
if err != nil { if err != nil {
return fmt.Errorf("pull model manifest: %s", err) return fmt.Errorf("pull model manifest: %s", err)
} }
var layers []*Layer layers := append(m.Layers, m.Config)
layers = append(layers, manifest.Layers...)
layers = append(layers, manifest.Config)
for _, layer := range layers { for _, layer := range layers {
if err := downloadBlob( if err := downloadBlob(
ctx, ctx,
downloadOpts{ downloadOptions{
mp: mp, name: name,
baseURL: baseURL,
digest: layer.Digest, digest: layer.Digest,
regOpts: regOpts, regOpts: &opts,
fn: fn, fn: fn,
}); err != nil { }); err != nil {
return err return err
} }
delete(deleteMap, layer.Digest)
} }
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "verifying sha256 digest"}) fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers { for _, layer := range layers {
if err := verifyBlob(layer.Digest); err != nil { if err := layer.Verify(); err != nil {
if errors.Is(err, errDigestMismatch) { _ = layer.Remove()
// something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
}
}
return err return err
} }
} }
fn(api.ProgressResponse{Status: "writing manifest"}) fn(api.ProgressResponse{Status: "writing manifest"})
if err := WriteManifest(name, m.Config, m.Layers); err != nil {
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err return err
} }
fp, err := mp.GetManifestPath() if !envconfig.NoPrune && old != nil {
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
return err
}
err = os.WriteFile(fp, manifestJSON, 0o644)
if err != nil {
slog.Info(fmt.Sprintf("couldn't write to %s", fp))
return err
}
if noprune == "" {
fn(api.ProgressResponse{Status: "removing any unused layers"}) fn(api.ProgressResponse{Status: "removing any unused layers"})
err = deleteUnusedLayers(nil, deleteMap) _ = old.RemoveLayers()
if err != nil {
return err
}
} }
fn(api.ProgressResponse{Status: "success"}) fn(api.ProgressResponse{Status: "success"})
return nil return nil
} }
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) { func pullModelManifest(ctx context.Context, name model.Name, baseURL *url.URL, opts *registryOptions) (*ManifestV2, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) requestURL := baseURL.JoinPath("manifests", name.Tag)
headers := make(http.Header) headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json") headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts) resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -949,17 +694,6 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
return m, err return m, err
} }
// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
func GetSHA256Digest(r io.Reader) (string, int64) {
h := sha256.New()
n, err := io.Copy(h, r)
if err != nil {
log.Fatal(err)
}
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
var errUnauthorized = fmt.Errorf("unauthorized: access denied") var errUnauthorized = fmt.Errorf("unauthorized: access denied")
// getTokenSubject returns the subject of a JWT token, it does not validate the token // getTokenSubject returns the subject of a JWT token, it does not validate the token
@ -1119,25 +853,3 @@ func parseRegistryChallenge(authStr string) registryChallenge {
Scope: getValue(authStr, "scope"), Scope: getValue(authStr, "scope"),
} }
} }
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
func verifyBlob(digest string) error {
fp, err := GetBlobsPath(digest)
if err != nil {
return err
}
f, err := os.Open(fp)
if err != nil {
return err
}
defer f.Close()
fileDigest, _ := GetSHA256Digest(f)
if digest != fileDigest {
return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest)
}
return nil
}

View File

@ -4,7 +4,10 @@ import (
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"io" "io"
"log/slog"
"os" "os"
"path/filepath"
"strings"
) )
type Layer struct { type Layer struct {
@ -88,3 +91,81 @@ func (l *Layer) Open() (io.ReadCloser, error) {
return os.Open(blob) return os.Open(blob)
} }
func (l *Layer) Remove() error {
ms, err := Manifests()
if err != nil {
return err
}
for _, m := range ms {
for _, layer := range append(m.Layers, m.Config) {
if layer.Digest == l.Digest {
// something is using this layer
return nil
}
}
}
p, err := GetBlobsPath("")
if err != nil {
return err
}
return os.Remove(filepath.Join(p, l.Digest))
}
func (l *Layer) Verify() error {
rc, err := l.Open()
if err != nil {
return err
}
defer rc.Close()
sha256sum := sha256.New()
if _, err := io.Copy(sha256sum, rc); err != nil {
return err
}
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
if digest != l.Digest {
return fmt.Errorf("digest mismatch: %s != %s", digest, l.Digest)
}
return nil
}
func Layers() (map[string]*Layer, error) {
blobs, err := GetBlobsPath("")
if err != nil {
return nil, err
}
// TODO(mxyng): use something less brittle
matches, err := filepath.Glob(filepath.Join(blobs, "*"))
if err != nil {
return nil, err
}
ds := make(map[string]*Layer)
for _, match := range matches {
rel, err := filepath.Rel(blobs, match)
if err != nil {
slog.Warn("bad filepath", "path", match, "error", err)
continue
}
// TODO(mxyng): this should ideally use model.Digest but
// that's currently incompatible with the manifest digest
d := strings.Replace(rel, "sha256-", "sha256:", 1)
layer, err := NewLayerFromLayer(d, "", "")
if err != nil {
slog.Warn("bad blob", "digest", d, "error", err)
layer = &Layer{Digest: rel}
}
ds[d] = layer
}
return ds, nil
}

View File

@ -1,11 +1,11 @@
package server package server
import ( import (
"bytes"
"crypto/sha256" "crypto/sha256"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log/slog"
"os" "os"
"path/filepath" "path/filepath"
@ -14,7 +14,10 @@ import (
type Manifest struct { type Manifest struct {
ManifestV2 ManifestV2
Digest string `json:"-"`
filepath string
fi os.FileInfo
digest string
} }
func (m *Manifest) Size() (size int64) { func (m *Manifest) Size() (size int64) {
@ -25,9 +28,67 @@ func (m *Manifest) Size() (size int64) {
return return
} }
func ParseNamedManifest(name model.Name) (*Manifest, error) { func (m *Manifest) Remove() error {
if !name.IsFullyQualified() { if err := os.Remove(m.filepath); err != nil {
return nil, model.Unqualified(name) return err
}
manifests, err := GetManifestPath()
if err != nil {
return err
}
return pruneEmptyDirectory(manifests)
}
func pruneEmptyDirectory(p string) error {
fi, err := os.Lstat(p)
if err != nil {
return err
}
if fi.Mode()&os.ModeSymlink == 0 {
entries, err := os.ReadDir(p)
if err != nil {
return err
}
for _, entry := range entries {
if entry.IsDir() {
if err := pruneEmptyDirectory(filepath.Join(p, entry.Name())); err != nil {
return err
}
}
}
entries, err = os.ReadDir(p)
if err != nil {
return err
}
if len(entries) == 0 {
if err := os.Remove(p); err != nil {
return err
}
}
}
return nil
}
func (m *Manifest) RemoveLayers() error {
for _, layer := range append(m.Layers, m.Config) {
if err := layer.Remove(); err != nil {
return err
}
}
return nil
}
func ParseNamedManifest(n model.Name) (*Manifest, error) {
if !n.IsFullyQualified() {
return nil, model.Unqualified(n)
} }
manifests, err := GetManifestPath() manifests, err := GetManifestPath()
@ -35,45 +96,115 @@ func ParseNamedManifest(name model.Name) (*Manifest, error) {
return nil, err return nil, err
} }
var manifest ManifestV2 p := filepath.Join(manifests, n.Filepath())
manifestfile, err := os.Open(filepath.Join(manifests, name.Filepath()))
var m ManifestV2
f, err := os.Open(p)
if err != nil {
return nil, err
}
defer f.Close()
fi, err := f.Stat()
if err != nil { if err != nil {
return nil, err return nil, err
} }
sha256sum := sha256.New() sha256sum := sha256.New()
if err := json.NewDecoder(io.TeeReader(manifestfile, sha256sum)).Decode(&manifest); err != nil { if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
return nil, err return nil, err
} }
return &Manifest{ return &Manifest{
ManifestV2: manifest, ManifestV2: m,
Digest: fmt.Sprintf("%x", sha256sum.Sum(nil)), filepath: p,
fi: fi,
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
}, nil }, nil
} }
func WriteManifest(name string, config *Layer, layers []*Layer) error { func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
manifest := ManifestV2{ manifests, err := GetManifestPath()
if err != nil {
return err
}
p := filepath.Join(manifests, name.Filepath())
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
return err
}
f, err := os.Create(p)
if err != nil {
return err
}
defer f.Close()
m := ManifestV2{
SchemaVersion: 2, SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json", MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config, Config: config,
Layers: layers, Layers: layers,
} }
var b bytes.Buffer return json.NewEncoder(f).Encode(m)
if err := json.NewEncoder(&b).Encode(manifest); err != nil { }
return err
} func Manifests() (map[model.Name]*Manifest, error) {
manifests, err := GetManifestPath()
modelpath := ParseModelPath(name) if err != nil {
manifestPath, err := modelpath.GetManifestPath() return nil, err
if err != nil { }
return err
} // TODO(mxyng): use something less brittle
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil { if err != nil {
return err return nil, err
} }
return os.WriteFile(manifestPath, b.Bytes(), 0o644) ms := make(map[model.Name]*Manifest)
for _, match := range matches {
fi, err := os.Stat(match)
if err != nil {
return nil, err
}
if !fi.IsDir() {
rel, err := filepath.Rel(manifests, match)
if err != nil {
slog.Warn("bad filepath", "path", match, "error", err)
continue
}
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
slog.Warn("bad manifest name", "path", rel, "error", err)
continue
}
m, err := ParseNamedManifest(n)
if err != nil {
slog.Warn("bad manifest", "name", n, "error", err)
continue
}
ms[n] = m
}
}
return ms, nil
}
func GetManifestPath() (string, error) {
dir, err := modelsDir()
if err != nil {
return "", err
}
path := filepath.Join(dir, "manifests")
if err := os.MkdirAll(path, 0o755); err != nil {
return "", err
}
return path, nil
} }

94
server/manifest_test.go Normal file
View File

@ -0,0 +1,94 @@
package server
import (
"encoding/json"
"os"
"path/filepath"
"slices"
"testing"
"github.com/ollama/ollama/types/model"
)
func createManifest(t *testing.T, path, name string) {
t.Helper()
p := filepath.Join(path, "manifests", name)
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
t.Fatal(err)
}
f, err := os.Create(p)
if err != nil {
t.Fatal(err)
}
defer f.Close()
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
t.Fatal(err)
}
}
func TestManifests(t *testing.T) {
cases := map[string][]string{
"empty": {},
"single": {
filepath.Join("host", "namespace", "model", "tag"),
},
"multiple": {
filepath.Join("registry.ollama.ai", "library", "llama3", "latest"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"),
},
"hidden": {
filepath.Join("host", "namespace", "model", "tag"),
filepath.Join("host", "namespace", "model", ".hidden"),
},
"subdir": {
filepath.Join("host", "namespace", "model", "tag", "one"),
filepath.Join("host", "namespace", "model", "tag", "another", "one"),
},
}
for n, wants := range cases {
t.Run(n, func(t *testing.T) {
d := t.TempDir()
t.Setenv("OLLAMA_MODELS", d)
for _, want := range wants {
createManifest(t, d, want)
}
ms, err := Manifests()
if err != nil {
t.Fatal(err)
}
var ns []model.Name
for k := range ms {
ns = append(ns, k)
}
for _, want := range wants {
n := model.ParseNameFromFilepath(want)
if !n.IsValid() && slices.Contains(ns, n) {
t.Errorf("unexpected invalid name: %s", want)
} else if n.IsValid() && !slices.Contains(ns, n) {
t.Errorf("missing valid name: %s", want)
}
}
})
}
}

View File

@ -23,16 +23,14 @@ type layerWithGGML struct {
} }
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
modelpath := ParseModelPath(name.String()) m, err := ParseNamedManifest(name)
manifest, _, err := GetManifest(modelpath)
switch { switch {
case errors.Is(err, os.ErrNotExist): case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil { if err := PullModel(ctx, name, registryOptions{}, fn); err != nil {
return nil, err return nil, err
} }
modelpath = ParseModelPath(name.String()) m, err = ParseNamedManifest(name)
manifest, _, err = GetManifest(modelpath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -40,8 +38,8 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
return nil, err return nil, err
} }
for _, layer := range manifest.Layers { for _, layer := range m.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname()) layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -2,105 +2,16 @@ package server
import ( import (
"errors" "errors"
"fmt"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
"strings" "strings"
) )
type ModelPath struct {
ProtocolScheme string
Registry string
Namespace string
Repository string
Tag string
}
const (
DefaultRegistry = "registry.ollama.ai"
DefaultNamespace = "library"
DefaultTag = "latest"
DefaultProtocolScheme = "https"
)
var ( var (
ErrInvalidImageFormat = errors.New("invalid image format")
ErrInvalidProtocol = errors.New("invalid protocol scheme")
ErrInsecureProtocol = errors.New("insecure protocol http")
ErrInvalidDigestFormat = errors.New("invalid digest format") ErrInvalidDigestFormat = errors.New("invalid digest format")
) )
func ParseModelPath(name string) ModelPath {
mp := ModelPath{
ProtocolScheme: DefaultProtocolScheme,
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "",
Tag: DefaultTag,
}
before, after, found := strings.Cut(name, "://")
if found {
mp.ProtocolScheme = before
name = after
}
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
mp.Registry = parts[0]
mp.Namespace = parts[1]
mp.Repository = parts[2]
case 2:
mp.Namespace = parts[0]
mp.Repository = parts[1]
case 1:
mp.Repository = parts[0]
}
if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
mp.Repository = repo
mp.Tag = tag
}
return mp
}
var errModelPathInvalid = errors.New("invalid model path")
func (mp ModelPath) Validate() error {
if mp.Repository == "" {
return fmt.Errorf("%w: model repository name is required", errModelPathInvalid)
}
if strings.Contains(mp.Tag, ":") {
return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid)
}
return nil
}
func (mp ModelPath) GetNamespaceRepository() string {
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
}
func (mp ModelPath) GetFullTagname() string {
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
func (mp ModelPath) GetShortTagname() string {
if mp.Registry == DefaultRegistry {
if mp.Namespace == DefaultNamespace {
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
// modelsDir returns the value of the OLLAMA_MODELS environment variable or the user's home directory if OLLAMA_MODELS is not set. // modelsDir returns the value of the OLLAMA_MODELS environment variable or the user's home directory if OLLAMA_MODELS is not set.
// The models directory is where Ollama stores its model files and manifests. // The models directory is where Ollama stores its model files and manifests.
func modelsDir() (string, error) { func modelsDir() (string, error) {
@ -114,37 +25,6 @@ func modelsDir() (string, error) {
return filepath.Join(home, ".ollama", "models"), nil return filepath.Join(home, ".ollama", "models"), nil
} }
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
func (mp ModelPath) GetManifestPath() (string, error) {
dir, err := modelsDir()
if err != nil {
return "", err
}
return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
}
func (mp ModelPath) BaseURL() *url.URL {
return &url.URL{
Scheme: mp.ProtocolScheme,
Host: mp.Registry,
}
}
func GetManifestPath() (string, error) {
dir, err := modelsDir()
if err != nil {
return "", err
}
path := filepath.Join(dir, "manifests")
if err := os.MkdirAll(path, 0o755); err != nil {
return "", err
}
return path, nil
}
func GetBlobsPath(digest string) (string, error) { func GetBlobsPath(digest string) (string, error) {
dir, err := modelsDir() dir, err := modelsDir()
if err != nil { if err != nil {

View File

@ -68,88 +68,3 @@ func TestGetBlobsPath(t *testing.T) {
}) })
} }
} }
func TestParseModelPath(t *testing.T) {
tests := []struct {
name string
arg string
want ModelPath
}{
{
"full path https",
"https://example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"full path http",
"http://example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "http",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no protocol",
"example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no registry",
"ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no namespace",
"repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: "tag",
},
},
{
"no tag",
"repo",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: DefaultTag,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := ParseModelPath(tc.arg)
if got != tc.want {
t.Errorf("got: %q want: %q", got, tc.want)
}
})
}
}

View File

@ -75,45 +75,43 @@ func isSupportedImageType(image []byte) bool {
} }
func (s *Server) GenerateHandler(c *gin.Context) { func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now() checkpointStart := time.Now()
var req api.GenerateRequest var r api.GenerateRequest
err := c.ShouldBindJSON(&req) if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
case err != nil: } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
n := model.ParseName(r.Model)
// validate the request // validate the request
switch { switch {
case req.Model == "": case !n.IsValid():
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
case len(req.Format) > 0 && req.Format != "json": case len(r.Format) > 0 && r.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return return
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0): case r.Raw && (r.Template != "" || r.System != "" || len(r.Context) > 0):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return return
} }
for _, img := range req.Images { for _, img := range r.Images {
if !isSupportedImageType(img) { if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return return
} }
} }
model, err := GetModel(req.Model) model, err := GetModel(n)
if err != nil { if err != nil {
var pErr *fs.PathError var pErr *fs.PathError
if errors.As(err, &pErr) { if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", r.Model)})
return return
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@ -125,17 +123,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
opts, err := modelOptions(model, req.Options) opts, err := modelOptions(model, r.Options)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
var sessionDuration time.Duration var sessionDuration time.Duration
if req.KeepAlive == nil { if r.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration() sessionDuration = getDefaultSessionDuration()
} else { } else {
sessionDuration = req.KeepAlive.Duration sessionDuration = r.KeepAlive.Duration
} }
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration) rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
@ -150,10 +148,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// an empty request loads the model // an empty request loads the model
// note: for a short while template was used in lieu // note: for a short while template was used in lieu
// of `raw` mode so we need to check for it too // of `raw` mode so we need to check for it too
if req.Prompt == "" && req.Template == "" && req.System == "" { if r.Prompt == "" && r.Template == "" && r.System == "" {
c.JSON(http.StatusOK, api.GenerateResponse{ c.JSON(http.StatusOK, api.GenerateResponse{
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Model: req.Model, Model: r.Model,
Done: true, Done: true,
DoneReason: "load", DoneReason: "load",
}) })
@ -164,37 +162,37 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var prompt string var prompt string
switch { switch {
case req.Raw: case r.Raw:
prompt = req.Prompt prompt = r.Prompt
case req.Prompt != "": case r.Prompt != "":
if req.Template == "" { if r.Template == "" {
req.Template = model.Template r.Template = model.Template
} }
if req.System == "" { if r.System == "" {
req.System = model.System r.System = model.System
} }
slog.Debug("generate handler", "prompt", req.Prompt) slog.Debug("generate handler", "prompt", r.Prompt)
slog.Debug("generate handler", "template", req.Template) slog.Debug("generate handler", "template", r.Template)
slog.Debug("generate handler", "system", req.System) slog.Debug("generate handler", "system", r.System)
var sb strings.Builder var sb strings.Builder
for i := range req.Images { for i := range r.Images {
fmt.Fprintf(&sb, "[img-%d] ", i) fmt.Fprintf(&sb, "[img-%d] ", i)
} }
sb.WriteString(req.Prompt) sb.WriteString(r.Prompt)
p, err := Prompt(req.Template, req.System, sb.String(), "", true) p, err := Prompt(r.Template, r.System, sb.String(), "", true)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
sb.Reset() sb.Reset()
if req.Context != nil { if r.Context != nil {
prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context) prev, err := runner.llama.Detokenize(c.Request.Context(), r.Context)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -215,33 +213,33 @@ func (s *Server) GenerateHandler(c *gin.Context) {
go func() { go func() {
defer close(ch) defer close(ch)
fn := func(r llm.CompletionResponse) { fn := func(comp llm.CompletionResponse) {
// Build up the full response // Build up the full response
if _, err := generated.WriteString(r.Content); err != nil { if _, err := generated.WriteString(comp.Content); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
return return
} }
resp := api.GenerateResponse{ resp := api.GenerateResponse{
Model: req.Model, Model: r.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Done: r.Done, Done: comp.Done,
Response: r.Content, DoneReason: comp.DoneReason,
DoneReason: r.DoneReason, Response: comp.Content,
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: comp.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: comp.PromptEvalDuration,
EvalCount: r.EvalCount, EvalCount: comp.EvalCount,
EvalDuration: r.EvalDuration, EvalDuration: comp.EvalDuration,
}, },
} }
if r.Done { if comp.Done {
resp.TotalDuration = time.Since(checkpointStart) resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart) resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw { if !r.Raw {
p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false) p, err := Prompt(r.Template, r.System, r.Prompt, generated.String(), false)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -254,7 +252,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
resp.Context = append(req.Context, tokens...) resp.Context = append(r.Context, tokens...)
} }
} }
@ -262,17 +260,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
var images []llm.ImageData var images []llm.ImageData
for i := range req.Images { for i := range r.Images {
images = append(images, llm.ImageData{ images = append(images, llm.ImageData{
ID: i, ID: i,
Data: req.Images[i], Data: r.Images[i],
}) })
} }
// Start prediction // Start prediction
req := llm.CompletionRequest{ req := llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Format: req.Format, Format: r.Format,
Images: images, Images: images,
Options: opts, Options: opts,
} }
@ -281,7 +279,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
}() }()
if req.Stream != nil && !*req.Stream { if r.Stream != nil && !*r.Stream {
// Accumulate responses into the final response // Accumulate responses into the final response
var final api.GenerateResponse var final api.GenerateResponse
var sb strings.Builder var sb strings.Builder
@ -339,44 +337,43 @@ func getDefaultSessionDuration() time.Duration {
} }
func (s *Server) EmbeddingsHandler(c *gin.Context) { func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest var r api.EmbeddingRequest
err := c.ShouldBindJSON(&req) if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
case err != nil: } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
if req.Model == "" { n := model.ParseName(r.Model)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", r.Model)})
return return
} }
model, err := GetModel(req.Model) model, err := GetModel(n)
if err != nil { if err != nil {
var pErr *fs.PathError var pErr *fs.PathError
if errors.As(err, &pErr) { if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", r.Model)})
return return
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
opts, err := modelOptions(model, req.Options) opts, err := modelOptions(model, r.Options)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
var sessionDuration time.Duration var sessionDuration time.Duration
if req.KeepAlive == nil { if r.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration() sessionDuration = getDefaultSessionDuration()
} else { } else {
sessionDuration = req.KeepAlive.Duration sessionDuration = r.KeepAlive.Duration
} }
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration) rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
@ -389,12 +386,12 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
} }
// an empty request loads the model // an empty request loads the model
if req.Prompt == "" { if r.Prompt == "" {
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}}) c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
return return
} }
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt) embedding, err := runner.llama.Embedding(c.Request.Context(), r.Prompt)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
@ -408,24 +405,18 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
} }
func (s *Server) PullModelHandler(c *gin.Context) { func (s *Server) PullModelHandler(c *gin.Context) {
var req api.PullRequest var r api.PullRequest
err := c.ShouldBindJSON(&req) if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
case err != nil: } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
var model string n := model.ParseName(cmp.Or(r.Model, r.Name))
if req.Model != "" { if !n.IsValid() {
model = req.Model c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
} }
@ -436,19 +427,15 @@ func (s *Server) PullModelHandler(c *gin.Context) {
ch <- r ch <- r
} }
regOpts := &registryOptions{
Insecure: req.Insecure,
}
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := PullModel(ctx, model, regOpts, fn); err != nil { if err := PullModel(ctx, n, registryOptions{Insecure: r.Insecure}, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
if req.Stream != nil && !*req.Stream { if r.Stream != nil && !*r.Stream {
waitForStream(c, ch) waitForStream(c, ch)
return return
} }
@ -457,24 +444,18 @@ func (s *Server) PullModelHandler(c *gin.Context) {
} }
func (s *Server) PushModelHandler(c *gin.Context) { func (s *Server) PushModelHandler(c *gin.Context) {
var req api.PushRequest var r api.PushRequest
err := c.ShouldBindJSON(&req) if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
case err != nil: } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
var model string n := model.ParseName(cmp.Or(r.Model, r.Name))
if req.Model != "" { if !n.IsValid() {
model = req.Model c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
} }
@ -485,19 +466,15 @@ func (s *Server) PushModelHandler(c *gin.Context) {
ch <- r ch <- r
} }
regOpts := &registryOptions{
Insecure: req.Insecure,
}
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := PushModel(ctx, model, regOpts, fn); err != nil { if err := PushModel(ctx, n, registryOptions{Insecure: r.Insecure}, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
if req.Stream != nil && !*req.Stream { if r.Stream != nil && !*r.Stream {
waitForStream(c, ch) waitForStream(c, ch)
return return
} }
@ -506,8 +483,8 @@ func (s *Server) PushModelHandler(c *gin.Context) {
} }
func (s *Server) CreateModelHandler(c *gin.Context) { func (s *Server) CreateModelHandler(c *gin.Context) {
var req api.CreateRequest var r api.CreateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
} else if err != nil { } else if err != nil {
@ -515,30 +492,30 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
return return
} }
name := model.ParseName(cmp.Or(req.Model, req.Name)) name := model.ParseName(cmp.Or(r.Model, r.Name))
if !name.IsValid() { if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
return return
} }
if req.Path == "" && req.Modelfile == "" { if r.Path == "" && r.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
return return
} }
var r io.Reader = strings.NewReader(req.Modelfile) var rd io.Reader = strings.NewReader(r.Modelfile)
if req.Path != "" && req.Modelfile == "" { if r.Path != "" && r.Modelfile == "" {
f, err := os.Open(req.Path) f, err := os.Open(r.Path)
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
return return
} }
defer f.Close() defer f.Close()
r = f rd = f
} }
modelfile, err := model.ParseFile(r) f, err := model.ParseFile(rd)
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
@ -554,17 +531,13 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
quantization := req.Quantization quantization := cmp.Or(r.Quantize, r.Quantization)
if req.Quantize != "" { if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
quantization = req.Quantize
}
if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(quantization), modelfile, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
if req.Stream != nil && !*req.Stream { if r.Stream != nil && !*r.Stream {
waitForStream(c, ch) waitForStream(c, ch)
return return
} }
@ -573,75 +546,58 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
} }
func (s *Server) DeleteModelHandler(c *gin.Context) { func (s *Server) DeleteModelHandler(c *gin.Context) {
var req api.DeleteRequest var r api.DeleteRequest
err := c.ShouldBindJSON(&req) if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
case err != nil: } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
var model string n := model.ParseName(cmp.Or(r.Model, r.Name))
if req.Model != "" { if !n.IsValid() {
model = req.Model c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
} }
if err := DeleteModel(model); err != nil { m, err := ParseNamedManifest(n)
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
manifestsPath, err := GetManifestPath()
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if err := PruneDirectory(manifestsPath); err != nil { if err := m.Remove(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
c.JSON(http.StatusOK, nil) if err := m.RemoveLayers(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} }
func (s *Server) ShowModelHandler(c *gin.Context) { func (s *Server) ShowModelHandler(c *gin.Context) {
var req api.ShowRequest var r api.ShowRequest
err := c.ShouldBindJSON(&req) if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
case err != nil: } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
if req.Model != "" { n := model.ParseName(cmp.Or(r.Model, r.Name))
// noop if !n.IsValid() {
} else if req.Name != "" { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
req.Model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
} }
resp, err := GetModelInfo(req) resp, err := GetModelInfo(n, r)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", r.Model)})
} else { } else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} }
@ -651,8 +607,8 @@ func (s *Server) ShowModelHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
} }
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { func GetModelInfo(name model.Name, req api.ShowRequest) (*api.ShowResponse, error) {
model, err := GetModel(req.Model) model, err := GetModel(name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -710,7 +666,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
var sb strings.Builder var sb strings.Builder
fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"") fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"")
fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:") fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
fmt.Fprintf(&sb, "# FROM %s\n\n", model.ShortName) fmt.Fprintf(&sb, "# FROM %s\n\n", name.DisplayShortest())
fmt.Fprint(&sb, model.String()) fmt.Fprint(&sb, model.String())
resp.Modelfile = sb.String() resp.Modelfile = sb.String()
@ -718,72 +674,42 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
} }
func (s *Server) ListModelsHandler(c *gin.Context) { func (s *Server) ListModelsHandler(c *gin.Context) {
manifests, err := GetManifestPath() ms, err := Manifests()
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
var models []api.ModelResponse var models []api.ModelResponse
if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error { for n, m := range ms {
if !info.IsDir() { f, err := m.Config.Open()
rel, err := filepath.Rel(manifests, path) if err != nil {
if err != nil { slog.Warn("bad manifest filepath", "name", n, "error", err)
return err continue
} }
defer f.Close()
if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil { var cf ConfigV2
return err if err := json.NewDecoder(f).Decode(&cf); err != nil {
} else if hidden { slog.Warn("bad manifest config", "name", n, "error", err)
return nil continue
}
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
slog.Warn("bad manifest filepath", "path", rel)
return nil
}
m, err := ParseNamedManifest(n)
if err != nil {
slog.Warn("bad manifest", "name", n, "error", err)
return nil
}
f, err := m.Config.Open()
if err != nil {
slog.Warn("bad manifest config filepath", "name", n, "error", err)
return nil
}
defer f.Close()
var c ConfigV2
if err := json.NewDecoder(f).Decode(&c); err != nil {
slog.Warn("bad manifest config", "name", n, "error", err)
return nil
}
// tag should never be masked
models = append(models, api.ModelResponse{
Model: n.DisplayShortest(),
Name: n.DisplayShortest(),
Size: m.Size(),
Digest: m.Digest,
ModifiedAt: info.ModTime(),
Details: api.ModelDetails{
Format: c.ModelFormat,
Family: c.ModelFamily,
Families: c.ModelFamilies,
ParameterSize: c.ModelType,
QuantizationLevel: c.FileType,
},
})
} }
return nil // tag should never be masked
}); err != nil { models = append(models, api.ModelResponse{
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) Model: n.DisplayShortest(),
return Name: n.DisplayShortest(),
Size: m.Size(),
Digest: m.digest,
ModifiedAt: m.fi.ModTime(),
Details: api.ModelDetails{
Format: cf.ModelFormat,
Family: cf.ModelFamily,
Families: cf.ModelFamilies,
ParameterSize: cf.ModelType,
QuantizationLevel: cf.FileType,
},
})
} }
slices.SortStableFunc(models, func(i, j api.ModelResponse) int { slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
@ -1029,17 +955,15 @@ func Serve(ln net.Listener) error {
if !envconfig.NoPrune { if !envconfig.NoPrune {
// clean up unused layers and manifests // clean up unused layers and manifests
if err := PruneLayers(); err != nil { layers, err := Layers()
return err
}
manifestsPath, err := GetManifestPath()
if err != nil { if err != nil {
return err return err
} }
if err := PruneDirectory(manifestsPath); err != nil { for _, layer := range layers {
return err if err := layer.Remove(); err != nil {
return err
}
} }
} }
@ -1155,27 +1079,21 @@ func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now() checkpointStart := time.Now()
var req api.ChatRequest var req api.ChatRequest
err := c.ShouldBindJSON(&req) if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
case err != nil: } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
// validate the request n := model.ParseName(req.Model)
switch { if !n.IsValid() {
case req.Model == "": c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", req.Model)})
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return return
} }
model, err := GetModel(req.Model) model, err := GetModel(n)
if err != nil { if err != nil {
var pErr *fs.PathError var pErr *fs.PathError
if errors.As(err, &pErr) { if errors.As(err, &pErr) {

View File

@ -53,6 +53,8 @@ func Test_Routes(t *testing.T) {
} }
createTestModel := func(t *testing.T, name string) { createTestModel := func(t *testing.T, name string) {
t.Helper()
fname := createTestFile(t, "ollama-model") fname := createTestFile(t, "ollama-model")
r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname)) r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
@ -61,7 +63,7 @@ func Test_Routes(t *testing.T) {
fn := func(resp api.ProgressResponse) { fn := func(resp api.ProgressResponse) {
t.Logf("Status: %s", resp.Status) t.Logf("Status: %s", resp.Status)
} }
err = CreateModel(context.TODO(), name, "", "", modelfile, fn) err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn)
assert.Nil(t, err) assert.Nil(t, err)
} }
@ -144,9 +146,9 @@ func Test_Routes(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, resp.StatusCode, 200) assert.Equal(t, resp.StatusCode, 200)
model, err := GetModel("t-bone") m, err := GetModel(model.ParseName("t-bone"))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "t-bone:latest", model.ShortName) assert.Equal(t, "t-bone:latest", m.Name.DisplayShortest())
}, },
}, },
{ {
@ -165,9 +167,9 @@ func Test_Routes(t *testing.T) {
req.Body = io.NopCloser(bytes.NewReader(jsonData)) req.Body = io.NopCloser(bytes.NewReader(jsonData))
}, },
Expected: func(t *testing.T, resp *http.Response) { Expected: func(t *testing.T, resp *http.Response) {
model, err := GetModel("beefsteak") m, err := GetModel(model.ParseName("beefsteak"))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "beefsteak:latest", model.ShortName) assert.Equal(t, "beefsteak:latest", m.Name.DisplayShortest())
}, },
}, },
{ {

View File

@ -310,7 +310,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
// show a generalized compatibility error until there is a better way to // show a generalized compatibility error until there is a better way to
// check for model compatibility // check for model compatibility
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") { if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName) err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.Name.DisplayShortest())
} }
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err) slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
req.errCh <- err req.errCh <- err

View File

@ -16,6 +16,7 @@ import (
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/server/envconfig" "github.com/ollama/ollama/server/envconfig"
"github.com/ollama/ollama/types/model"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -107,12 +108,12 @@ func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.
return scenario.srv, nil return scenario.srv, nil
} }
func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle { func newScenario(t *testing.T, ctx context.Context, name model.Name, estimatedVRAM uint64) *bundle {
scenario := &bundle{} scenario := &bundle{}
scenario.ctx, scenario.ctxDone = context.WithCancel(ctx) scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
t.Helper() t.Helper()
f, err := os.CreateTemp(t.TempDir(), modelName) f, err := os.CreateTemp(t.TempDir(), name.Model)
assert.Nil(t, err) assert.Nil(t, err)
defer f.Close() defer f.Close()
@ -134,7 +135,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
assert.Nil(t, err) assert.Nil(t, err)
fname := f.Name() fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname} model := &Model{Name: name, ModelPath: fname}
scenario.ggml, err = llm.LoadModel(model.ModelPath) scenario.ggml, err = llm.LoadModel(model.ModelPath)
require.NoError(t, err) require.NoError(t, err)
@ -155,24 +156,24 @@ func TestRequests(t *testing.T) {
defer done() defer done()
// Same model, same request // Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1", 10) scenario1a := newScenario(t, ctx, model.ParseName("ollama-model-1"), 10)
scenario1a.req.sessionDuration = 0 scenario1a.req.sessionDuration = 0
scenario1b := newScenario(t, ctx, "ollama-model-1", 11) scenario1b := newScenario(t, ctx, model.ParseName("ollama-model-1"), 11)
scenario1b.req.model = scenario1a.req.model scenario1b.req.model = scenario1a.req.model
scenario1b.ggml = scenario1a.ggml scenario1b.ggml = scenario1a.ggml
scenario1b.req.sessionDuration = 0 scenario1b.req.sessionDuration = 0
// simple reload of same model // simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20) scenario2a := newScenario(t, ctx, model.ParseName("ollama-model-1"), 20)
scenario2a.req.model = scenario1a.req.model scenario2a.req.model = scenario1a.req.model
scenario2a.ggml = scenario1a.ggml scenario2a.ggml = scenario1a.ggml
// Multiple loaded models // Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte) scenario3a := newScenario(t, ctx, model.ParseName("ollama-model-3a"), 1*format.GigaByte)
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte) scenario3b := newScenario(t, ctx, model.ParseName("ollama-model-3b"), 24*format.GigaByte)
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30) scenario3c := newScenario(t, ctx, model.ParseName("ollama-model-4a"), 30)
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded scenario3d := newScenario(t, ctx, model.ParseName("ollama-model-3c"), 30) // Needs prior unloaded
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList { s.getGpuFn = func() gpu.GpuInfoList {
@ -310,11 +311,11 @@ func TestGetRunner(t *testing.T) {
defer done() defer done()
// Same model, same request // Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) scenario1a := newScenario(t, ctx, model.ParseName("ollama-model-1a"), 10)
scenario1a.req.sessionDuration = 0 scenario1a.req.sessionDuration = 0
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10) scenario1b := newScenario(t, ctx, model.ParseName("ollama-model-1b"), 10)
scenario1b.req.sessionDuration = 0 scenario1b.req.sessionDuration = 0
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10) scenario1c := newScenario(t, ctx, model.ParseName("ollama-model-1c"), 10)
scenario1c.req.sessionDuration = 0 scenario1c.req.sessionDuration = 0
envconfig.MaxQueuedRequests = 1 envconfig.MaxQueuedRequests = 1
s := InitScheduler(ctx) s := InitScheduler(ctx)
@ -370,7 +371,7 @@ func TestPrematureExpired(t *testing.T) {
defer done() defer done()
// Same model, same request // Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) scenario1a := newScenario(t, ctx, model.ParseName("ollama-model-1a"), 10)
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList { s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"} g := gpu.GpuInfo{Library: "metal"}

View File

@ -12,12 +12,14 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"path"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/types/model"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
@ -55,9 +57,10 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
} }
if b.From != "" { if b.From != "" {
n := model.ParseName(b.From)
values := requestURL.Query() values := requestURL.Query()
values.Add("mount", b.Digest) values.Add("mount", b.Digest)
values.Add("from", ParseModelPath(b.From).GetNamespaceRepository()) values.Add("from", path.Join(n.Namespace, n.Model))
requestURL.RawQuery = values.Encode() requestURL.RawQuery = values.Encode()
} }
@ -360,40 +363,46 @@ func (p *progressWriter) Rollback() {
p.written = 0 p.written = 0
} }
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryOptions, fn func(api.ProgressResponse)) error { type uploadOptions struct {
requestURL := mp.BaseURL() name model.Name
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest) baseURL *url.URL
layer *Layer
regOpts *registryOptions
fn func(api.ProgressResponse)
}
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts) func uploadBlob(ctx context.Context, opts uploadOptions) error {
requestURL := opts.baseURL.JoinPath("v2", opts.name.Namespace, opts.name.Model, "blobs", opts.layer.Digest)
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts.regOpts)
switch { switch {
case errors.Is(err, os.ErrNotExist): case errors.Is(err, os.ErrNotExist):
case err != nil: case err != nil:
return err return err
default: default:
defer resp.Body.Close() defer resp.Body.Close()
fn(api.ProgressResponse{ opts.fn(api.ProgressResponse{
Status: fmt.Sprintf("pushing %s", layer.Digest[7:19]), Status: fmt.Sprintf("pushing %s", opts.layer.Digest[7:19]),
Digest: layer.Digest, Digest: opts.layer.Digest,
Total: layer.Size, Total: opts.layer.Size,
Completed: layer.Size, Completed: opts.layer.Size,
}) })
return nil return nil
} }
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer}) data, ok := blobUploadManager.LoadOrStore(opts.layer.Digest, &blobUpload{Layer: opts.layer})
upload := data.(*blobUpload) upload := data.(*blobUpload)
if !ok { if !ok {
requestURL := mp.BaseURL() requestURL := opts.baseURL.JoinPath("v2", opts.name.Namespace, opts.name.Model, "blobs/uploads/")
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/") if err := upload.Prepare(ctx, requestURL, opts.regOpts); err != nil {
if err := upload.Prepare(ctx, requestURL, opts); err != nil { blobUploadManager.Delete(opts.layer.Digest)
blobUploadManager.Delete(layer.Digest)
return err return err
} }
// nolint: contextcheck //nolint:contextcheck
go upload.Run(context.Background(), opts) go upload.Run(context.Background(), opts.regOpts)
} }
return upload.Wait(ctx, fn) return upload.Wait(ctx, opts.fn)
} }