From 7c7f56a7fbddb5c033a3e3e9d19cefec77bfb698 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Fri, 5 Apr 2024 18:24:50 -0700 Subject: [PATCH] x/model: limit part len, not entire len Limiting the whole name length comes naturally with part name length restrictions. This aligns with Docker's registry behavior. --- x/model/file.go | 10 +-- x/model/name.go | 144 +++++++++++++++++++++++++------------------ x/model/name_test.go | 40 ++++++++++-- 3 files changed, 123 insertions(+), 71 deletions(-) diff --git a/x/model/file.go b/x/model/file.go index 6fdfb61c..22fa4698 100644 --- a/x/model/file.go +++ b/x/model/file.go @@ -1,14 +1,8 @@ // Package model implements the File and Name types for working with and // representing Modelfiles and model Names. // -// The Name type is designed for safety and correctness. It is an opaque -// reference to a model, and holds the parts of a model, casing preserved, -// but is not directly comparable with other Names since model names can be -// represented with different caseing depending on the use case. -// -// Names should never be compared manually parsed. Instead, use the -// [Name.EqualFold] method to compare two names in a case-insensitive -// manner, and [ParseName] to create a Name from a string, safely. +// The Name type should be used when working with model names, and the File +// type should be used when working with Modelfiles. package model import ( diff --git a/x/model/name.go b/x/model/name.go index e24b2d1c..047030d5 100644 --- a/x/model/name.go +++ b/x/model/name.go @@ -2,6 +2,7 @@ package model import ( "cmp" + "errors" "hash/maphash" "iter" "slices" @@ -10,21 +11,19 @@ import ( "github.com/ollama/ollama/x/types/structs" ) -const MaxNameLength = 255 - -type NamePart int - -// Levels of concreteness -const ( - Invalid NamePart = iota - Host - Namespace - Model - Tag - Build +// Errors +var ( + // ErrInvalidName is not used by this package, but is exported so that + // other packages do not need to invent their own error type when they + // need to return an error for an invalid name. + ErrIncompleteName = errors.New("incomplete model name") ) -var kindNames = map[NamePart]string{ +const MaxNamePartLen = 128 + +type NamePartKind int + +var kindNames = map[NamePartKind]string{ Invalid: "Invalid", Host: "Host", Namespace: "Namespace", @@ -33,12 +32,36 @@ var kindNames = map[NamePart]string{ Build: "Build", } -// Name is an opaque reference to a model. It holds the parts of a model, -// casing preserved, and provides methods for comparing and manipulating -// them in a case-insensitive manner. +func (k NamePartKind) String() string { + return cmp.Or(kindNames[k], "!(UNKNOWN PART KIND)") +} + +// Levels of concreteness +const ( + Invalid NamePartKind = iota + Host + Namespace + Model + Tag + Build +) + +// Name is an opaque reference to a model. It holds the parts of a model +// with the case preserved, but is not directly comparable with other Names +// since model names can be represented with different caseing depending on +// the use case. For instance, "Mistral" and "mistral" are the same model +// but each version may have come from different sources (e.g. copied from a +// Web page, or from a file path). // -// To create a Name, use [ParseName]. To compare two names, use -// [Name.EqualFold]. To use a name as a key in a map, use [Name.MapHash]. +// Valid Names can ONLY be constructed by calling [ParseName]. +// +// A Name is valid if and only if is have a valid Model part. The other parts +// are optional. +// +// A Name is considered "complete" if it has all parts present. To check if a +// Name is complete, use [Name.Complete]. +// +// To compare two names in a case-insensitive manner, use [Name.EqualFold]. // // The parts of a Name are: // @@ -124,7 +147,7 @@ func ParseName(s string) Name { // Fill fills in the missing parts of dst with the parts of src. // -// Use this for merging a fully qualified ref with a partial ref, such as +// Use this for merging a fully qualified Name with a partial Name, such as // when filling in a missing parts with defaults. // // The returned Name will only be valid if dst is valid. @@ -144,6 +167,23 @@ func (r Name) WithBuild(build string) Name { return r } +// Has reports whether the Name has the given part kind. +func (r Name) Has(kind NamePartKind) bool { + switch kind { + case Host: + return r.host != "" + case Namespace: + return r.namespace != "" + case Model: + return r.model != "" + case Tag: + return r.tag != "" + case Build: + return r.build != "" + } + return false +} + var mapHashSeed = maphash.MakeSeed() // MapHash returns a case insensitive hash for use in maps and equality @@ -165,9 +205,10 @@ func (r Name) MapHash() uint64 { return h.Sum64() } -// Format returns a string representation of the ref with the given -// concreteness. If a part is missing, it is replaced with a loud -// placeholder. +func (r Name) DisplayModel() string { + return r.model +} + func (r Name) DisplayFull() string { return (Name{ host: cmp.Or(r.host, "!(MISSING DOMAIN)"), @@ -178,27 +219,7 @@ func (r Name) DisplayFull() string { }).String() } -func (r Name) DisplayModel() string { - return r.model -} - -func (r Name) Has(kind NamePart) bool { - switch kind { - case Host: - return r.host != "" - case Namespace: - return r.namespace != "" - case Model: - return r.model != "" - case Tag: - return r.tag != "" - case Build: - return r.build != "" - } - return false -} - -// DisplayCompact returns a compact display string of the ref with only the +// DisplayCompact returns a compact display string of the Name with only the // model and tag parts. func (r Name) DisplayCompact() string { return (Name{ @@ -207,7 +228,7 @@ func (r Name) DisplayCompact() string { }).String() } -// DisplayShort returns a short display string of the ref with only the +// DisplayShort returns a short display string of the Name with only the // model, tag, and build parts. func (r Name) DisplayShort() string { return (Name{ @@ -217,7 +238,7 @@ func (r Name) DisplayShort() string { }).String() } -// DisplayLong returns a long display string of the ref including namespace, +// DisplayLong returns a long display string of the Name including namespace, // model, tag, and build parts. func (r Name) DisplayLong() string { return (Name{ @@ -228,7 +249,7 @@ func (r Name) DisplayLong() string { }).String() } -// String returns the fully qualified ref string. +// String returns the fully qualified Name string. func (r Name) String() string { var b strings.Builder if r.host != "" { @@ -251,7 +272,7 @@ func (r Name) String() string { return b.String() } -// Complete reports whether the ref is fully qualified. That is it has a +// Complete reports whether the Name is fully qualified. That is it has a // domain, namespace, name, tag, and build. func (r Name) Complete() bool { return r.Valid() && !slices.Contains(r.Parts(), "") @@ -262,7 +283,7 @@ func (r Name) Complete() bool { // TODO(bmizerany): LogValue // TODO(bmizerany): driver.Value? (MarshalText etc should be enough) -// Parts returns the parts of the ref in order of concreteness. +// Parts returns the parts of the Name in order of concreteness. // // The length of the returned slice is always 5. func (r Name) Parts() []string { @@ -287,7 +308,7 @@ func (r Name) EqualFold(o Name) bool { return r.MapHash() == o.MapHash() } -// Parts returns a sequence of the parts of a ref string from most specific +// Parts returns a sequence of the parts of a Name string from most specific // to least specific. // // It normalizes the input string by removing "http://" and "https://" only. @@ -295,8 +316,8 @@ func (r Name) EqualFold(o Name) bool { // // As a special case, question marks are ignored so they may be used as // placeholders for missing parts in string literals. -func NameParts(s string) iter.Seq2[NamePart, string] { - return func(yield func(NamePart, string) bool) { +func NameParts(s string) iter.Seq2[NamePartKind, string] { + return func(yield func(NamePartKind, string) bool) { if strings.HasPrefix(s, "http://") { s = s[len("http://"):] } @@ -304,11 +325,11 @@ func NameParts(s string) iter.Seq2[NamePart, string] { s = s[len("https://"):] } - if len(s) > MaxNameLength || len(s) == 0 { + if len(s) > MaxNamePartLen || len(s) == 0 { return } - yieldValid := func(kind NamePart, part string) bool { + yieldValid := func(kind NamePartKind, part string) bool { if !isValidPart(kind, part) { yield(Invalid, "") return false @@ -316,8 +337,13 @@ func NameParts(s string) iter.Seq2[NamePart, string] { return yield(kind, part) } + partLen := 0 state, j := Build, len(s) for i := len(s) - 1; i >= 0; i-- { + if partLen++; partLen > MaxNamePartLen { + yield(Invalid, "") + return + } switch s[i] { case '+': switch state { @@ -325,7 +351,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] { if !yieldValid(Build, s[i+1:j]) { return } - state, j = Tag, i + state, j, partLen = Tag, i, 0 default: yield(Invalid, "") return @@ -336,7 +362,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] { if !yieldValid(Tag, s[i+1:j]) { return } - state, j = Model, i + state, j, partLen = Model, i, 0 default: yield(Invalid, "") return @@ -352,7 +378,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] { if !yieldValid(Namespace, s[i+1:j]) { return } - state, j = Host, i + state, j, partLen = Host, i, 0 default: yield(Invalid, "") return @@ -373,7 +399,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] { } } -// Valid returns true if the ref has a valid nick. To know if a ref is +// Valid returns true if the Name has a valid nick. To know if a Name is // "complete", use Complete. func (r Name) Valid() bool { // Parts ensures we only have valid parts, so no need to validate @@ -382,7 +408,7 @@ func (r Name) Valid() bool { } // isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-] -func isValidPart(kind NamePart, s string) bool { +func isValidPart(kind NamePartKind, s string) bool { if s == "" { return false } @@ -394,7 +420,7 @@ func isValidPart(kind NamePart, s string) bool { return true } -func isValidByte(kind NamePart, c byte) bool { +func isValidByte(kind NamePartKind, c byte) bool { if kind == Namespace && c == '.' { return false } diff --git a/x/model/name_test.go b/x/model/name_test.go index 4c0d6ae7..a4276e95 100644 --- a/x/model/name_test.go +++ b/x/model/name_test.go @@ -52,8 +52,8 @@ var testNames = map[string]Name{ "file:///etc/passwd:latest": {}, "file:///etc/passwd:latest+u": {}, - strings.Repeat("a", MaxNameLength): {model: strings.Repeat("a", MaxNameLength)}, - strings.Repeat("a", MaxNameLength+1): {}, + strings.Repeat("a", MaxNamePartLen): {model: strings.Repeat("a", MaxNamePartLen)}, + strings.Repeat("a", MaxNamePartLen+1): {}, } func TestNameParts(t *testing.T) { @@ -64,6 +64,34 @@ func TestNameParts(t *testing.T) { } } +func TestPartTooLong(t *testing.T) { + for i := Host; i <= Build; i++ { + t.Run(i.String(), func(t *testing.T) { + var p Name + switch i { + case Host: + p.host = strings.Repeat("a", MaxNamePartLen+1) + case Namespace: + p.namespace = strings.Repeat("a", MaxNamePartLen+1) + case Model: + p.model = strings.Repeat("a", MaxNamePartLen+1) + case Tag: + p.tag = strings.Repeat("a", MaxNamePartLen+1) + case Build: + p.build = strings.Repeat("a", MaxNamePartLen+1) + } + s := strings.Trim(p.String(), "+:/") + if len(s) != MaxNamePartLen+1 { + t.Errorf("len(String()) = %d; want %d", len(s), MaxNamePartLen+1) + t.Logf("String() = %q", s) + } + if ParseName(p.String()).Valid() { + t.Errorf("Valid(%q) = true; want false", p) + } + }) + } +} + func TestParseName(t *testing.T) { for baseName, want := range testNames { for _, prefix := range []string{"", "https://", "http://"} { @@ -210,7 +238,7 @@ func FuzzParseName(f *testing.F) { } for _, p := range r0.Parts() { - if len(p) > MaxNameLength { + if len(p) > MaxNamePartLen { t.Errorf("part too long: %q", p) } } @@ -261,11 +289,15 @@ func ExampleFill() { func ExampleName_MapHash() { m := map[uint64]bool{} + // key 1 m[ParseName("mistral:latest+q4").MapHash()] = true m[ParseName("miSTRal:latest+Q4").MapHash()] = true m[ParseName("mistral:LATest+Q4").MapHash()] = true + // key 2 + m[ParseName("mistral:LATest").MapHash()] = true + fmt.Println(len(m)) // Output: - // 1 + // 2 }