From 55ac7254a376de4d46111d3ea21a0a66eae8e7e9 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Wed, 24 Apr 2024 15:20:21 -0700 Subject: [PATCH] add param count tag to latest push by default - add a tag for the model parameter count by default when a new latest model is pushed - allow skipping this default with an environment flag --- server/images.go | 73 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 24 deletions(-) diff --git a/server/images.go b/server/images.go index dd44a0f4..cb708804 100644 --- a/server/images.go +++ b/server/images.go @@ -115,6 +115,27 @@ func (c *ConfigV2) SetFileType(fileType string) { } } +// GetConfig reads the config file from the manifest and returns the ConfigV2 struct +func GetConfig(manifest *ManifestV2) (*ConfigV2, error) { + filename, err := GetBlobsPath(manifest.Config.Digest) + if err != nil { + return nil, err + } + + configFile, err := os.Open(filename) + if err != nil { + return nil, err + } + defer configFile.Close() + + var config ConfigV2 + if err := json.NewDecoder(configFile).Decode(&config); err != nil { + return nil, err + } + + return &config, nil +} + type RootFS struct { Type string `json:"type"` DiffIDs []string `json:"diff_ids"` @@ -172,20 +193,11 @@ func GetModel(name string) (*Model, error) { Size: manifest.GetTotalSize(), } - filename, err := GetBlobsPath(manifest.Config.Digest) + config, err := GetConfig(manifest) if err != nil { - return nil, err - } - - configFile, err := os.Open(filename) - if err != nil { - return nil, err - } - defer configFile.Close() - - if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil { - return nil, err + return nil, fmt.Errorf("get manifest config: %v", err) } + model.Config = *config for _, layer := range manifest.Layers { filename, err := GetBlobsPath(layer.Digest) @@ -979,22 +991,35 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu } } + tags := []string{mp.Tag} + // OLLAMA_SKIP_DEFAULT_TAGS is an undocumented configuration option to skip adding the default tag + if mp.Tag == DefaultTag && (os.Getenv("OLLAMA_SKIP_DEFAULT_TAGS") != "") { + // also add the parameter count as a tag by default + config, err := GetConfig(manifest) + if err != nil { + return fmt.Errorf("push manifest config: %v", err) + } + tags = append(tags, strings.ToLower(config.ModelType)) + } + fn(api.ProgressResponse{Status: "pushing manifest"}) - requestURL := mp.BaseURL() - requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) + for _, tag := range tags { + requestURL := mp.BaseURL() + requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", tag) - manifestJSON, err := json.Marshal(manifest) - if err != nil { - return err - } + manifestJSON, err := json.Marshal(manifest) + if err != nil { + return err + } - headers := make(http.Header) - headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json") - resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts) - if err != nil { - return err + headers := make(http.Header) + headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json") + resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts) + if err != nil { + return err + } + resp.Body.Close() } - defer resp.Body.Close() fn(api.ProgressResponse{Status: "success"})