From 61b287cf25ca033a29cf02a7f97f7fe31483c783 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Thu, 2 May 2024 14:36:46 -0700 Subject: [PATCH] types/model: make Name.Filepath substitute colons in host with ("%") This makes the filepath legal on all supported platforms. Fixes #4088 --- server/fixblobs.go | 42 ++++++++++++++++++++++++++++++++++ server/fixblobs_test.go | 49 ++++++++++++++++++++++++++++++++++++++++ server/routes.go | 7 ++++++ types/model/name.go | 4 ++-- types/model/name_test.go | 15 +++++++++--- 5 files changed, 112 insertions(+), 5 deletions(-) diff --git a/server/fixblobs.go b/server/fixblobs.go index 6e6b3114..ed95767e 100644 --- a/server/fixblobs.go +++ b/server/fixblobs.go @@ -3,6 +3,7 @@ package server import ( "os" "path/filepath" + "runtime" "strings" ) @@ -24,3 +25,44 @@ func fixBlobs(dir string) error { return nil }) } + +// fixManifests walks the provided dir and replaces (":") to ("%") for all +// manifest files on non-Windows systems. +func fixManifests(dir string) error { + if runtime.GOOS == "windows" { + return nil + } + return filepath.Walk(dir, func(oldPath string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + + var partNum int + newPath := []byte(oldPath) + for i := len(newPath) - 1; i >= 0; i-- { + if partNum > 3 { + break + } + if partNum == 3 { + if newPath[i] == ':' { + newPath[i] = '%' + break + } + continue + } + if newPath[i] == '/' { + partNum++ + } + } + + newDir, _ := filepath.Split(string(newPath)) + if err := os.MkdirAll(newDir, 0o755); err != nil { + return err + } + + return os.Rename(oldPath, string(newPath)) + }) +} diff --git a/server/fixblobs_test.go b/server/fixblobs_test.go index c2d4851a..a204de89 100644 --- a/server/fixblobs_test.go +++ b/server/fixblobs_test.go @@ -64,6 +64,55 @@ func TestFixBlobs(t *testing.T) { } } +func TestFixManifests(t *testing.T) { + cases := []struct { + path []string + want []string + }{ + {path: []string{}, want: []string{}}, + {path: []string{"h/n/m/t"}, want: []string{"h/n/m/t"}}, + {path: []string{"h:p/n/m/t"}, want: []string{"h%p/n/m/t"}}, + {path: []string{"x:y/h:p/n/m/t"}, want: []string{"x:y/h%p/n/m/t"}}, + } + + for _, tt := range cases { + t.Run(strings.Join(tt.path, "|"), func(t *testing.T) { + hasColon := slices.ContainsFunc(tt.path, func(s string) bool { return strings.Contains(s, ":") }) + if hasColon && runtime.GOOS == "windows" { + t.Skip("skipping test on windows") + } + + rootDir := t.TempDir() + for _, path := range tt.path { + fullPath := filepath.Join(rootDir, path) + fullDir, _ := filepath.Split(fullPath) + + t.Logf("creating dir %s", fullDir) + if err := os.MkdirAll(fullDir, 0o755); err != nil { + t.Fatal(err) + } + + t.Logf("writing file %s", fullPath) + if err := os.WriteFile(fullPath, nil, 0o644); err != nil { + t.Fatal(err) + } + } + + if err := fixManifests(rootDir); err != nil { + t.Fatal(err) + } + + got := slurpFiles(os.DirFS(rootDir)) + + slices.Sort(tt.want) + slices.Sort(got) + if !slices.Equal(got, tt.want) { + t.Fatalf("got = %v, want %v", got, tt.want) + } + }) + } +} + func slurpFiles(fsys fs.FS) []string { var sfs []string fn := func(path string, d fs.DirEntry, err error) error { diff --git a/server/routes.go b/server/routes.go index 480527f2..d590dbfd 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1046,6 +1046,13 @@ func Serve(ln net.Listener) error { if err := fixBlobs(blobsDir); err != nil { return err } + manifestsDir, err := GetManifestPath() + if err != nil { + return err + } + if err := fixManifests(manifestsDir); err != nil { + return err + } if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { // clean up unused layers and manifests diff --git a/types/model/name.go b/types/model/name.go index cb890b3a..36e5e3f8 100644 --- a/types/model/name.go +++ b/types/model/name.go @@ -161,7 +161,7 @@ func ParseNameBare(s string) Name { } scheme, host, ok := strings.Cut(s, "://") - if ! ok { + if !ok { host = scheme } n.Host = host @@ -243,7 +243,7 @@ func (n Name) Filepath() string { panic("illegal attempt to get filepath of invalid name") } return strings.ToLower(filepath.Join( - n.Host, + strings.Replace(n.Host, ":", "%", 1), n.Namespace, n.Model, n.Tag, diff --git a/types/model/name_test.go b/types/model/name_test.go index 47263c20..15825dd0 100644 --- a/types/model/name_test.go +++ b/types/model/name_test.go @@ -27,7 +27,7 @@ func TestParseNameParts(t *testing.T) { Model: "model", Tag: "tag", }, - wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"), + wantFilepath: filepath.Join("host%port", "namespace", "model", "tag"), }, { in: "host/namespace/model:tag", @@ -47,7 +47,7 @@ func TestParseNameParts(t *testing.T) { Model: "model", Tag: "tag", }, - wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"), + wantFilepath: filepath.Join("host%port", "namespace", "model", "tag"), }, { in: "host/namespace/model", @@ -65,7 +65,7 @@ func TestParseNameParts(t *testing.T) { Namespace: "namespace", Model: "model", }, - wantFilepath: filepath.Join("host:port", "namespace", "model", "latest"), + wantFilepath: filepath.Join("host%port", "namespace", "model", "latest"), }, { in: "namespace/model", @@ -127,6 +127,15 @@ func TestParseNameParts(t *testing.T) { }, wantValidDigest: true, }, + { + in: "y.com:443/n/model", + want: Name{ + Host: "y.com:443", + Namespace: "n", + Model: "model", + }, + wantFilepath: filepath.Join("y.com%443", "n", "model", "latest"), + }, } for _, tt := range cases {