x/model: make equality checks case-insensitive

This commit is contained in:
Blake Mizerany 2024-04-04 16:04:06 -07:00
parent 92b7e40fde
commit bfe89d6fa0
2 changed files with 37 additions and 7 deletions

View File

@ -2,9 +2,12 @@ package model
import (
"cmp"
"hash/maphash"
"iter"
"slices"
"strings"
"github.com/ollama/ollama/x/types/structs"
)
const MaxNameLength = 255
@ -36,6 +39,8 @@ var kindNames = map[NamePart]string{
//
// Users or Name must check Valid before using it.
type Name struct {
_ structs.Incomparable
host string
namespace string
model string
@ -43,6 +48,27 @@ type Name struct {
build string
}
var mapHashSeed = maphash.MakeSeed()
// MapHash returns a case insensitive hash for use in maps and equality
// checks. For a convienent way to compare names, use [EqualFold].
func (r Name) MapHash() uint64 {
// correctly hash the parts with case insensitive comparison
var h maphash.Hash
h.SetSeed(mapHashSeed)
for _, part := range r.Parts() {
// downcase the part for hashing
for i := range part {
c := part[i]
if c >= 'A' && c <= 'Z' {
c = c - 'A' + 'a'
}
h.WriteByte(c)
}
}
return h.Sum64()
}
// Format returns a string representation of the ref with the given
// concreteness. If a part is missing, it is replaced with a loud
// placeholder.
@ -135,6 +161,10 @@ func (r Name) Model() string { return r.model }
func (r Name) Tag() string { return r.tag }
func (r Name) Build() string { return r.build }
func (r Name) EqualFold(o Name) bool {
return r.MapHash() == o.MapHash()
}
// ParseName parses s into a Name. The input string must be a valid form of
// a model name in the form:
//

View File

@ -49,21 +49,21 @@ func TestNameParts(t *testing.T) {
}
func TestParseName(t *testing.T) {
for s, want := range testNames {
for baseName, want := range testNames {
for _, prefix := range []string{"", "https://", "http://"} {
// We should get the same results with or without the
// http(s) prefixes
s := prefix + s
s := prefix + baseName
t.Run(s, func(t *testing.T) {
got := ParseName(s)
if got != want {
if !got.EqualFold(want) {
t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
}
// test round-trip
if ParseName(got.String()) != got {
t.Errorf("String() = %s; want %s", got.String(), s)
if !ParseName(got.String()).EqualFold(got) {
t.Errorf("String() = %s; want %s", got.String(), baseName)
}
if got.Valid() && got.Model() == "" {
@ -190,7 +190,7 @@ func FuzzParseName(f *testing.F) {
f.Fuzz(func(t *testing.T, s string) {
r0 := ParseName(s)
if !r0.Valid() {
if r0 != (Name{}) {
if !r0.EqualFold(Name{}) {
t.Errorf("expected invalid path to be zero value; got %#v", r0)
}
t.Skipf("invalid path: %q", s)
@ -207,7 +207,7 @@ func FuzzParseName(f *testing.F) {
}
r1 := ParseName(r0.String())
if r0 != r1 {
if !r0.EqualFold(r1) {
t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
}