update create handler to use model.Name

This commit is contained in:
Michael Yang 2024-05-08 14:36:08 -07:00
parent 8215545c6d
commit b91cf0893d
9 changed files with 66 additions and 71 deletions

View File

@ -24,7 +24,8 @@ func terminate(cmd *exec.Cmd) error {
if err != nil { if err != nil {
return err return err
} }
defer dll.Release() // nolint: errcheck //nolint:errcheck
defer dll.Release()
pid := cmd.Process.Pid pid := cmd.Process.Pid
@ -73,7 +74,8 @@ func isProcessExited(pid int) (bool, error) {
if err != nil { if err != nil {
return false, fmt.Errorf("failed to open process: %v", err) return false, fmt.Errorf("failed to open process: %v", err)
} }
defer windows.CloseHandle(hProcess) // nolint: errcheck //nolint:errcheck
defer windows.CloseHandle(hProcess)
var exitCode uint32 var exitCode uint32
err = windows.GetExitCodeProcess(hProcess, &exitCode) err = windows.GetExitCodeProcess(hProcess, &exitCode)

View File

@ -81,7 +81,7 @@ func (i *Instance) Readline() (string, error) {
defer func() { defer func() {
fd := int(syscall.Stdin) fd := int(syscall.Stdin)
// nolint: errcheck //nolint:errcheck
UnsetRawMode(fd, i.Terminal.termios) UnsetRawMode(fd, i.Terminal.termios)
i.Terminal.rawmode = false i.Terminal.rawmode = false
}() }()

View File

@ -372,7 +372,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
return err return err
} }
// nolint: contextcheck //nolint:contextcheck
go download.Run(context.Background(), requestURL, opts.regOpts) go download.Run(context.Background(), requestURL, opts.regOpts)
} }

View File

@ -314,7 +314,7 @@ func realpath(rel, from string) string {
return abspath return abspath
} }
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) { func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) {
config := ConfigV2{ config := ConfigV2{
OS: "linux", OS: "linux",
Architecture: "amd64", Architecture: "amd64",
@ -546,16 +546,10 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
} }
} }
unref := make(map[string]struct{}) if !envconfig.NoPrune {
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil { if old, err := ParseNamedManifest(name); err == nil {
for _, layer := range manifest.Layers { //nolint:errcheck
if !slices.Contains(digests, layer.Digest) { defer old.RemoveLayers()
unref[layer.Digest] = struct{}{}
}
}
if manifest.Config.Digest != layer.Digest {
unref[manifest.Config.Digest] = struct{}{}
} }
} }
@ -564,12 +558,6 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
return err return err
} }
if !envconfig.NoPrune {
if err := deleteUnusedLayers(nil, unref); err != nil {
return err
}
}
fn(api.ProgressResponse{Status: "success"}) fn(api.ProgressResponse{Status: "success"})
return nil return nil
} }
@ -637,7 +625,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{})
// save (i.e. delete from the deleteMap) any files used in other manifests // save (i.e. delete from the deleteMap) any files used in other manifests
manifest, _, err := GetManifest(fmp) manifest, _, err := GetManifest(fmp)
if err != nil { if err != nil {
// nolint: nilerr //nolint:nilerr
return nil return nil
} }

View File

