Compare commits

...

2 Commits

Author SHA1 Message Date
Michael Yang
6761aca1e1 update pull handler to use model.Name 2024-08-28 14:00:40 -07:00
Michael Yang
3e24edd9ed update push to use model.Name 2024-08-28 12:07:17 -07:00
5 changed files with 92 additions and 138 deletions

View File

@ -24,6 +24,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/types/model"
) )
const maxRetries = 6 const maxRetries = 6
@ -451,15 +452,16 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
} }
} }
type downloadOpts struct { type downloadOptions struct {
mp ModelPath name model.Name
baseURL *url.URL
digest string digest string
regOpts *registryOptions regOpts *registryOptions
fn func(api.ProgressResponse) fn func(api.ProgressResponse)
} }
// downloadBlob downloads a blob from the registry and stores it in the blobs directory // downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) { func downloadBlob(ctx context.Context, opts downloadOptions) (cacheHit bool, _ error) {
fp, err := GetBlobsPath(opts.digest) fp, err := GetBlobsPath(opts.digest)
if err != nil { if err != nil {
return false, err return false, err
@ -484,8 +486,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest}) data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
download := data.(*blobDownload) download := data.(*blobDownload)
if !ok { if !ok {
requestURL := opts.mp.BaseURL() requestURL := opts.baseURL.JoinPath("blobs", opts.digest)
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil { if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
blobDownloadManager.Delete(opts.digest) blobDownloadManager.Delete(opts.digest)
return false, err return false, err

View File

@ -16,6 +16,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"path"
"path/filepath" "path/filepath"
"runtime" "runtime"
"slices" "slices"
@ -795,45 +796,40 @@ func PruneDirectory(path string) error {
return nil return nil
} }
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { func PushModel(ctx context.Context, name model.Name, opts registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name) m, err := ParseNamedManifest(name)
fn(api.ProgressResponse{Status: "retrieving manifest"})
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return errors.New("insecure protocol http")
}
manifest, _, err := GetManifest(mp)
if err != nil { if err != nil {
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
return err return err
} }
var layers []Layer scheme := "https"
layers = append(layers, manifest.Layers...) if opts.Insecure {
if manifest.Config.Digest != "" { scheme = "http"
layers = append(layers, manifest.Config)
} }
for _, layer := range layers { baseURL, err := url.Parse(fmt.Sprintf("%s://%s", scheme, path.Join(name.Host, "v2", name.Namespace, name.Model)))
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil { if err != nil {
return err
}
for _, layer := range append(m.Layers, m.Config) {
if err := uploadBlob(ctx, uploadOptions{name: name, baseURL: baseURL, layer: layer, regOpts: &opts, fn: fn}); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err)) slog.Info(fmt.Sprintf("error uploading blob: %v", err))
return err return err
} }
} }
fn(api.ProgressResponse{Status: "pushing manifest"}) fn(api.ProgressResponse{Status: "pushing manifest"})
requestURL := mp.BaseURL() requestURL := baseURL.JoinPath("manifests", name.Tag)
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
manifestJSON, err := json.Marshal(manifest) manifestJSON, err := json.Marshal(m)
if err != nil { if err != nil {
return err return err
} }
headers := make(http.Header) headers := make(http.Header)
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json") headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts) resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), &opts)
if err != nil { if err != nil {
return err return err
} }
@ -844,118 +840,83 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return nil return nil
} }
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { func PullModel(ctx context.Context, name model.Name, opts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name) mm, _ := ParseNamedManifest(name)
// build deleteMap to prune unused layers scheme := "https"
deleteMap := make(map[string]struct{}) if opts.Insecure {
manifest, _, err := GetManifest(mp) scheme = "http"
if errors.Is(err, os.ErrNotExist) {
// noop
} else if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
} else {
for _, l := range manifest.Layers {
deleteMap[l.Digest] = struct{}{}
}
if manifest.Config.Digest != "" {
deleteMap[manifest.Config.Digest] = struct{}{}
}
} }
if mp.ProtocolScheme == "http" && !regOpts.Insecure { baseURL, err := url.Parse(fmt.Sprintf("%s://%s", scheme, path.Join(name.Host, "v2", name.Namespace, name.Model)))
return errors.New("insecure protocol http") if err != nil {
return err
} }
fn(api.ProgressResponse{Status: "pulling manifest"}) fn(api.ProgressResponse{Status: "pulling manifest"})
m, err := pullModelManifest(ctx, name, baseURL, opts)
manifest, err = pullModelManifest(ctx, mp, regOpts)
if err != nil { if err != nil {
return fmt.Errorf("pull model manifest: %s", err) return fmt.Errorf("pull model manifest: %s", err)
} }
var layers []Layer layers := append(m.Layers, m.Config)
layers = append(layers, manifest.Layers...)
if manifest.Config.Digest != "" {
layers = append(layers, manifest.Config)
}
skipVerify := make(map[string]bool) skipVerify := make(map[string]bool)
for _, layer := range layers { for _, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{ hit, err := downloadBlob(ctx, downloadOptions{
mp: mp, name: name,
baseURL: baseURL,
digest: layer.Digest, digest: layer.Digest,
regOpts: regOpts, regOpts: opts,
fn: fn, fn: fn,
}) })
if err != nil { if err != nil {
return err return err
} }
skipVerify[layer.Digest] = cacheHit
delete(deleteMap, layer.Digest) skipVerify[layer.Digest] = hit
} }
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "verifying sha256 digest"}) fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers { for _, layer := range layers {
if skipVerify[layer.Digest] { if !skipVerify[layer.Digest] {
continue if err := verifyBlob(layer.Digest); errors.Is(err, errDigestMismatch) {
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob // something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest) fp, err := GetBlobsPath(layer.Digest)
if err != nil { if err != nil {
return err return err
} }
if err := os.Remove(fp); err != nil { if err := os.Remove(fp); err != nil {
// log this, but return the original error // log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err)) slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
} }
} else if err != nil {
return err
} }
return err
} }
} }
fn(api.ProgressResponse{Status: "writing manifest"}) fn(api.ProgressResponse{Status: "writing manifest"})
if err := WriteManifest(name, m.Config, m.Layers); err != nil {
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err return err
} }
fp, err := mp.GetManifestPath() if !envconfig.NoPrune() && mm != nil {
if err != nil { fn(api.ProgressResponse{Status: "pruning old layers"})
return err _ = mm.RemoveLayers()
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
return err
}
err = os.WriteFile(fp, manifestJSON, 0o644)
if err != nil {
slog.Info(fmt.Sprintf("couldn't write to %s", fp))
return err
}
if !envconfig.NoPrune() && len(deleteMap) > 0 {
fn(api.ProgressResponse{Status: "removing unused layers"})
if err := deleteUnusedLayers(deleteMap); err != nil {
fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't remove unused layers: %v", err)})
}
} }
fn(api.ProgressResponse{Status: "success"}) fn(api.ProgressResponse{Status: "success"})
return nil return nil
} }
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) { func pullModelManifest(ctx context.Context, name model.Name, baseURL *url.URL, opts *registryOptions) (*Manifest, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) requestURL := baseURL.JoinPath("manifests", name.Tag)
headers := make(http.Header) headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json") headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts) resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1105,6 +1066,7 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
return nil, err return nil, err
} }
slog.Debug("request upstream", "method", method, "request", requestURL.Redacted(), "status", resp.StatusCode)
return resp, nil return resp, nil
} }

