diff --git a/x/model/name.go b/x/model/name.go index ff74668c..ccc80ad4 100644 --- a/x/model/name.go +++ b/x/model/name.go @@ -262,6 +262,44 @@ func (r Name) Complete() bool { return !slices.Contains(r.Parts(), "") } +// EqualFold reports whether r and o are equivalent model names, ignoring +// case. +func (r Name) EqualFold(o Name) bool { + return r.CompareFold(o) == 0 +} + +// CompareFold performs a case-insensitive comparison of two Names. It returns +// an integer comparing two Names lexicographically. The result will be 0 if +// r == o, -1 if r < o, and +1 if r > o. +// +// This can be used with [slice.SortFunc]. +func (r Name) CompareFold(o Name) int { + return cmp.Or( + compareFold(r.host, o.host), + compareFold(r.namespace, o.namespace), + compareFold(r.model, o.model), + compareFold(r.tag, o.tag), + compareFold(r.build, o.build), + ) +} + +func compareFold(a, b string) int { + for i := 0; i < len(a) && i < len(b); i++ { + ca, cb := downcase(a[i]), downcase(b[i]) + if n := cmp.Compare(ca, cb); n != 0 { + return n + } + } + return cmp.Compare(len(a), len(b)) +} + +func downcase(c byte) byte { + if c >= 'A' && c <= 'Z' { + return c + 'a' - 'A' + } + return c +} + // TODO(bmizerany): Compare // TODO(bmizerany): MarshalText/UnmarshalText // TODO(bmizerany): LogValue @@ -280,12 +318,6 @@ func (r Name) Parts() []string { } } -// EqualFold reports whether r and o are equivalent model names, ignoring -// case. -func (r Name) EqualFold(o Name) bool { - return r.MapHash() == o.MapHash() -} - // Parts returns a sequence of the parts of a Name string from most specific // to least specific. // diff --git a/x/model/name_test.go b/x/model/name_test.go index 63c5676c..73359987 100644 --- a/x/model/name_test.go +++ b/x/model/name_test.go @@ -2,6 +2,7 @@ package model import ( "fmt" + "slices" "strings" "testing" ) @@ -339,4 +340,23 @@ func ExampleName_MapHash() { // 2 } +func ExampleName_CompareFold_sort() { + names := []Name{ + ParseName("mistral:latest"), + ParseName("mistRal:7b+q4"), + ParseName("MIstral:7b"), + } + + slices.SortFunc(names, Name.CompareFold) + + for _, n := range names { + fmt.Println(n) + } + + // Output: + // MIstral:7b + // mistRal:7b+q4 + // mistral:latest +} + func keep[T any](v T) T { return v }