Compare commits

...

10 Commits

Author SHA1 Message Date
Michael Yang
980070dce6 remove last bits of ParseModelPath 2024-05-13 14:25:37 -07:00
Michael Yang
af11838245 update push to use model.Name 2024-05-13 14:25:27 -07:00
Michael Yang
7cb2fd3555 fix(server): prune files 2024-05-13 14:24:52 -07:00
Michael Yang
0aeaeaa058 update pull handler to use model.Name 2024-05-13 14:24:05 -07:00
Michael Yang
b91cf0893d update create handler to use model.Name 2024-05-13 13:41:13 -07:00
Michael Yang
8215545c6d more resilient Manifests 2024-05-13 13:40:50 -07:00
Michael Yang
34a1dbe6ec filepath.Join 2024-05-13 13:40:50 -07:00
Michael Yang
6f2a09abfd remove DeleteModel 2024-05-13 13:40:50 -07:00
Michael Yang
14f9dc4e6a routes: use Manifests for ListHandler 2024-05-13 13:40:50 -07:00
Michael Yang
eeba2cbae3 update delete handler to use model.Name 2024-05-13 13:40:50 -07:00
15 changed files with 608 additions and 864 deletions

View File

@ -24,7 +24,8 @@ func terminate(cmd *exec.Cmd) error {
if err != nil {
return err
}
defer dll.Release() // nolint: errcheck
//nolint:errcheck
defer dll.Release()
pid := cmd.Process.Pid
@ -73,7 +74,8 @@ func isProcessExited(pid int) (bool, error) {
if err != nil {
return false, fmt.Errorf("failed to open process: %v", err)
}
defer windows.CloseHandle(hProcess) // nolint: errcheck
//nolint:errcheck
defer windows.CloseHandle(hProcess)
var exitCode uint32
err = windows.GetExitCodeProcess(hProcess, &exitCode)

View File

@ -81,7 +81,7 @@ func (i *Instance) Readline() (string, error) {
defer func() {
fd := int(syscall.Stdin)
// nolint: errcheck
//nolint:errcheck
UnsetRawMode(fd, i.Terminal.termios)
i.Terminal.rawmode = false
}()

View File

@ -23,6 +23,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/types/model"
)
const maxRetries = 6
@ -332,15 +333,16 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
}
}
type downloadOpts struct {
mp ModelPath
type downloadOptions struct {
name model.Name
baseURL *url.URL
digest string
regOpts *registryOptions
fn func(api.ProgressResponse)
}
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) error {
func downloadBlob(ctx context.Context, opts downloadOptions) error {
fp, err := GetBlobsPath(opts.digest)
if err != nil {
return err
@ -365,14 +367,13 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
download := data.(*blobDownload)
if !ok {
requestURL := opts.mp.BaseURL()
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
requestURL := opts.baseURL.JoinPath("blobs", opts.digest)
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
blobDownloadManager.Delete(opts.digest)
return err
}
// nolint: contextcheck
//nolint:contextcheck
go download.Run(context.Background(), requestURL, opts.regOpts)
}

View File

@ -4,18 +4,16 @@ import (
"bytes"
"cmp"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"log/slog"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"runtime"
"strconv"
@ -41,9 +39,8 @@ type registryOptions struct {
}
type Model struct {
Name string `json:"name"`
Name model.Name
Config ConfigV2
ShortName string
ModelPath string
ParentModel string
AdapterPaths []string
@ -160,46 +157,17 @@ type RootFS struct {
DiffIDs []string `json:"diff_ids"`
}
func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
fp, err := mp.GetManifestPath()
if err != nil {
return nil, "", err
}
if _, err = os.Stat(fp); err != nil {
return nil, "", err
}
var manifest *ManifestV2
bts, err := os.ReadFile(fp)
if err != nil {
return nil, "", fmt.Errorf("couldn't open file '%s'", fp)
}
shaSum := sha256.Sum256(bts)
shaStr := hex.EncodeToString(shaSum[:])
if err := json.Unmarshal(bts, &manifest); err != nil {
return nil, "", err
}
return manifest, shaStr, nil
}
func GetModel(name string) (*Model, error) {
mp := ParseModelPath(name)
manifest, digest, err := GetManifest(mp)
func GetModel(name model.Name) (*Model, error) {
manifest, err := ParseNamedManifest(name)
if err != nil {
return nil, err
}
model := &Model{
Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(),
Digest: digest,
Template: "{{ .Prompt }}",
License: []string{},
Name: name,
Digest: manifest.digest,
Template: "{{ .Prompt }}",
License: []string{},
}
filename, err := GetBlobsPath(manifest.Config.Digest)
@ -314,7 +282,7 @@ func realpath(rel, from string) string {
return abspath
}
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) {
func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) {
config := ConfigV2{
OS: "linux",
Architecture: "amd64",
@ -546,16 +514,10 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
}
}
unref := make(map[string]struct{})
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
for _, layer := range manifest.Layers {
if !slices.Contains(digests, layer.Digest) {
unref[layer.Digest] = struct{}{}
}
}
if manifest.Config.Digest != layer.Digest {
unref[manifest.Config.Digest] = struct{}{}
if !envconfig.NoPrune {
if old, err := ParseNamedManifest(name); err == nil {
//nolint:errcheck
defer old.RemoveLayers()
}
}
@ -564,12 +526,6 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
return err
}
if !envconfig.NoPrune {
if err := deleteUnusedLayers(nil, unref); err != nil {
return err
}
}
fn(api.ProgressResponse{Status: "success"})
return nil
}
@ -613,207 +569,42 @@ func CopyModel(src, dst model.Name) error {
return err
}
func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}) error {
fp, err := GetManifestPath()
if err != nil {
return err
}
walkFunc := func(path string, info os.FileInfo, _ error) error {
if info.IsDir() {
return nil
}
dir, file := filepath.Split(path)
dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
tag := strings.Join([]string{dir, file}, ":")
fmp := ParseModelPath(tag)
// skip the manifest we're trying to delete
if skipModelPath != nil && skipModelPath.GetFullTagname() == fmp.GetFullTagname() {
return nil
}
// save (i.e. delete from the deleteMap) any files used in other manifests
manifest, _, err := GetManifest(fmp)
if err != nil {
// nolint: nilerr
return nil
}
for _, layer := range manifest.Layers {
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
return nil
}
if err := filepath.Walk(fp, walkFunc); err != nil {
return err
}
// only delete the files which are still in the deleteMap
for k := range deleteMap {
fp, err := GetBlobsPath(k)
if err != nil {
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
continue
}
if err := os.Remove(fp); err != nil {
slog.Info(fmt.Sprintf("couldn't remove file '%s': %v", fp, err))
continue
}
}
return nil
}
func PruneLayers() error {
deleteMap := make(map[string]struct{})
p, err := GetBlobsPath("")
if err != nil {
return err
}
blobs, err := os.ReadDir(p)
if err != nil {
slog.Info(fmt.Sprintf("couldn't read dir '%s': %v", p, err))
return err
}
for _, blob := range blobs {
name := blob.Name()
name = strings.ReplaceAll(name, "-", ":")
_, err := GetBlobsPath(name)
if err != nil {
if errors.Is(err, ErrInvalidDigestFormat) {
// remove invalid blobs (e.g. partial downloads)
if err := os.Remove(filepath.Join(p, blob.Name())); err != nil {
slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err)
}
}
continue
}
deleteMap[name] = struct{}{}
}
slog.Info(fmt.Sprintf("total blobs: %d", len(deleteMap)))
err = deleteUnusedLayers(nil, deleteMap)
if err != nil {
return err
}
slog.Info(fmt.Sprintf("total unused blobs removed: %d", len(deleteMap)))
return nil
}
func PruneDirectory(path string) error {
info, err := os.Lstat(path)
if err != nil {
return err
}
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
entries, err := os.ReadDir(path)
if err != nil {
return err
}
for _, entry := range entries {
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
return err
}
}
entries, err = os.ReadDir(path)
if err != nil {
return err
}
if len(entries) > 0 {
return nil
}
return os.Remove(path)
}
return nil
}
func DeleteModel(name string) error {
mp := ParseModelPath(name)
manifest, _, err := GetManifest(mp)
if err != nil {
return err
}
deleteMap := make(map[string]struct{})
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = struct{}{}
}
deleteMap[manifest.Config.Digest] = struct{}{}
err = deleteUnusedLayers(&mp, deleteMap)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
err = os.Remove(fp)
if err != nil {
slog.Info(fmt.Sprintf("couldn't remove manifest file '%s': %v", fp, err))
return err
}
return nil
}
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
func PushModel(ctx context.Context, name model.Name, opts registryOptions, fn func(api.ProgressResponse)) error {
fn(api.ProgressResponse{Status: "retrieving manifest"})
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return fmt.Errorf("insecure protocol http")
}
manifest, _, err := GetManifest(mp)
m, err := ParseNamedManifest(name)
if err != nil {
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
return err
}
var layers []*Layer
layers = append(layers, manifest.Layers...)
layers = append(layers, manifest.Config)
scheme := "https"
if opts.Insecure {
scheme = "http"
}
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
baseURL, err := url.Parse(fmt.Sprintf("%s://%s", scheme, path.Join(name.Host, "v2", name.Namespace, name.Model)))
if err != nil {
return err
}
for _, layer := range append(m.Layers, m.Config) {
if err := uploadBlob(ctx, uploadOptions{name: name, baseURL: baseURL, layer: layer, regOpts: &opts, fn: fn}); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
return err
}
}
fn(api.ProgressResponse{Status: "pushing manifest"})
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
requestURL := baseURL.JoinPath("v2", name.Namespace, name.Model, "manifests", name.Tag)
manifestJSON, err := json.Marshal(manifest)
manifestJSON, err := json.Marshal(m)
if err != nil {
return err
}
headers := make(http.Header)
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), &opts)
if err != nil {
return err
}
@ -824,118 +615,72 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return nil
}
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
func PullModel(ctx context.Context, name model.Name, opts registryOptions, fn func(api.ProgressResponse)) error {
old, _ := ParseNamedManifest(name)
var manifest *ManifestV2
var err error
var noprune string
// build deleteMap to prune unused layers
deleteMap := make(map[string]struct{})
if !envconfig.NoPrune {
manifest, _, err = GetManifest(mp)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if manifest != nil {
for _, l := range manifest.Layers {
deleteMap[l.Digest] = struct{}{}
}
deleteMap[manifest.Config.Digest] = struct{}{}
}
if !name.IsFullyQualified() {
return model.Unqualified(name)
}
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return fmt.Errorf("insecure protocol http")
scheme := "https"
if opts.Insecure {
scheme = "http"
}
baseURL, err := url.Parse(fmt.Sprintf("%s://%s", scheme, path.Join(name.Host, "v2", name.Namespace, name.Model)))
if err != nil {
return err
}
fn(api.ProgressResponse{Status: "pulling manifest"})
manifest, err = pullModelManifest(ctx, mp, regOpts)
m, err := pullModelManifest(ctx, name, baseURL, &opts)
if err != nil {
return fmt.Errorf("pull model manifest: %s", err)
}
var layers []*Layer
layers = append(layers, manifest.Layers...)
layers = append(layers, manifest.Config)
layers := append(m.Layers, m.Config)
for _, layer := range layers {
if err := downloadBlob(
ctx,
downloadOpts{
mp: mp,
downloadOptions{
name: name,
baseURL: baseURL,
digest: layer.Digest,
regOpts: regOpts,
regOpts: &opts,
fn: fn,
}); err != nil {
return err
}
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers {
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
}
}
if err := layer.Verify(); err != nil {
_ = layer.Remove()
return err
}
}
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
if err != nil {
if err := WriteManifest(name, m.Config, m.Layers); err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
return err
}
err = os.WriteFile(fp, manifestJSON, 0o644)
if err != nil {
slog.Info(fmt.Sprintf("couldn't write to %s", fp))
return err
}
if noprune == "" {
if !envconfig.NoPrune && old != nil {
fn(api.ProgressResponse{Status: "removing any unused layers"})
err = deleteUnusedLayers(nil, deleteMap)
if err != nil {
return err
}
_ = old.RemoveLayers()
}
fn(api.ProgressResponse{Status: "success"})
return nil
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
func pullModelManifest(ctx context.Context, name model.Name, baseURL *url.URL, opts *registryOptions) (*ManifestV2, error) {
requestURL := baseURL.JoinPath("manifests", name.Tag)
headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
if err != nil {
return nil, err
}
@ -949,17 +694,6 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
return m, err
}
// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
func GetSHA256Digest(r io.Reader) (string, int64) {
h := sha256.New()
n, err := io.Copy(h, r)
if err != nil {
log.Fatal(err)
}
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
var errUnauthorized = fmt.Errorf("unauthorized: access denied")
// getTokenSubject returns the subject of a JWT token, it does not validate the token
@ -1119,25 +853,3 @@ func parseRegistryChallenge(authStr string) registryChallenge {
Scope: getValue(authStr, "scope"),
}
}
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
func verifyBlob(digest string) error {
fp, err := GetBlobsPath(digest)
if err != nil {
return err
}
f, err := os.Open(fp)
if err != nil {
return err
}
defer f.Close()
fileDigest, _ := GetSHA256Digest(f)
if digest != fileDigest {
return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest)
}
return nil
}

View File

@ -4,7 +4,10 @@ import (
"crypto/sha256"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strings"
)
type Layer struct {
@ -88,3 +91,81 @@ func (l *Layer) Open() (io.ReadCloser, error) {
return os.Open(blob)
}
func (l *Layer) Remove() error {
ms, err := Manifests()
if err != nil {
return err
}
for _, m := range ms {
for _, layer := range append(m.Layers, m.Config) {
if layer.Digest == l.Digest {
// something is using this layer
return nil
}
}
}
p, err := GetBlobsPath("")
if err != nil {
return err
}
return os.Remove(filepath.Join(p, l.Digest))
}
func (l *Layer) Verify() error {
rc, err := l.Open()
if err != nil {
return err
}
defer rc.Close()
sha256sum := sha256.New()
if _, err := io.Copy(sha256sum, rc); err != nil {
return err
}
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
if digest != l.Digest {
return fmt.Errorf("digest mismatch: %s != %s", digest, l.Digest)
}
return nil
}
func Layers() (map[string]*Layer, error) {
blobs, err := GetBlobsPath("")
if err != nil {
return nil, err
}
// TODO(mxyng): use something less brittle
matches, err := filepath.Glob(filepath.Join(blobs, "*"))
if err != nil {
return nil, err
}
ds := make(map[string]*Layer)
for _, match := range matches {
rel, err := filepath.Rel(blobs, match)
if err != nil {
slog.Warn("bad filepath", "path", match, "error", err)
continue
}
// TODO(mxyng): this should ideally use model.Digest but
// that's currently incompatible with the manifest digest
d := strings.Replace(rel, "sha256-", "sha256:", 1)
layer, err := NewLayerFromLayer(d, "", "")
if err != nil {
slog.Warn("bad blob", "digest", d, "error", err)
layer = &Layer{Digest: rel}
}
ds[d] = layer
}
return ds, nil
}

View File

@ -1,11 +1,11 @@
package server
import (
"bytes"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
@ -14,7 +14,10 @@ import (
type Manifest struct {
ManifestV2
Digest string `json:"-"`
filepath string
fi os.FileInfo
digest string
}
func (m *Manifest) Size() (size int64) {
@ -25,9 +28,67 @@ func (m *Manifest) Size() (size int64) {
return
}
func ParseNamedManifest(name model.Name) (*Manifest, error) {
if !name.IsFullyQualified() {
return nil, model.Unqualified(name)
func (m *Manifest) Remove() error {
if err := os.Remove(m.filepath); err != nil {
return err
}
manifests, err := GetManifestPath()
if err != nil {
return err
}
return pruneEmptyDirectory(manifests)
}
func pruneEmptyDirectory(p string) error {
fi, err := os.Lstat(p)
if err != nil {
return err
}
if fi.Mode()&os.ModeSymlink == 0 {
entries, err := os.ReadDir(p)
if err != nil {
return err
}
for _, entry := range entries {
if entry.IsDir() {
if err := pruneEmptyDirectory(filepath.Join(p, entry.Name())); err != nil {
return err
}
}
}
entries, err = os.ReadDir(p)
if err != nil {
return err
}
if len(entries) == 0 {
if err := os.Remove(p); err != nil {
return err
}
}
}
return nil
}
func (m *Manifest) RemoveLayers() error {
for _, layer := range append(m.Layers, m.Config) {
if err := layer.Remove(); err != nil {
return err
}
}
return nil
}
func ParseNamedManifest(n model.Name) (*Manifest, error) {
if !n.IsFullyQualified() {
return nil, model.Unqualified(n)
}
manifests, err := GetManifestPath()
@ -35,45 +96,115 @@ func ParseNamedManifest(name model.Name) (*Manifest, error) {
return nil, err
}
var manifest ManifestV2
manifestfile, err := os.Open(filepath.Join(manifests, name.Filepath()))
p := filepath.Join(manifests, n.Filepath())
var m ManifestV2
f, err := os.Open(p)
if err != nil {
return nil, err
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return nil, err
}
sha256sum := sha256.New()
if err := json.NewDecoder(io.TeeReader(manifestfile, sha256sum)).Decode(&manifest); err != nil {
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
return nil, err
}
return &Manifest{
ManifestV2: manifest,
Digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
ManifestV2: m,
filepath: p,
fi: fi,
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
}, nil
}
func WriteManifest(name string, config *Layer, layers []*Layer) error {
manifest := ManifestV2{
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
manifests, err := GetManifestPath()
if err != nil {
return err
}
p := filepath.Join(manifests, name.Filepath())
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
return err
}
f, err := os.Create(p)
if err != nil {
return err
}
defer f.Close()
m := ManifestV2{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config,
Layers: layers,
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(manifest); err != nil {
return err
}
modelpath := ParseModelPath(name)
manifestPath, err := modelpath.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
return err
}
return os.WriteFile(manifestPath, b.Bytes(), 0o644)
return json.NewEncoder(f).Encode(m)
}
func Manifests() (map[model.Name]*Manifest, error) {
manifests, err := GetManifestPath()
if err != nil {
return nil, err
}
// TODO(mxyng): use something less brittle
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
if err != nil {
return nil, err
}
ms := make(map[model.Name]*Manifest)
for _, match := range matches {
fi, err := os.Stat(match)
if err != nil {
return nil, err
}
if !fi.IsDir() {
rel, err := filepath.Rel(manifests, match)
if err != nil {
slog.Warn("bad filepath", "path", match, "error", err)
continue
}
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
slog.Warn("bad manifest name", "path", rel, "error", err)
continue
}
m, err := ParseNamedManifest(n)
if err != nil {
slog.Warn("bad manifest", "name", n, "error", err)
continue
}
ms[n] = m
}
}
return ms, nil
}
func GetManifestPath() (string, error) {
dir, err := modelsDir()
if err != nil {
return "", err
}
path := filepath.Join(dir, "manifests")
if err := os.MkdirAll(path, 0o755); err != nil {
return "", err
}
return path, nil
}

94
server/manifest_test.go Normal file
View File

@ -0,0 +1,94 @@
package server
import (
"encoding/json"
"os"
"path/filepath"
"slices"
"testing"
"github.com/ollama/ollama/types/model"
)
func createManifest(t *testing.T, path, name string) {
t.Helper()
p := filepath.Join(path, "manifests", name)
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
t.Fatal(err)
}
f, err := os.Create(p)
if err != nil {
t.Fatal(err)
}
defer f.Close()
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
t.Fatal(err)
}
}
func TestManifests(t *testing.T) {
cases := map[string][]string{
"empty": {},
"single": {
filepath.Join("host", "namespace", "model", "tag"),
},
"multiple": {
filepath.Join("registry.ollama.ai", "library", "llama3", "latest"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"),
},
"hidden": {
filepath.Join("host", "namespace", "model", "tag"),
filepath.Join("host", "namespace", "model", ".hidden"),
},
"subdir": {
filepath.Join("host", "namespace", "model", "tag", "one"),
filepath.Join("host", "namespace", "model", "tag", "another", "one"),
},
}
for n, wants := range cases {
t.Run(n, func(t *testing.T) {
d := t.TempDir()
t.Setenv("OLLAMA_MODELS", d)
for _, want := range wants {
createManifest(t, d, want)
}
ms, err := Manifests()
if err != nil {
t.Fatal(err)
}
var ns []model.Name
for k := range ms {
ns = append(ns, k)
}
for _, want := range wants {
n := model.ParseNameFromFilepath(want)
if !n.IsValid() && slices.Contains(ns, n) {
t.Errorf("unexpected invalid name: %s", want)
} else if n.IsValid() && !slices.Contains(ns, n) {
t.Errorf("missing valid name: %s", want)
}
}
})
}
}

View File

@ -23,16 +23,14 @@ type layerWithGGML struct {
}
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
modelpath := ParseModelPath(name.String())
manifest, _, err := GetManifest(modelpath)
m, err := ParseNamedManifest(name)
switch {
case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
if err := PullModel(ctx, name, registryOptions{}, fn); err != nil {
return nil, err
}
modelpath = ParseModelPath(name.String())
manifest, _, err = GetManifest(modelpath)
m, err = ParseNamedManifest(name)
if err != nil {
return nil, err
}
@ -40,8 +38,8 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
return nil, err
}
for _, layer := range manifest.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
for _, layer := range m.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
if err != nil {
return nil, err
}

View File

@ -2,105 +2,16 @@ package server
import (
"errors"
"fmt"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
)
type ModelPath struct {
ProtocolScheme string
Registry string
Namespace string
Repository string
Tag string
}
const (
DefaultRegistry = "registry.ollama.ai"
DefaultNamespace = "library"
DefaultTag = "latest"
DefaultProtocolScheme = "https"
)
var (
ErrInvalidImageFormat = errors.New("invalid image format")
ErrInvalidProtocol = errors.New("invalid protocol scheme")
ErrInsecureProtocol = errors.New("insecure protocol http")
ErrInvalidDigestFormat = errors.New("invalid digest format")
)
func ParseModelPath(name string) ModelPath {
mp := ModelPath{
ProtocolScheme: DefaultProtocolScheme,
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "",
Tag: DefaultTag,
}
before, after, found := strings.Cut(name, "://")
if found {
mp.ProtocolScheme = before
name = after
}
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
mp.Registry = parts[0]
mp.Namespace = parts[1]
mp.Repository = parts[2]
case 2:
mp.Namespace = parts[0]
mp.Repository = parts[1]
case 1:
mp.Repository = parts[0]
}
if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
mp.Repository = repo
mp.Tag = tag
}
return mp
}
var errModelPathInvalid = errors.New("invalid model path")
func (mp ModelPath) Validate() error {
if mp.Repository == "" {
return fmt.Errorf("%w: model repository name is required", errModelPathInvalid)
}
if strings.Contains(mp.Tag, ":") {
return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid)
}
return nil
}
func (mp ModelPath) GetNamespaceRepository() string {
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
}
func (mp ModelPath) GetFullTagname() string {
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
func (mp ModelPath) GetShortTagname() string {
if mp.Registry == DefaultRegistry {
if mp.Namespace == DefaultNamespace {
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
// modelsDir returns the value of the OLLAMA_MODELS environment variable or the user's home directory if OLLAMA_MODELS is not set.
// The models directory is where Ollama stores its model files and manifests.
func modelsDir() (string, error) {
@ -114,37 +25,6 @@ func modelsDir() (string, error) {
return filepath.Join(home, ".ollama", "models"), nil
}
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
func (mp ModelPath) GetManifestPath() (string, error) {
dir, err := modelsDir()
if err != nil {
return "", err
}
return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
}
func (mp ModelPath) BaseURL() *url.URL {
return &url.URL{
Scheme: mp.ProtocolScheme,
Host: mp.Registry,
}
}
func GetManifestPath() (string, error) {
dir, err := modelsDir()
if err != nil {
return "", err
}
path := filepath.Join(dir, "manifests")
if err := os.MkdirAll(path, 0o755); err != nil {
return "", err
}
return path, nil
}
func GetBlobsPath(digest string) (string, error) {
dir, err := modelsDir()
if err != nil {

View File

@ -68,88 +68,3 @@ func TestGetBlobsPath(t *testing.T) {
})
}
}
func TestParseModelPath(t *testing.T) {
tests := []struct {
name string
arg string
want ModelPath
}{
{
"full path https",
"https://example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"full path http",
"http://example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "http",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no protocol",
"example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no registry",
"ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no namespace",
"repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: "tag",
},
},
{
"no tag",
"repo",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: DefaultTag,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := ParseModelPath(tc.arg)
if got != tc.want {
t.Errorf("got: %q want: %q", got, tc.want)
}
})
}
}

View File

@ -75,45 +75,43 @@ func isSupportedImageType(image []byte) bool {
}
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
var r api.GenerateRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
n := model.ParseName(r.Model)
// validate the request
switch {
case req.Model == "":
case !n.IsValid():
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
case len(req.Format) > 0 && req.Format != "json":
case len(r.Format) > 0 && r.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
case r.Raw && (r.Template != "" || r.System != "" || len(r.Context) > 0):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return
}
for _, img := range req.Images {
for _, img := range r.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return
}
}
model, err := GetModel(req.Model)
model, err := GetModel(n)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", r.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@ -125,17 +123,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
opts, err := modelOptions(model, req.Options)
opts, err := modelOptions(model, r.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var sessionDuration time.Duration
if req.KeepAlive == nil {
if r.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
sessionDuration = r.KeepAlive.Duration
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
@ -150,10 +148,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// an empty request loads the model
// note: for a short while template was used in lieu
// of `raw` mode so we need to check for it too
if req.Prompt == "" && req.Template == "" && req.System == "" {
if r.Prompt == "" && r.Template == "" && r.System == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Model: r.Model,
Done: true,
DoneReason: "load",
})
@ -164,37 +162,37 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var prompt string
switch {
case req.Raw:
prompt = req.Prompt
case req.Prompt != "":
if req.Template == "" {
req.Template = model.Template
case r.Raw:
prompt = r.Prompt
case r.Prompt != "":
if r.Template == "" {
r.Template = model.Template
}
if req.System == "" {
req.System = model.System
if r.System == "" {
r.System = model.System
}
slog.Debug("generate handler", "prompt", req.Prompt)
slog.Debug("generate handler", "template", req.Template)
slog.Debug("generate handler", "system", req.System)
slog.Debug("generate handler", "prompt", r.Prompt)
slog.Debug("generate handler", "template", r.Template)
slog.Debug("generate handler", "system", r.System)
var sb strings.Builder
for i := range req.Images {
for i := range r.Images {
fmt.Fprintf(&sb, "[img-%d] ", i)
}
sb.WriteString(req.Prompt)
sb.WriteString(r.Prompt)
p, err := Prompt(req.Template, req.System, sb.String(), "", true)
p, err := Prompt(r.Template, r.System, sb.String(), "", true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sb.Reset()
if req.Context != nil {
prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
if r.Context != nil {
prev, err := runner.llama.Detokenize(c.Request.Context(), r.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@ -215,33 +213,33 @@ func (s *Server) GenerateHandler(c *gin.Context) {
go func() {
defer close(ch)
fn := func(r llm.CompletionResponse) {
fn := func(comp llm.CompletionResponse) {
// Build up the full response
if _, err := generated.WriteString(r.Content); err != nil {
if _, err := generated.WriteString(comp.Content); err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp := api.GenerateResponse{
Model: req.Model,
Model: r.Model,
CreatedAt: time.Now().UTC(),
Done: r.Done,
Response: r.Content,
DoneReason: r.DoneReason,
Done: comp.Done,
DoneReason: comp.DoneReason,
Response: comp.Content,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
PromptEvalCount: comp.PromptEvalCount,
PromptEvalDuration: comp.PromptEvalDuration,
EvalCount: comp.EvalCount,
EvalDuration: comp.EvalDuration,
},
}
if r.Done {
if comp.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
if !r.Raw {
p, err := Prompt(r.Template, r.System, r.Prompt, generated.String(), false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@ -254,7 +252,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
resp.Context = append(req.Context, tokens...)
resp.Context = append(r.Context, tokens...)
}
}
@ -262,17 +260,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
var images []llm.ImageData
for i := range req.Images {
for i := range r.Images {
images = append(images, llm.ImageData{
ID: i,
Data: req.Images[i],
Data: r.Images[i],
})
}
// Start prediction
req := llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Format: r.Format,
Images: images,
Options: opts,
}
@ -281,7 +279,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
}()
if req.Stream != nil && !*req.Stream {
if r.Stream != nil && !*r.Stream {
// Accumulate responses into the final response
var final api.GenerateResponse
var sb strings.Builder
@ -339,44 +337,43 @@ func getDefaultSessionDuration() time.Duration {
}
func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
var r api.EmbeddingRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
n := model.ParseName(r.Model)
if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", r.Model)})
return
}
model, err := GetModel(req.Model)
model, err := GetModel(n)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", r.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
opts, err := modelOptions(model, req.Options)
opts, err := modelOptions(model, r.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var sessionDuration time.Duration
if req.KeepAlive == nil {
if r.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
sessionDuration = r.KeepAlive.Duration
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
@ -389,12 +386,12 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
}
// an empty request loads the model
if req.Prompt == "" {
if r.Prompt == "" {
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
return
}
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
embedding, err := runner.llama.Embedding(c.Request.Context(), r.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
@ -408,24 +405,18 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
}
func (s *Server) PullModelHandler(c *gin.Context) {
var req api.PullRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
var r api.PullRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
n := model.ParseName(cmp.Or(r.Model, r.Name))
if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
return
}
@ -436,19 +427,15 @@ func (s *Server) PullModelHandler(c *gin.Context) {
ch <- r
}
regOpts := &registryOptions{
Insecure: req.Insecure,
}
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PullModel(ctx, model, regOpts, fn); err != nil {
if err := PullModel(ctx, n, registryOptions{Insecure: r.Insecure}, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
if r.Stream != nil && !*r.Stream {
waitForStream(c, ch)
return
}
@ -457,24 +444,18 @@ func (s *Server) PullModelHandler(c *gin.Context) {
}
func (s *Server) PushModelHandler(c *gin.Context) {
var req api.PushRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
var r api.PushRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
n := model.ParseName(cmp.Or(r.Model, r.Name))
if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
return
}
@ -485,19 +466,15 @@ func (s *Server) PushModelHandler(c *gin.Context) {
ch <- r
}
regOpts := &registryOptions{
Insecure: req.Insecure,
}
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PushModel(ctx, model, regOpts, fn); err != nil {
if err := PushModel(ctx, n, registryOptions{Insecure: r.Insecure}, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
if r.Stream != nil && !*r.Stream {
waitForStream(c, ch)
return
}
@ -506,8 +483,8 @@ func (s *Server) PushModelHandler(c *gin.Context) {
}
func (s *Server) CreateModelHandler(c *gin.Context) {
var req api.CreateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
var r api.CreateRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
@ -515,30 +492,30 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
return
}
name := model.ParseName(cmp.Or(req.Model, req.Name))
name := model.ParseName(cmp.Or(r.Model, r.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
return
}
if req.Path == "" && req.Modelfile == "" {
if r.Path == "" && r.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
return
}
var r io.Reader = strings.NewReader(req.Modelfile)
if req.Path != "" && req.Modelfile == "" {
f, err := os.Open(req.Path)
var rd io.Reader = strings.NewReader(r.Modelfile)
if r.Path != "" && r.Modelfile == "" {
f, err := os.Open(r.Path)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
return
}
defer f.Close()
r = f
rd = f
}
modelfile, err := model.ParseFile(r)
f, err := model.ParseFile(rd)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@ -554,17 +531,13 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
quantization := req.Quantization
if req.Quantize != "" {
quantization = req.Quantize
}
if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(quantization), modelfile, fn); err != nil {
quantization := cmp.Or(r.Quantize, r.Quantization)
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
if r.Stream != nil && !*r.Stream {
waitForStream(c, ch)
return
}
@ -573,75 +546,58 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
}
func (s *Server) DeleteModelHandler(c *gin.Context) {
var req api.DeleteRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
var r api.DeleteRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
n := model.ParseName(cmp.Or(r.Model, r.Name))
if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
return
}
if err := DeleteModel(model); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
manifestsPath, err := GetManifestPath()
m, err := ParseNamedManifest(n)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := PruneDirectory(manifestsPath); err != nil {
if err := m.Remove(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, nil)
if err := m.RemoveLayers(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
func (s *Server) ShowModelHandler(c *gin.Context) {
var req api.ShowRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
var r api.ShowRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Model != "" {
// noop
} else if req.Name != "" {
req.Model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
n := model.ParseName(cmp.Or(r.Model, r.Name))
if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
return
}
resp, err := GetModelInfo(req)
resp, err := GetModelInfo(n, r)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", r.Model)})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
@ -651,8 +607,8 @@ func (s *Server) ShowModelHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp)
}
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
model, err := GetModel(req.Model)
func GetModelInfo(name model.Name, req api.ShowRequest) (*api.ShowResponse, error) {
model, err := GetModel(name)
if err != nil {
return nil, err
}
@ -710,7 +666,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
var sb strings.Builder
fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"")
fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
fmt.Fprintf(&sb, "# FROM %s\n\n", model.ShortName)
fmt.Fprintf(&sb, "# FROM %s\n\n", name.DisplayShortest())
fmt.Fprint(&sb, model.String())
resp.Modelfile = sb.String()
@ -718,72 +674,42 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
func (s *Server) ListModelsHandler(c *gin.Context) {
manifests, err := GetManifestPath()
ms, err := Manifests()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var models []api.ModelResponse
if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error {
if !info.IsDir() {
rel, err := filepath.Rel(manifests, path)
if err != nil {
return err
}
for n, m := range ms {
f, err := m.Config.Open()
if err != nil {
slog.Warn("bad manifest filepath", "name", n, "error", err)
continue
}
defer f.Close()
if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil {
return err
} else if hidden {
return nil
}
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
slog.Warn("bad manifest filepath", "path", rel)
return nil
}
m, err := ParseNamedManifest(n)
if err != nil {
slog.Warn("bad manifest", "name", n, "error", err)
return nil
}
f, err := m.Config.Open()
if err != nil {
slog.Warn("bad manifest config filepath", "name", n, "error", err)
return nil
}
defer f.Close()
var c ConfigV2
if err := json.NewDecoder(f).Decode(&c); err != nil {
slog.Warn("bad manifest config", "name", n, "error", err)
return nil
}
// tag should never be masked
models = append(models, api.ModelResponse{
Model: n.DisplayShortest(),
Name: n.DisplayShortest(),
Size: m.Size(),
Digest: m.Digest,
ModifiedAt: info.ModTime(),
Details: api.ModelDetails{
Format: c.ModelFormat,
Family: c.ModelFamily,
Families: c.ModelFamilies,
ParameterSize: c.ModelType,
QuantizationLevel: c.FileType,
},
})
var cf ConfigV2
if err := json.NewDecoder(f).Decode(&cf); err != nil {
slog.Warn("bad manifest config", "name", n, "error", err)
continue
}
return nil
}); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
// tag should never be masked
models = append(models, api.ModelResponse{
Model: n.DisplayShortest(),
Name: n.DisplayShortest(),
Size: m.Size(),
Digest: m.digest,
ModifiedAt: m.fi.ModTime(),
Details: api.ModelDetails{
Format: cf.ModelFormat,
Family: cf.ModelFamily,
Families: cf.ModelFamilies,
ParameterSize: cf.ModelType,
QuantizationLevel: cf.FileType,
},
})
}
slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
@ -1029,17 +955,15 @@ func Serve(ln net.Listener) error {
if !envconfig.NoPrune {
// clean up unused layers and manifests
if err := PruneLayers(); err != nil {
return err
}
manifestsPath, err := GetManifestPath()
layers, err := Layers()
if err != nil {
return err
}
if err := PruneDirectory(manifestsPath); err != nil {
return err
for _, layer := range layers {
if err := layer.Remove(); err != nil {
return err
}
}
}
@ -1155,27 +1079,21 @@ func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.ChatRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// validate the request
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
n := model.ParseName(req.Model)
if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", req.Model)})
return
}
model, err := GetModel(req.Model)
model, err := GetModel(n)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {

View File

@ -53,6 +53,8 @@ func Test_Routes(t *testing.T) {
}
createTestModel := func(t *testing.T, name string) {
t.Helper()
fname := createTestFile(t, "ollama-model")
r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
@ -61,7 +63,7 @@ func Test_Routes(t *testing.T) {
fn := func(resp api.ProgressResponse) {
t.Logf("Status: %s", resp.Status)
}
err = CreateModel(context.TODO(), name, "", "", modelfile, fn)
err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn)
assert.Nil(t, err)
}
@ -144,9 +146,9 @@ func Test_Routes(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, resp.StatusCode, 200)
model, err := GetModel("t-bone")
m, err := GetModel(model.ParseName("t-bone"))
assert.Nil(t, err)
assert.Equal(t, "t-bone:latest", model.ShortName)
assert.Equal(t, "t-bone:latest", m.Name.DisplayShortest())
},
},
{
@ -165,9 +167,9 @@ func Test_Routes(t *testing.T) {
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
model, err := GetModel("beefsteak")
m, err := GetModel(model.ParseName("beefsteak"))
assert.Nil(t, err)
assert.Equal(t, "beefsteak:latest", model.ShortName)
assert.Equal(t, "beefsteak:latest", m.Name.DisplayShortest())
},
},
{

View File

@ -310,7 +310,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.Name.DisplayShortest())
}
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
req.errCh <- err

View File

@ -16,6 +16,7 @@ import (
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/server/envconfig"
"github.com/ollama/ollama/types/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -107,12 +108,12 @@ func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.
return scenario.srv, nil
}
func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
func newScenario(t *testing.T, ctx context.Context, name model.Name, estimatedVRAM uint64) *bundle {
scenario := &bundle{}
scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
t.Helper()
f, err := os.CreateTemp(t.TempDir(), modelName)
f, err := os.CreateTemp(t.TempDir(), name.Model)
assert.Nil(t, err)
defer f.Close()
@ -134,7 +135,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
assert.Nil(t, err)
fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname}
model := &Model{Name: name, ModelPath: fname}
scenario.ggml, err = llm.LoadModel(model.ModelPath)
require.NoError(t, err)
@ -155,24 +156,24 @@ func TestRequests(t *testing.T) {
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
scenario1a := newScenario(t, ctx, model.ParseName("ollama-model-1"), 10)
scenario1a.req.sessionDuration = 0
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
scenario1b := newScenario(t, ctx, model.ParseName("ollama-model-1"), 11)
scenario1b.req.model = scenario1a.req.model
scenario1b.ggml = scenario1a.ggml
scenario1b.req.sessionDuration = 0
// simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
scenario2a := newScenario(t, ctx, model.ParseName("ollama-model-1"), 20)
scenario2a.req.model = scenario1a.req.model
scenario2a.ggml = scenario1a.ggml
// Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
scenario3a := newScenario(t, ctx, model.ParseName("ollama-model-3a"), 1*format.GigaByte)
scenario3b := newScenario(t, ctx, model.ParseName("ollama-model-3b"), 24*format.GigaByte)
scenario3c := newScenario(t, ctx, model.ParseName("ollama-model-4a"), 30)
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
scenario3d := newScenario(t, ctx, model.ParseName("ollama-model-3c"), 30) // Needs prior unloaded
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
@ -310,11 +311,11 @@ func TestGetRunner(t *testing.T) {
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
scenario1a := newScenario(t, ctx, model.ParseName("ollama-model-1a"), 10)
scenario1a.req.sessionDuration = 0
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
scenario1b := newScenario(t, ctx, model.ParseName("ollama-model-1b"), 10)
scenario1b.req.sessionDuration = 0
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c := newScenario(t, ctx, model.ParseName("ollama-model-1c"), 10)
scenario1c.req.sessionDuration = 0
envconfig.MaxQueuedRequests = 1
s := InitScheduler(ctx)
@ -370,7 +371,7 @@ func TestPrematureExpired(t *testing.T) {
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
scenario1a := newScenario(t, ctx, model.ParseName("ollama-model-1a"), 10)
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}

View File

@ -12,12 +12,14 @@ import (
"net/http"
"net/url"
"os"
"path"
"sync"
"sync/atomic"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/types/model"
"golang.org/x/sync/errgroup"
)
@ -55,9 +57,10 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
}
if b.From != "" {
n := model.ParseName(b.From)
values := requestURL.Query()
values.Add("mount", b.Digest)
values.Add("from", ParseModelPath(b.From).GetNamespaceRepository())
values.Add("from", path.Join(n.Namespace, n.Model))
requestURL.RawQuery = values.Encode()
}
@ -360,40 +363,46 @@ func (p *progressWriter) Rollback() {
p.written = 0
}
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
type uploadOptions struct {
name model.Name
baseURL *url.URL
layer *Layer
regOpts *registryOptions
fn func(api.ProgressResponse)
}
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
func uploadBlob(ctx context.Context, opts uploadOptions) error {
requestURL := opts.baseURL.JoinPath("v2", opts.name.Namespace, opts.name.Model, "blobs", opts.layer.Digest)
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts.regOpts)
switch {
case errors.Is(err, os.ErrNotExist):
case err != nil:
return err
default:
defer resp.Body.Close()
fn(api.ProgressResponse{
Status: fmt.Sprintf("pushing %s", layer.Digest[7:19]),
Digest: layer.Digest,
Total: layer.Size,
Completed: layer.Size,
opts.fn(api.ProgressResponse{
Status: fmt.Sprintf("pushing %s", opts.layer.Digest[7:19]),
Digest: opts.layer.Digest,
Total: opts.layer.Size,
Completed: opts.layer.Size,
})
return nil
}
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
data, ok := blobUploadManager.LoadOrStore(opts.layer.Digest, &blobUpload{Layer: opts.layer})
upload := data.(*blobUpload)
if !ok {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
blobUploadManager.Delete(layer.Digest)
requestURL := opts.baseURL.JoinPath("v2", opts.name.Namespace, opts.name.Model, "blobs/uploads/")
if err := upload.Prepare(ctx, requestURL, opts.regOpts); err != nil {
blobUploadManager.Delete(opts.layer.Digest)
return err
}
// nolint: contextcheck
go upload.Run(context.Background(), opts)
//nolint:contextcheck
go upload.Run(context.Background(), opts.regOpts)
}
return upload.Wait(ctx, fn)
return upload.Wait(ctx, opts.fn)
}