update pull handler to use model.Name
This commit is contained in:
parent
b91cf0893d
commit
0aeaeaa058
@ -23,6 +23,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const maxRetries = 6
|
||||
@ -332,15 +333,16 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||
}
|
||||
}
|
||||
|
||||
type downloadOpts struct {
|
||||
mp ModelPath
|
||||
type downloadOptions struct {
|
||||
name model.Name
|
||||
baseURL *url.URL
|
||||
digest string
|
||||
regOpts *registryOptions
|
||||
fn func(api.ProgressResponse)
|
||||
}
|
||||
|
||||
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
||||
func downloadBlob(ctx context.Context, opts downloadOpts) error {
|
||||
func downloadBlob(ctx context.Context, opts downloadOptions) error {
|
||||
fp, err := GetBlobsPath(opts.digest)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -365,8 +367,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
|
||||
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
|
||||
download := data.(*blobDownload)
|
||||
if !ok {
|
||||
requestURL := opts.mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
|
||||
requestURL := opts.baseURL.JoinPath("blobs", opts.digest)
|
||||
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
|
||||
blobDownloadManager.Delete(opts.digest)
|
||||
return err
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@ -781,59 +782,43 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
return nil
|
||||
}
|
||||
|
||||
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
func PullModel(ctx context.Context, name model.Name, opts registryOptions, fn func(api.ProgressResponse)) error {
|
||||
old, _ := ParseNamedManifest(name)
|
||||
|
||||
var manifest *ManifestV2
|
||||
var err error
|
||||
var noprune string
|
||||
|
||||
// build deleteMap to prune unused layers
|
||||
deleteMap := make(map[string]struct{})
|
||||
|
||||
if !envconfig.NoPrune {
|
||||
manifest, _, err = GetManifest(mp)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
|
||||
if manifest != nil {
|
||||
for _, l := range manifest.Layers {
|
||||
deleteMap[l.Digest] = struct{}{}
|
||||
}
|
||||
deleteMap[manifest.Config.Digest] = struct{}{}
|
||||
}
|
||||
if !name.IsFullyQualified() {
|
||||
return model.Unqualified(name)
|
||||
}
|
||||
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return fmt.Errorf("insecure protocol http")
|
||||
scheme := "https"
|
||||
if opts.Insecure {
|
||||
scheme = "http"
|
||||
}
|
||||
|
||||
baseURL, err := url.Parse(fmt.Sprintf("%s://%s", scheme, path.Join(name.Host, "v2", name.Namespace, name.Model)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||
|
||||
manifest, err = pullModelManifest(ctx, mp, regOpts)
|
||||
m, err := pullModelManifest(ctx, name, baseURL, &opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pull model manifest: %s", err)
|
||||
}
|
||||
|
||||
var layers []*Layer
|
||||
layers = append(layers, manifest.Layers...)
|
||||
layers = append(layers, manifest.Config)
|
||||
|
||||
layers := append(m.Layers, m.Config)
|
||||
for _, layer := range layers {
|
||||
if err := downloadBlob(
|
||||
ctx,
|
||||
downloadOpts{
|
||||
mp: mp,
|
||||
downloadOptions{
|
||||
name: name,
|
||||
baseURL: baseURL,
|
||||
digest: layer.Digest,
|
||||
regOpts: regOpts,
|
||||
regOpts: &opts,
|
||||
fn: fn,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
delete(deleteMap, layer.Digest)
|
||||
}
|
||||
delete(deleteMap, manifest.Config.Digest)
|
||||
|
||||
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
|
||||
for _, layer := range layers {
|
||||
@ -854,45 +839,25 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
if err != nil {
|
||||
if err := WriteManifest(name, m.Config, m.Layers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = os.WriteFile(fp, manifestJSON, 0o644)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("couldn't write to %s", fp))
|
||||
return err
|
||||
}
|
||||
|
||||
if noprune == "" {
|
||||
if !envconfig.NoPrune && old != nil {
|
||||
fn(api.ProgressResponse{Status: "removing any unused layers"})
|
||||
err = deleteUnusedLayers(nil, deleteMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = old.RemoveLayers()
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) {
|
||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
func pullModelManifest(ctx context.Context, name model.Name, baseURL *url.URL, opts *registryOptions) (*ManifestV2, error) {
|
||||
requestURL := baseURL.JoinPath("manifests", name.Tag)
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||
m, err := ParseNamedManifest(name)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil {
|
||||
if err := PullModel(ctx, name, registryOptions{}, fn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -408,24 +408,18 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (s *Server) PullModelHandler(c *gin.Context) {
|
||||
var req api.PullRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
var r api.PullRequest
|
||||
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
case err != nil:
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var model string
|
||||
if req.Model != "" {
|
||||
model = req.Model
|
||||
} else if req.Name != "" {
|
||||
model = req.Name
|
||||
} else {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
n := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||
if !n.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
|
||||
return
|
||||
}
|
||||
|
||||
@ -436,19 +430,15 @@ func (s *Server) PullModelHandler(c *gin.Context) {
|
||||
ch <- r
|
||||
}
|
||||
|
||||
regOpts := ®istryOptions{
|
||||
Insecure: req.Insecure,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
if err := PullModel(ctx, model, regOpts, fn); err != nil {
|
||||
if err := PullModel(ctx, n, registryOptions{Insecure: r.Insecure}, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
if r.Stream != nil && !*r.Stream {
|
||||
waitForStream(c, ch)
|
||||
return
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user