add param count tag to latest push by default

- add a tag for the model parameter count by default when a new latest model is pushed
- allow skipping this default with an environment flag
This commit is contained in:
Bruce MacDonald 2024-04-24 15:20:21 -07:00
parent ade4b55520
commit 55ac7254a3

View File

@ -115,6 +115,27 @@ func (c *ConfigV2) SetFileType(fileType string) {
}
}
// GetConfig reads the config file from the manifest and returns the ConfigV2 struct
func GetConfig(manifest *ManifestV2) (*ConfigV2, error) {
filename, err := GetBlobsPath(manifest.Config.Digest)
if err != nil {
return nil, err
}
configFile, err := os.Open(filename)
if err != nil {
return nil, err
}
defer configFile.Close()
var config ConfigV2
if err := json.NewDecoder(configFile).Decode(&config); err != nil {
return nil, err
}
return &config, nil
}
type RootFS struct {
Type string `json:"type"`
DiffIDs []string `json:"diff_ids"`
@ -172,20 +193,11 @@ func GetModel(name string) (*Model, error) {
Size: manifest.GetTotalSize(),
}
filename, err := GetBlobsPath(manifest.Config.Digest)
config, err := GetConfig(manifest)
if err != nil {
return nil, err
}
configFile, err := os.Open(filename)
if err != nil {
return nil, err
}
defer configFile.Close()
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
return nil, err
return nil, fmt.Errorf("get manifest config: %v", err)
}
model.Config = *config
for _, layer := range manifest.Layers {
filename, err := GetBlobsPath(layer.Digest)
@ -979,22 +991,35 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
}
tags := []string{mp.Tag}
// OLLAMA_SKIP_DEFAULT_TAGS is an undocumented configuration option to skip adding the default tag
if mp.Tag == DefaultTag && (os.Getenv("OLLAMA_SKIP_DEFAULT_TAGS") != "") {
// also add the parameter count as a tag by default
config, err := GetConfig(manifest)
if err != nil {
return fmt.Errorf("push manifest config: %v", err)
}
tags = append(tags, strings.ToLower(config.ModelType))
}
fn(api.ProgressResponse{Status: "pushing manifest"})
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
for _, tag := range tags {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", tag)
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
headers := make(http.Header)
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
if err != nil {
return err
headers := make(http.Header)
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
if err != nil {
return err
}
resp.Body.Close()
}
defer resp.Body.Close()
fn(api.ProgressResponse{Status: "success"})