@ -1,7 +1,6 @@
package server package server
import ( import (
"bytes"
"crypto/sha256" "crypto/sha256"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -34,12 +33,6 @@ func (m *Manifest) Remove() error {
return err return err
} }
for _, layer := range append(m.Layers, m.Config) {
if err := layer.Remove(); err != nil {
return err
}
}
manifests, err := GetManifestPath() manifests, err := GetManifestPath()
if err != nil { if err != nil {
return err return err
@ -48,6 +41,16 @@ func (m *Manifest) Remove() error {
return PruneDirectory(manifests) return PruneDirectory(manifests)
} }
func (m *Manifest) RemoveLayers() error {
for _, layer := range append(m.Layers, m.Config) {
if err := layer.Remove(); err != nil {
return err
}
}
return nil
}
func ParseNamedManifest(n model.Name) (*Manifest, error) { func ParseNamedManifest(n model.Name) (*Manifest, error) {
if !n.IsFullyQualified() { if !n.IsFullyQualified() {
return nil, model.Unqualified(n) return nil, model.Unqualified(n)
@ -85,30 +88,31 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
}, nil }, nil
} }
func WriteManifest(name string, config *Layer, layers []*Layer) error { func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
manifest := ManifestV2{ manifests, err := GetManifestPath()
if err != nil {
return err
}
p := filepath.Join(manifests, name.Filepath())
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
return err
}
f, err := os.Create(p)
if err != nil {
return err
}
defer f.Close()
m := ManifestV2{
SchemaVersion: 2, SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json", MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config, Config: config,
Layers: layers, Layers: layers,
} }
var b bytes.Buffer return json.NewEncoder(f).Encode(m)
if err := json.NewEncoder(&b).Encode(manifest); err != nil {
return err
}
modelpath := ParseModelPath(name)
manifestPath, err := modelpath.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
return err
}
return os.WriteFile(manifestPath, b.Bytes(), 0o644)
} }
func Manifests() (map[model.Name]*Manifest, error) { func Manifests() (map[model.Name]*Manifest, error) {

View File

@ -23,16 +23,14 @@ type layerWithGGML struct {
} }
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
modelpath := ParseModelPath(name.String()) m, err := ParseNamedManifest(name)
manifest, _, err := GetManifest(modelpath)
switch { switch {
case errors.Is(err, os.ErrNotExist): case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil { if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
return nil, err return nil, err
} }
modelpath = ParseModelPath(name.String()) m, err = ParseNamedManifest(name)
manifest, _, err = GetManifest(modelpath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -40,8 +38,8 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
return nil, err return nil, err
} }
for _, layer := range manifest.Layers { for _, layer := range m.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname()) layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -506,8 +506,8 @@ func (s *Server) PushModelHandler(c *gin.Context) {
} }
func (s *Server) CreateModelHandler(c *gin.Context) { func (s *Server) CreateModelHandler(c *gin.Context) {
var req api.CreateRequest var r api.CreateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
} else if err != nil { } else if err != nil {
@ -515,30 +515,30 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
return return
} }
name := model.ParseName(cmp.Or(req.Model, req.Name)) name := model.ParseName(cmp.Or(r.Model, r.Name))
if !name.IsValid() { if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
return return
} }
if req.Path == "" && req.Modelfile == "" { if r.Path == "" && r.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
return return
} }
var r io.Reader = strings.NewReader(req.Modelfile) var rd io.Reader = strings.NewReader(r.Modelfile)
if req.Path != "" && req.Modelfile == "" { if r.Path != "" && r.Modelfile == "" {
f, err := os.Open(req.Path) f, err := os.Open(r.Path)
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
return return
} }
defer f.Close() defer f.Close()
r = f rd = f
} }
modelfile, err := model.ParseFile(r) f, err := model.ParseFile(rd)
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
@ -554,17 +554,13 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
quantization := req.Quantization quantization := cmp.Or(r.Quantize, r.Quantization)
if req.Quantize != "" { if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
quantization = req.Quantize
}
if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(quantization), modelfile, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
if req.Stream != nil && !*req.Stream { if r.Stream != nil && !*r.Stream {
waitForStream(c, ch) waitForStream(c, ch)
return return
} }
@ -598,6 +594,11 @@ func (s *Server) DeleteModelHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if err := m.RemoveLayers(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} }
func (s *Server) ShowModelHandler(c *gin.Context) { func (s *Server) ShowModelHandler(c *gin.Context) {

View File

@ -53,6 +53,8 @@ func Test_Routes(t *testing.T) {
} }
createTestModel := func(t *testing.T, name string) { createTestModel := func(t *testing.T, name string) {
t.Helper()
fname := createTestFile(t, "ollama-model") fname := createTestFile(t, "ollama-model")
r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname)) r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
@ -61,7 +63,7 @@ func Test_Routes(t *testing.T) {
fn := func(resp api.ProgressResponse) { fn := func(resp api.ProgressResponse) {
t.Logf("Status: %s", resp.Status) t.Logf("Status: %s", resp.Status)
} }
err = CreateModel(context.TODO(), name, "", "", modelfile, fn) err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn)
assert.Nil(t, err) assert.Nil(t, err)
} }

View File

@ -391,7 +391,7 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryO
return err return err
} }
// nolint: contextcheck //nolint:contextcheck
go upload.Run(context.Background(), opts) go upload.Run(context.Background(), opts)
} }