View File

@ -34,7 +34,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
m, err := ParseNamedManifest(name) m, err := ParseNamedManifest(name)
switch { switch {
case errors.Is(err, os.ErrNotExist): case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil { if err := PullModel(ctx, name, &registryOptions{}, fn); err != nil {
return nil, err return nil, err
} }

View File

@ -464,24 +464,22 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
} }
func (s *Server) PullHandler(c *gin.Context) { func (s *Server) PullHandler(c *gin.Context) {
var req api.PullRequest var r api.PullRequest
err := c.ShouldBindJSON(&req) if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
case err != nil: } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
name := model.ParseName(cmp.Or(req.Model, req.Name)) n := model.ParseName(cmp.Or(r.Model, r.Name))
if !name.IsValid() { if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
return return
} }
if err := checkNameExists(name); err != nil { if err := checkNameExists(n); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
@ -493,19 +491,15 @@ func (s *Server) PullHandler(c *gin.Context) {
ch <- r ch <- r
} }
regOpts := &registryOptions{
Insecure: req.Insecure,
}
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil { if err := PullModel(ctx, n, &registryOptions{Insecure: r.Insecure}, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
if req.Stream != nil && !*req.Stream { if r.Stream != nil && !*r.Stream {
waitForStream(c, ch) waitForStream(c, ch)
return return
} }
@ -514,24 +508,18 @@ func (s *Server) PullHandler(c *gin.Context) {
} }
func (s *Server) PushHandler(c *gin.Context) { func (s *Server) PushHandler(c *gin.Context) {
var req api.PushRequest var r api.PushRequest
err := c.ShouldBindJSON(&req) if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
case err != nil: } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
var model string n := model.ParseName(cmp.Or(r.Model, r.Name))
if req.Model != "" { if !n.IsValid() {
model = req.Model c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
} }
@ -542,19 +530,15 @@ func (s *Server) PushHandler(c *gin.Context) {
ch <- r ch <- r
} }
regOpts := &registryOptions{
Insecure: req.Insecure,
}
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := PushModel(ctx, model, regOpts, fn); err != nil { if err := PushModel(ctx, n, registryOptions{Insecure: r.Insecure}, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
if req.Stream != nil && !*req.Stream { if r.Stream != nil && !*r.Stream {
waitForStream(c, ch) waitForStream(c, ch)
return return
} }

View File

@ -21,6 +21,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/types/model"
) )
var blobUploadManager sync.Map var blobUploadManager sync.Map
@ -108,7 +109,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
offset += size offset += size
} }
slog.Info(fmt.Sprintf("uploading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size))) slog.Info("uploading blob", "digest", b.Digest, "size", format.HumanBytes(b.Total), "parts", len(b.Parts), "size per part", format.HumanBytes(b.Parts[0].Size))
requestURL, err = url.Parse(location) requestURL, err = url.Parse(location)
if err != nil { if err != nil {
@ -362,40 +363,46 @@ func (p *progressWriter) Rollback() {
p.written = 0 p.written = 0
} }
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error { type uploadOptions struct {
requestURL := mp.BaseURL() name model.Name
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest) baseURL *url.URL
layer Layer
regOpts *registryOptions
fn func(api.ProgressResponse)
}
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts) func uploadBlob(ctx context.Context, opts uploadOptions) error {
requestURL := opts.baseURL.JoinPath("blobs", opts.layer.Digest)
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts.regOpts)
switch { switch {
case errors.Is(err, os.ErrNotExist): case errors.Is(err, os.ErrNotExist):
case err != nil: case err != nil:
return err return err
default: default:
defer resp.Body.Close() defer resp.Body.Close()
fn(api.ProgressResponse{ opts.fn(api.ProgressResponse{
Status: fmt.Sprintf("pushing %s", layer.Digest[7:19]), Status: fmt.Sprintf("pushing %s", opts.layer.Digest[7:19]),
Digest: layer.Digest, Digest: opts.layer.Digest,
Total: layer.Size, Total: opts.layer.Size,
Completed: layer.Size, Completed: opts.layer.Size,
}) })
return nil return nil
} }
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer}) data, ok := blobUploadManager.LoadOrStore(opts.layer.Digest, &blobUpload{Layer: opts.layer})
upload := data.(*blobUpload) upload := data.(*blobUpload)
if !ok { if !ok {
requestURL := mp.BaseURL() requestURL := opts.baseURL.JoinPath("blobs", "uploads")
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/") if err := upload.Prepare(ctx, requestURL, opts.regOpts); err != nil {
if err := upload.Prepare(ctx, requestURL, opts); err != nil { blobUploadManager.Delete(opts.layer.Digest)
blobUploadManager.Delete(layer.Digest)
return err return err
} }
//nolint:contextcheck //nolint:contextcheck
go upload.Run(context.Background(), opts) go upload.Run(context.Background(), opts.regOpts)
} }
return upload.Wait(ctx, fn) return upload.Wait(ctx, opts.fn)
} }