diff --git a/x/model/name.go b/x/model/name.go index 77ab08af..2f30ae2a 100644 --- a/x/model/name.go +++ b/x/model/name.go @@ -6,6 +6,7 @@ import ( "database/sql" "database/sql/driver" "errors" + "fmt" "hash/maphash" "io" "iter" @@ -13,6 +14,7 @@ import ( "slices" "strings" "sync" + "unicode" "github.com/ollama/ollama/x/types/structs" ) @@ -105,7 +107,7 @@ type Name struct { // ParseName parses s into a Name. The input string must be a valid string // representation of a model name in the form: // -// //:+ +// //:+@- // // The name part is required, all others are optional. If a part is missing, // it is left empty in the returned Name. If a part is invalid, the zero Ref @@ -120,6 +122,7 @@ type Name struct { // "mistral:7b+x" // "example.com/mike/mistral:latest+Q4_0" // "example.com/bruce/mistral:latest" +// "example.com/mistral:7b+Q4_0@sha256-1234567890abcdef" // // Examples of invalid paths: // @@ -141,10 +144,10 @@ func ParseName(s string) Name { } r.parts[kind] = part } - if !r.Valid() { - return Name{} + if r.Valid() || r.Resolved() { + return r } - return r + return Name{} } // Fill fills in the missing parts of dst with the parts of src. @@ -238,15 +241,19 @@ var seps = [...]string{ // WriteTo implements io.WriterTo. It writes the fullest possible display // string in form: // -// //:+ +// //:+@- // // Missing parts and their seperators are not written. +// +// The full digest is always prefixed with "@". That is if [Name.Valid] +// reports false and [Name.Resolved] reports true, then the string is +// returned as "@-". func (r Name) WriteTo(w io.Writer) (n int64, err error) { for i := range r.parts { if r.parts[i] == "" { continue } - if n > 0 { + if n > 0 || NamePart(i) == Digest { n1, err := io.WriteString(w, seps[i-1]) n += int64(n1) if err != nil { @@ -382,6 +389,22 @@ func (r Name) CompleteNoBuild() bool { return !slices.Contains(r.parts[:Build], "") } +// Resolved reports true if the Name has a valid digest. +// +// It is possible to have a valid Name, or a complete Name that is not +// resolved. +func (r Name) Resolved() bool { + return r.parts[Digest] != "" +} + +// Digest returns the digest part of the Name, if any. +// +// If Digest returns a non-empty string, then [Name.Resolved] will return +// true, and digest is considered valid. +func (r Name) Digest() string { + return r.parts[Digest] +} + // EqualFold reports whether r and o are equivalent model names, ignoring // case. func (r Name) EqualFold(o Name) bool { @@ -452,11 +475,29 @@ func Parts(s string) iter.Seq2[NamePart, string] { yield(Invalid, "") return } + switch s[i] { case '@': switch state { case Digest: - if !yieldValid(Digest, s[i+1:j]) { + part := s[i+1:] + if isValidDigest(part) { + if !yield(Digest, part) { + return + } + if i == 0 { + // The name is in + // the form of + // "@digest". This + // is valid ans so + // we want to skip + // the final + // validation for + // any other state. + return + } + } else { + yield(Invalid, "") return } state, j, partLen = Build, i, 0 @@ -552,9 +593,41 @@ func isValidByte(kind NamePart, c byte) bool { return false } -func sumLens(a []string) (sum int) { - for _, n := range a { - sum += len(n) - } - return +// isValidDigest returns true if the given string in the form of +// "-", and is in the form of [a-z0-9]+ +// and is a valid hex string. +// +// It does not check if the digest is a valid hash for the given digest +// type, or restrict the digest type to a known set of types. This is left +// up to ueers of this package. +func isValidDigest(s string) bool { + typ, digest, ok := strings.Cut(s, "-") + res := ok && isValidDigestType(typ) && isValidHex(digest) + fmt.Printf("DEBUG: %q: typ: %s, digest: %s, ok: %v res: %v\n", s, typ, digest, ok, res) + return res +} + +func isValidDigestType(s string) bool { + if len(s) == 0 { + return false + } + for _, r := range s { + if !unicode.IsLower(r) && !unicode.IsDigit(r) { + return false + } + } + return true +} + +func isValidHex(s string) bool { + if len(s) == 0 { + return false + } + for i := range s { + c := s[i] + if c < '0' || c > '9' && c < 'a' || c > 'f' { + return false + } + } + return true } diff --git a/x/model/name_test.go b/x/model/name_test.go index 99d9f74a..4f8f5822 100644 --- a/x/model/name_test.go +++ b/x/model/name_test.go @@ -49,8 +49,16 @@ var testNames = map[string]fields{ "example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"}, "example.com/ns/mistral:7b+X": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"}, + // invalid digest + "mistral:latest@invalid256-": {}, + "mistral:latest@-123": {}, + "mistral:latest@!-123": {}, + "mistral:latest@1-!": {}, + "mistral:latest@": {}, + // resolved - "x@123": {model: "x", digest: "123"}, + "x@sha123-1": {model: "x", digest: "sha123-1"}, + "@sha456-2": {digest: "sha456-2"}, // preserves case for build "x+b": {model: "x", build: "b"}, @@ -109,6 +117,53 @@ func TestNamePartString(t *testing.T) { } } +func TestIsValidDigestType(t *testing.T) { + cases := []struct { + in string + want bool + }{ + {"sha256", true}, + {"blake2", true}, + + {"", false}, + {"-sha256", false}, + {"sha256-", false}, + {"Sha256", false}, + {"sha256(", false}, + {" sha256", false}, + } + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + if g := isValidDigestType(tt.in); g != tt.want { + t.Errorf("isValidDigestType(%q) = %v; want %v", tt.in, g, tt.want) + } + }) + } +} + +func TestIsValidDigest(t *testing.T) { + cases := []struct { + in string + want bool + }{ + {"", false}, + {"sha256-123", true}, + {"sha256-1234567890abcdef", true}, + {"sha256-1234567890abcdef1234567890abcdeffffffffffffffffffffffffffffffffffffffffff", true}, + {"!sha256-123", false}, + {"sha256-123!", false}, + {"sha256-", false}, + {"-123", false}, + } + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + if g := isValidDigest(tt.in); g != tt.want { + t.Errorf("isValidDigest(%q) = %v; want %v", tt.in, g, tt.want) + } + }) + } +} + func TestParseName(t *testing.T) { for baseName, want := range testNames { for _, prefix := range []string{"", "https://", "http://"} { @@ -117,6 +172,10 @@ func TestParseName(t *testing.T) { s := prefix + baseName t.Run(s, func(t *testing.T) { + for kind, part := range Parts(s) { + t.Logf("Part: %s: %q", kind, part) + } + name := ParseName(s) got := fieldsFromName(name) if got != want { @@ -133,6 +192,12 @@ func TestParseName(t *testing.T) { } else if !name.Valid() && name.DisplayModel() != "" { t.Errorf("Valid() = false; Model() = %q; want empty name", got.model) } + + if name.Resolved() && name.Digest() == "" { + t.Errorf("Resolved() = true; Digest() = %q; want non-empty digest", got.digest) + } else if !name.Resolved() && name.Digest() != "" { + t.Errorf("Resolved() = false; Digest() = %q; want empty digest", got.digest) + } }) } }