diff --git a/app/lifecycle/server_windows.go b/app/lifecycle/server_windows.go index cd4244ff..5f9fe124 100644 --- a/app/lifecycle/server_windows.go +++ b/app/lifecycle/server_windows.go @@ -24,7 +24,8 @@ func terminate(cmd *exec.Cmd) error { if err != nil { return err } - defer dll.Release() // nolint: errcheck + //nolint:errcheck + defer dll.Release() pid := cmd.Process.Pid @@ -73,7 +74,8 @@ func isProcessExited(pid int) (bool, error) { if err != nil { return false, fmt.Errorf("failed to open process: %v", err) } - defer windows.CloseHandle(hProcess) // nolint: errcheck + //nolint:errcheck + defer windows.CloseHandle(hProcess) var exitCode uint32 err = windows.GetExitCodeProcess(hProcess, &exitCode) diff --git a/readline/readline.go b/readline/readline.go index 6fa45391..b9d34401 100644 --- a/readline/readline.go +++ b/readline/readline.go @@ -81,7 +81,7 @@ func (i *Instance) Readline() (string, error) { defer func() { fd := int(syscall.Stdin) - // nolint: errcheck + //nolint:errcheck UnsetRawMode(fd, i.Terminal.termios) i.Terminal.rawmode = false }() diff --git a/server/download.go b/server/download.go index 935af9c1..faa06dd2 100644 --- a/server/download.go +++ b/server/download.go @@ -372,7 +372,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error { return err } - // nolint: contextcheck + //nolint:contextcheck go download.Run(context.Background(), requestURL, opts.regOpts) } diff --git a/server/images.go b/server/images.go index 94057a49..67253ec7 100644 --- a/server/images.go +++ b/server/images.go @@ -314,7 +314,7 @@ func realpath(rel, from string) string { return abspath } -func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) { +func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) { config := ConfigV2{ OS: "linux", Architecture: "amd64", @@ -546,16 +546,10 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m } } - unref := make(map[string]struct{}) - if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil { - for _, layer := range manifest.Layers { - if !slices.Contains(digests, layer.Digest) { - unref[layer.Digest] = struct{}{} - } - } - - if manifest.Config.Digest != layer.Digest { - unref[manifest.Config.Digest] = struct{}{} + if !envconfig.NoPrune { + if old, err := ParseNamedManifest(name); err == nil { + //nolint:errcheck + defer old.RemoveLayers() } } @@ -564,12 +558,6 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m return err } - if !envconfig.NoPrune { - if err := deleteUnusedLayers(nil, unref); err != nil { - return err - } - } - fn(api.ProgressResponse{Status: "success"}) return nil } @@ -637,7 +625,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}) // save (i.e. delete from the deleteMap) any files used in other manifests manifest, _, err := GetManifest(fmp) if err != nil { - // nolint: nilerr + //nolint:nilerr return nil } diff --git a/server/manifest.go b/server/manifest.go index a5251298..d0675724 100644 --- a/server/manifest.go +++ b/server/manifest.go @@ -1,7 +1,6 @@ package server import ( - "bytes" "crypto/sha256" "encoding/json" "fmt" @@ -34,12 +33,6 @@ func (m *Manifest) Remove() error { return err } - for _, layer := range append(m.Layers, m.Config) { - if err := layer.Remove(); err != nil { - return err - } - } - manifests, err := GetManifestPath() if err != nil { return err @@ -48,6 +41,16 @@ func (m *Manifest) Remove() error { return PruneDirectory(manifests) } +func (m *Manifest) RemoveLayers() error { + for _, layer := range append(m.Layers, m.Config) { + if err := layer.Remove(); err != nil { + return err + } + } + + return nil +} + func ParseNamedManifest(n model.Name) (*Manifest, error) { if !n.IsFullyQualified() { return nil, model.Unqualified(n) @@ -85,30 +88,31 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) { }, nil } -func WriteManifest(name string, config *Layer, layers []*Layer) error { - manifest := ManifestV2{ +func WriteManifest(name model.Name, config *Layer, layers []*Layer) error { + manifests, err := GetManifestPath() + if err != nil { + return err + } + + p := filepath.Join(manifests, name.Filepath()) + if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil { + return err + } + + f, err := os.Create(p) + if err != nil { + return err + } + defer f.Close() + + m := ManifestV2{ SchemaVersion: 2, MediaType: "application/vnd.docker.distribution.manifest.v2+json", Config: config, Layers: layers, } - var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(manifest); err != nil { - return err - } - - modelpath := ParseModelPath(name) - manifestPath, err := modelpath.GetManifestPath() - if err != nil { - return err - } - - if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil { - return err - } - - return os.WriteFile(manifestPath, b.Bytes(), 0o644) + return json.NewEncoder(f).Encode(m) } func Manifests() (map[model.Name]*Manifest, error) { diff --git a/server/model.go b/server/model.go index eea5d13a..d1cacfe1 100644 --- a/server/model.go +++ b/server/model.go @@ -23,16 +23,14 @@ type layerWithGGML struct { } func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { - modelpath := ParseModelPath(name.String()) - manifest, _, err := GetManifest(modelpath) + m, err := ParseNamedManifest(name) switch { case errors.Is(err, os.ErrNotExist): if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil { return nil, err } - modelpath = ParseModelPath(name.String()) - manifest, _, err = GetManifest(modelpath) + m, err = ParseNamedManifest(name) if err != nil { return nil, err } @@ -40,8 +38,8 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe return nil, err } - for _, layer := range manifest.Layers { - layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname()) + for _, layer := range m.Layers { + layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest()) if err != nil { return nil, err } diff --git a/server/routes.go b/server/routes.go index 5d6770c4..bf15079c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -506,8 +506,8 @@ func (s *Server) PushModelHandler(c *gin.Context) { } func (s *Server) CreateModelHandler(c *gin.Context) { - var req api.CreateRequest - if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { + var r api.CreateRequest + if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return } else if err != nil { @@ -515,30 +515,30 @@ func (s *Server) CreateModelHandler(c *gin.Context) { return } - name := model.ParseName(cmp.Or(req.Model, req.Name)) + name := model.ParseName(cmp.Or(r.Model, r.Name)) if !name.IsValid() { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"}) return } - if req.Path == "" && req.Modelfile == "" { + if r.Path == "" && r.Modelfile == "" { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"}) return } - var r io.Reader = strings.NewReader(req.Modelfile) - if req.Path != "" && req.Modelfile == "" { - f, err := os.Open(req.Path) + var rd io.Reader = strings.NewReader(r.Modelfile) + if r.Path != "" && r.Modelfile == "" { + f, err := os.Open(r.Path) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)}) return } defer f.Close() - r = f + rd = f } - modelfile, err := model.ParseFile(r) + f, err := model.ParseFile(rd) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return @@ -554,17 +554,13 @@ func (s *Server) CreateModelHandler(c *gin.Context) { ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() - quantization := req.Quantization - if req.Quantize != "" { - quantization = req.Quantize - } - - if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(quantization), modelfile, fn); err != nil { + quantization := cmp.Or(r.Quantize, r.Quantization) + if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() - if req.Stream != nil && !*req.Stream { + if r.Stream != nil && !*r.Stream { waitForStream(c, ch) return } @@ -598,6 +594,11 @@ func (s *Server) DeleteModelHandler(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + + if err := m.RemoveLayers(); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } } func (s *Server) ShowModelHandler(c *gin.Context) { diff --git a/server/routes_test.go b/server/routes_test.go index 896dc27b..a5e9da23 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -53,6 +53,8 @@ func Test_Routes(t *testing.T) { } createTestModel := func(t *testing.T, name string) { + t.Helper() + fname := createTestFile(t, "ollama-model") r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname)) @@ -61,7 +63,7 @@ func Test_Routes(t *testing.T) { fn := func(resp api.ProgressResponse) { t.Logf("Status: %s", resp.Status) } - err = CreateModel(context.TODO(), name, "", "", modelfile, fn) + err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn) assert.Nil(t, err) } diff --git a/server/upload.go b/server/upload.go index 9b52238a..aa775518 100644 --- a/server/upload.go +++ b/server/upload.go @@ -391,7 +391,7 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryO return err } - // nolint: contextcheck + //nolint:contextcheck go upload.Run(context.Background(), opts) }