add --upgrade-all
flag to refresh any stale models
This commit is contained in:
parent
b5cf31b460
commit
021b1bdc4a
12
api/types.go
12
api/types.go
@ -183,11 +183,12 @@ type CopyRequest struct {
|
||||
}
|
||||
|
||||
type PullRequest struct {
|
||||
Model string `json:"model"`
|
||||
Insecure bool `json:"insecure,omitempty"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Insecure bool `json:"insecure,omitempty"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
CurrentDigest string `json:"current_digest,omitempty"`
|
||||
|
||||
// Name is deprecated, see Model
|
||||
Name string `json:"name"`
|
||||
@ -241,6 +242,7 @@ type GenerateResponse struct {
|
||||
|
||||
type ModelDetails struct {
|
||||
ParentModel string `json:"parent_model"`
|
||||
Digest string `json:"digest"`
|
||||
Format string `json:"format"`
|
||||
Family string `json:"family"`
|
||||
Families []string `json:"families"`
|
||||
|
64
cmd/cmd.go
64
cmd/cmd.go
@ -11,6 +11,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
@ -357,6 +358,62 @@ func CopyHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func PullHandler(cmd *cobra.Command, args []string) error {
|
||||
upgradeAll, err := cmd.Flags().GetBool("upgrade-all")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !upgradeAll {
|
||||
if len(args) == 0 {
|
||||
return fmt.Errorf("no model specified to pull")
|
||||
}
|
||||
return pull(cmd, args[0], "")
|
||||
}
|
||||
|
||||
fp, err := server.GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
type modelInfo struct {
|
||||
Name string
|
||||
Digest string
|
||||
}
|
||||
|
||||
var modelList []modelInfo
|
||||
|
||||
walkFunc := func(path string, info os.FileInfo, _ error) error {
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
dir, file := filepath.Split(path)
|
||||
dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
|
||||
tag := strings.Join([]string{dir, file}, ":")
|
||||
|
||||
model, err := server.GetModel(tag)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelList = append(modelList, modelInfo{tag, "sha256:" + model.Digest})
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = filepath.Walk(fp, walkFunc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, m := range modelList {
|
||||
err = pull(cmd, m.Name, m.Digest)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("couldn't pull model '%s'", m.Name))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func pull(cmd *cobra.Command, name string, currentDigest string) error {
|
||||
insecure, err := cmd.Flags().GetBool("insecure")
|
||||
if err != nil {
|
||||
return err
|
||||
@ -368,7 +425,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
defer p.StopWithoutClear()
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
|
||||
@ -402,7 +459,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
request := api.PullRequest{Name: args[0], Insecure: insecure}
|
||||
request := api.PullRequest{Name: name, Insecure: insecure, CurrentDigest: currentDigest}
|
||||
if err := client.Pull(cmd.Context(), &request, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -884,12 +941,13 @@ func NewCLI() *cobra.Command {
|
||||
pullCmd := &cobra.Command{
|
||||
Use: "pull MODEL",
|
||||
Short: "Pull a model from a registry",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Args: cobra.RangeArgs(0, 1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: PullHandler,
|
||||
}
|
||||
|
||||
pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
pullCmd.Flags().Bool("upgrade-all", false, "Upgrade all models if they're out of date")
|
||||
|
||||
pushCmd := &cobra.Command{
|
||||
Use: "push MODEL",
|
||||
|
@ -52,6 +52,10 @@ func (p *Progress) Stop() bool {
|
||||
return stopped
|
||||
}
|
||||
|
||||
func (p *Progress) StopWithoutClear() bool {
|
||||
return p.stop()
|
||||
}
|
||||
|
||||
func (p *Progress) StopAndClear() bool {
|
||||
fmt.Fprint(p.w, "\033[?25l")
|
||||
defer fmt.Fprint(p.w, "\033[?25h")
|
||||
|
@ -471,7 +471,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
fn(api.ProgressResponse{Status: "pulling model"})
|
||||
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
|
||||
if err := PullModel(ctx, c.Args, "", &RegistryOptions{}, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -1041,7 +1041,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
return nil
|
||||
}
|
||||
|
||||
func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
func PullModel(ctx context.Context, name, currentDigest string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
|
||||
var manifest *ManifestV2
|
||||
@ -1069,13 +1069,23 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
return fmt.Errorf("insecure protocol http")
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||
if currentDigest == "" {
|
||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||
}
|
||||
|
||||
manifest, err = pullModelManifest(ctx, mp, regOpts)
|
||||
manifest, err = pullModelManifest(ctx, mp, currentDigest, regOpts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pull model manifest: %s", err)
|
||||
}
|
||||
|
||||
if currentDigest != "" {
|
||||
if manifest == nil {
|
||||
// we already have the model
|
||||
return nil
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "upgrading " + mp.GetShortTagname()})
|
||||
}
|
||||
|
||||
var layers []*Layer
|
||||
layers = append(layers, manifest.Layers...)
|
||||
layers = append(layers, manifest.Config)
|
||||
@ -1147,17 +1157,27 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
return nil
|
||||
}
|
||||
|
||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
|
||||
func pullModelManifest(ctx context.Context, mp ModelPath, currentDigest string, regOpts *RegistryOptions) (*ManifestV2, error) {
|
||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
||||
|
||||
if currentDigest != "" {
|
||||
headers.Set("If-None-Match", currentDigest)
|
||||
}
|
||||
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// todo we can potentially read the manifest locally and return it here
|
||||
if resp.StatusCode == http.StatusNotModified {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var m *ManifestV2
|
||||
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
||||
return nil, err
|
||||
|
@ -451,7 +451,7 @@ func PullModelHandler(c *gin.Context) {
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
if err := PullModel(ctx, model, regOpts, fn); err != nil {
|
||||
if err := PullModel(ctx, model, req.CurrentDigest, regOpts, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
@ -673,6 +673,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
|
||||
modelDetails := api.ModelDetails{
|
||||
ParentModel: model.ParentModel,
|
||||
Digest: "sha256:" + model.Digest,
|
||||
Format: model.Config.ModelFormat,
|
||||
Family: model.Config.ModelFamily,
|
||||
Families: model.Config.ModelFamilies,
|
||||
|
Loading…
x
Reference in New Issue
Block a user