x/model: make equality checks case-insensitive
This commit is contained in:
parent
92b7e40fde
commit
bfe89d6fa0
@ -2,9 +2,12 @@ package model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"hash/maphash"
|
||||
"iter"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/types/structs"
|
||||
)
|
||||
|
||||
const MaxNameLength = 255
|
||||
@ -36,6 +39,8 @@ var kindNames = map[NamePart]string{
|
||||
//
|
||||
// Users or Name must check Valid before using it.
|
||||
type Name struct {
|
||||
_ structs.Incomparable
|
||||
|
||||
host string
|
||||
namespace string
|
||||
model string
|
||||
@ -43,6 +48,27 @@ type Name struct {
|
||||
build string
|
||||
}
|
||||
|
||||
var mapHashSeed = maphash.MakeSeed()
|
||||
|
||||
// MapHash returns a case insensitive hash for use in maps and equality
|
||||
// checks. For a convienent way to compare names, use [EqualFold].
|
||||
func (r Name) MapHash() uint64 {
|
||||
// correctly hash the parts with case insensitive comparison
|
||||
var h maphash.Hash
|
||||
h.SetSeed(mapHashSeed)
|
||||
for _, part := range r.Parts() {
|
||||
// downcase the part for hashing
|
||||
for i := range part {
|
||||
c := part[i]
|
||||
if c >= 'A' && c <= 'Z' {
|
||||
c = c - 'A' + 'a'
|
||||
}
|
||||
h.WriteByte(c)
|
||||
}
|
||||
}
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
// Format returns a string representation of the ref with the given
|
||||
// concreteness. If a part is missing, it is replaced with a loud
|
||||
// placeholder.
|
||||
@ -135,6 +161,10 @@ func (r Name) Model() string { return r.model }
|
||||
func (r Name) Tag() string { return r.tag }
|
||||
func (r Name) Build() string { return r.build }
|
||||
|
||||
func (r Name) EqualFold(o Name) bool {
|
||||
return r.MapHash() == o.MapHash()
|
||||
}
|
||||
|
||||
// ParseName parses s into a Name. The input string must be a valid form of
|
||||
// a model name in the form:
|
||||
//
|
||||
|
@ -49,21 +49,21 @@ func TestNameParts(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseName(t *testing.T) {
|
||||
for s, want := range testNames {
|
||||
for baseName, want := range testNames {
|
||||
for _, prefix := range []string{"", "https://", "http://"} {
|
||||
// We should get the same results with or without the
|
||||
// http(s) prefixes
|
||||
s := prefix + s
|
||||
s := prefix + baseName
|
||||
|
||||
t.Run(s, func(t *testing.T) {
|
||||
got := ParseName(s)
|
||||
if got != want {
|
||||
if !got.EqualFold(want) {
|
||||
t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
|
||||
}
|
||||
|
||||
// test round-trip
|
||||
if ParseName(got.String()) != got {
|
||||
t.Errorf("String() = %s; want %s", got.String(), s)
|
||||
if !ParseName(got.String()).EqualFold(got) {
|
||||
t.Errorf("String() = %s; want %s", got.String(), baseName)
|
||||
}
|
||||
|
||||
if got.Valid() && got.Model() == "" {
|
||||
@ -190,7 +190,7 @@ func FuzzParseName(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, s string) {
|
||||
r0 := ParseName(s)
|
||||
if !r0.Valid() {
|
||||
if r0 != (Name{}) {
|
||||
if !r0.EqualFold(Name{}) {
|
||||
t.Errorf("expected invalid path to be zero value; got %#v", r0)
|
||||
}
|
||||
t.Skipf("invalid path: %q", s)
|
||||
@ -207,7 +207,7 @@ func FuzzParseName(f *testing.F) {
|
||||
}
|
||||
|
||||
r1 := ParseName(r0.String())
|
||||
if r0 != r1 {
|
||||
if !r0.EqualFold(r1) {
|
||||
t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user