x/model: replace part fields with array of parts
This makes building strings and reasoning about parts easier.
This commit is contained in:
parent
45d8d22785
commit
14a6f85e9e
209
x/model/name.go
209
x/model/name.go
@ -1,9 +1,11 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"errors"
|
||||
"hash/maphash"
|
||||
"io"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"slices"
|
||||
@ -41,12 +43,15 @@ func (k NamePart) String() string {
|
||||
|
||||
// Levels of concreteness
|
||||
const (
|
||||
Invalid NamePart = iota
|
||||
Host
|
||||
Host NamePart = iota
|
||||
Namespace
|
||||
Model
|
||||
Tag
|
||||
Build
|
||||
|
||||
NumParts = Build + 1
|
||||
|
||||
Invalid = NamePart(-1)
|
||||
)
|
||||
|
||||
// Name is an opaque reference to a model. It holds the parts of a model
|
||||
@ -84,13 +89,8 @@ const (
|
||||
//
|
||||
// To update parts of a Name with defaults, use [Fill].
|
||||
type Name struct {
|
||||
_ structs.Incomparable
|
||||
|
||||
host string
|
||||
namespace string
|
||||
model string
|
||||
tag string
|
||||
build string
|
||||
_ structs.Incomparable
|
||||
parts [NumParts]string
|
||||
}
|
||||
|
||||
// ParseName parses s into a Name. The input string must be a valid string
|
||||
@ -127,20 +127,10 @@ type Name struct {
|
||||
func ParseName(s string) Name {
|
||||
var r Name
|
||||
for kind, part := range NameParts(s) {
|
||||
switch kind {
|
||||
case Host:
|
||||
r.host = part
|
||||
case Namespace:
|
||||
r.namespace = part
|
||||
case Model:
|
||||
r.model = part
|
||||
case Tag:
|
||||
r.tag = part
|
||||
case Build:
|
||||
r.build = part
|
||||
case Invalid:
|
||||
if kind == Invalid {
|
||||
return Name{}
|
||||
}
|
||||
r.parts[kind] = part
|
||||
}
|
||||
if !r.Valid() {
|
||||
return Name{}
|
||||
@ -152,18 +142,16 @@ func ParseName(s string) Name {
|
||||
//
|
||||
// The returned Name will only be valid if dst is valid.
|
||||
func Fill(dst, src Name) Name {
|
||||
return Name{
|
||||
model: cmp.Or(dst.model, src.model),
|
||||
host: cmp.Or(dst.host, src.host),
|
||||
namespace: cmp.Or(dst.namespace, src.namespace),
|
||||
tag: cmp.Or(dst.tag, src.tag),
|
||||
build: cmp.Or(dst.build, src.build),
|
||||
var r Name
|
||||
for i := range r.parts {
|
||||
r.parts[i] = cmp.Or(dst.parts[i], src.parts[i])
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// WithBuild returns a copy of r with the build set to the given string.
|
||||
func (r Name) WithBuild(build string) Name {
|
||||
r.build = build
|
||||
r.parts[Build] = build
|
||||
return r
|
||||
}
|
||||
|
||||
@ -188,9 +176,15 @@ func (r Name) MapHash() uint64 {
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
func (r Name) slice(from, to NamePart) Name {
|
||||
var v Name
|
||||
copy(v.parts[from:to+1], r.parts[from:to+1])
|
||||
return v
|
||||
}
|
||||
|
||||
// DisplayModel returns the a display string composed of the model only.
|
||||
func (r Name) DisplayModel() string {
|
||||
return r.model
|
||||
return r.parts[Model]
|
||||
}
|
||||
|
||||
// DisplayFullest returns the fullest possible display string in form:
|
||||
@ -202,12 +196,7 @@ func (r Name) DisplayModel() string {
|
||||
// It does not include the build part. For the fullest possible display
|
||||
// string with the build, use [Name.String].
|
||||
func (r Name) DisplayFullest() string {
|
||||
return (Name{
|
||||
host: r.host,
|
||||
namespace: r.namespace,
|
||||
model: r.model,
|
||||
tag: r.tag,
|
||||
}).String()
|
||||
return r.slice(Host, Tag).String()
|
||||
}
|
||||
|
||||
// DisplayShort returns the fullest possible display string in form:
|
||||
@ -216,10 +205,7 @@ func (r Name) DisplayFullest() string {
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
func (r Name) DisplayShort() string {
|
||||
return (Name{
|
||||
model: r.model,
|
||||
tag: r.tag,
|
||||
}).String()
|
||||
return r.slice(Model, Tag).String()
|
||||
}
|
||||
|
||||
// DisplayLong returns the fullest possible display string in form:
|
||||
@ -228,11 +214,36 @@ func (r Name) DisplayShort() string {
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
func (r Name) DisplayLong() string {
|
||||
return (Name{
|
||||
namespace: r.namespace,
|
||||
model: r.model,
|
||||
tag: r.tag,
|
||||
}).String()
|
||||
return r.slice(Namespace, Tag).String()
|
||||
}
|
||||
|
||||
var seps = [...]string{
|
||||
Host: "/",
|
||||
Namespace: "/",
|
||||
Model: ":",
|
||||
Tag: "+",
|
||||
Build: "",
|
||||
}
|
||||
|
||||
func (r Name) WriteTo(w io.Writer) (n int64, err error) {
|
||||
for i := range r.parts {
|
||||
if r.parts[i] == "" {
|
||||
continue
|
||||
}
|
||||
if n > 0 {
|
||||
n1, err := io.WriteString(w, seps[i-1])
|
||||
n += int64(n1)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
n1, err := io.WriteString(w, r.parts[i])
|
||||
n += int64(n1)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
var builderPool = sync.Pool{
|
||||
@ -241,6 +252,9 @@ var builderPool = sync.Pool{
|
||||
},
|
||||
}
|
||||
|
||||
// TODO(bmizerany): Add WriteTo and use in String and MarshalText with
|
||||
// strings.Builder and bytes.Buffer, respectively.
|
||||
|
||||
// String returns the fullest possible display string in form:
|
||||
//
|
||||
// <host>/<namespace>/<model>:<tag>+<build>
|
||||
@ -251,33 +265,10 @@ var builderPool = sync.Pool{
|
||||
// [Name.DisplayFullest].
|
||||
func (r Name) String() string {
|
||||
b := builderPool.Get().(*strings.Builder)
|
||||
b.Reset()
|
||||
defer builderPool.Put(b)
|
||||
b.Grow(0 +
|
||||
len(r.host) +
|
||||
len(r.namespace) +
|
||||
len(r.model) +
|
||||
len(r.tag) +
|
||||
len(r.build) +
|
||||
4, // 4 possible separators
|
||||
)
|
||||
if r.host != "" {
|
||||
b.WriteString(r.host)
|
||||
b.WriteString("/")
|
||||
}
|
||||
if r.namespace != "" {
|
||||
b.WriteString(r.namespace)
|
||||
b.WriteString("/")
|
||||
}
|
||||
b.WriteString(r.model)
|
||||
if r.tag != "" {
|
||||
b.WriteString(":")
|
||||
b.WriteString(r.tag)
|
||||
}
|
||||
if r.build != "" {
|
||||
b.WriteString("+")
|
||||
b.WriteString(r.build)
|
||||
}
|
||||
b.Reset()
|
||||
b.Grow(50) // arbitrarily long enough for most names
|
||||
_, _ = r.WriteTo(b)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
@ -286,13 +277,11 @@ func (r Name) String() string {
|
||||
// returns a string that includes all parts of the Name, with missing parts
|
||||
// replaced with a ("?").
|
||||
func (r Name) GoString() string {
|
||||
return (Name{
|
||||
host: cmp.Or(r.host, "?"),
|
||||
namespace: cmp.Or(r.namespace, "?"),
|
||||
model: cmp.Or(r.model, "?"),
|
||||
tag: cmp.Or(r.tag, "?"),
|
||||
build: cmp.Or(r.build, "?"),
|
||||
}).String()
|
||||
var v Name
|
||||
for i := range r.parts {
|
||||
v.parts[i] = cmp.Or(r.parts[i], "?")
|
||||
}
|
||||
return v.String()
|
||||
}
|
||||
|
||||
// LogValue implements slog.Valuer.
|
||||
@ -300,18 +289,25 @@ func (r Name) LogValue() slog.Value {
|
||||
return slog.StringValue(r.GoString())
|
||||
}
|
||||
|
||||
// MarshalText implements encoding.TextMarshaler.
|
||||
func (r Name) MarshalText() ([]byte, error) {
|
||||
// unsafeBytes is safe here because we gurantee that the string is
|
||||
// never used after this function returns.
|
||||
//
|
||||
// TODO: We can remove this if https://github.com/golang/go/issues/62384
|
||||
// lands.
|
||||
return unsafeBytes(r.String()), nil
|
||||
var bufPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(bytes.Buffer)
|
||||
},
|
||||
}
|
||||
|
||||
func unsafeBytes(s string) []byte {
|
||||
return *(*[]byte)(unsafe.Pointer(&s))
|
||||
// MarshalText implements encoding.TextMarshaler.
|
||||
func (r Name) MarshalText() ([]byte, error) {
|
||||
b := bufPool.Get().(*bytes.Buffer)
|
||||
b.Reset()
|
||||
b.Grow(50) // arbitrarily long enough for most names
|
||||
defer bufPool.Put(b)
|
||||
_, err := r.WriteTo(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO: We can remove this alloc if/when
|
||||
// https://github.com/golang/go/issues/62384 lands.
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements encoding.TextUnmarshaler.
|
||||
@ -329,13 +325,13 @@ func unsafeString(b []byte) string {
|
||||
// Complete reports whether the Name is fully qualified. That is it has a
|
||||
// domain, namespace, name, tag, and build.
|
||||
func (r Name) Complete() bool {
|
||||
return !slices.Contains(r.Parts(), "")
|
||||
return !slices.Contains(r.parts[:], "")
|
||||
}
|
||||
|
||||
// CompleteNoBuild is like [Name.Complete] but it does not require the
|
||||
// build part to be present.
|
||||
func (r Name) CompleteNoBuild() bool {
|
||||
return !slices.Contains(r.Parts()[:4], "")
|
||||
return !slices.Contains(r.parts[:Tag], "")
|
||||
}
|
||||
|
||||
// EqualFold reports whether r and o are equivalent model names, ignoring
|
||||
@ -350,27 +346,23 @@ func (r Name) EqualFold(o Name) bool {
|
||||
//
|
||||
// For simple equality checks, use [Name.EqualFold].
|
||||
func (r Name) CompareFold(o Name) int {
|
||||
return cmp.Or(
|
||||
compareFold(r.host, o.host),
|
||||
compareFold(r.namespace, o.namespace),
|
||||
compareFold(r.model, o.model),
|
||||
compareFold(r.tag, o.tag),
|
||||
compareFold(r.build, o.build),
|
||||
)
|
||||
for i := range r.parts {
|
||||
if n := compareFold(r.parts[i], o.parts[i]); n != 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func compareFold(a, b string) int {
|
||||
// fast-path for unequal lengths
|
||||
if n := cmp.Compare(len(a), len(b)); n != 0 {
|
||||
return n
|
||||
}
|
||||
for i := 0; i < len(a) && i < len(b); i++ {
|
||||
ca, cb := downcase(a[i]), downcase(b[i])
|
||||
if n := cmp.Compare(ca, cb); n != 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return 0
|
||||
return cmp.Compare(len(a), len(b))
|
||||
}
|
||||
|
||||
func downcase(c byte) byte {
|
||||
@ -387,13 +379,7 @@ func downcase(c byte) byte {
|
||||
//
|
||||
// The length of the returned slice is always 5.
|
||||
func (r Name) Parts() []string {
|
||||
return []string{
|
||||
r.host,
|
||||
r.namespace,
|
||||
r.model,
|
||||
r.tag,
|
||||
r.build,
|
||||
}
|
||||
return slices.Clone(r.parts[:])
|
||||
}
|
||||
|
||||
// Parts returns a sequence of the parts of a Name string from most specific
|
||||
@ -492,7 +478,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
|
||||
func (r Name) Valid() bool {
|
||||
// Parts ensures we only have valid parts, so no need to validate
|
||||
// them here, only check if we have a name or not.
|
||||
return r.model != ""
|
||||
return r.parts[Model] != ""
|
||||
}
|
||||
|
||||
// isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-]
|
||||
@ -520,3 +506,10 @@ func isValidByte(kind NamePart, c byte) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func sumLens(a []string) (sum int) {
|
||||
for _, n := range a {
|
||||
sum += len(n)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -11,7 +11,21 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
var testNames = map[string]Name{
|
||||
type fields struct {
|
||||
host, namespace, model, tag, build string
|
||||
}
|
||||
|
||||
func fieldsFromName(p Name) fields {
|
||||
return fields{
|
||||
host: p.parts[Host],
|
||||
namespace: p.parts[Namespace],
|
||||
model: p.parts[Model],
|
||||
tag: p.parts[Tag],
|
||||
build: p.parts[Build],
|
||||
}
|
||||
}
|
||||
|
||||
var testNames = map[string]fields{
|
||||
"mistral:latest": {model: "mistral", tag: "latest"},
|
||||
"mistral": {model: "mistral"},
|
||||
"mistral:30B": {model: "mistral", tag: "30B"},
|
||||
@ -23,7 +37,7 @@ var testNames = map[string]Name{
|
||||
"llama2": {model: "llama2"},
|
||||
"user/model": {namespace: "user", model: "model"},
|
||||
"example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"},
|
||||
"example.com/ns/mistral:7b+x": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
|
||||
"example.com/ns/mistral:7b+X": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
|
||||
|
||||
// preserves case for build
|
||||
"x+b": {model: "x", build: "b"},
|
||||
@ -73,7 +87,7 @@ func TestNameParts(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNamePartString(t *testing.T) {
|
||||
if g := NamePart(-1).String(); g != "Unknown" {
|
||||
if g := NamePart(-2).String(); g != "Unknown" {
|
||||
t.Errorf("Unknown part = %q; want %q", g, "Unknown")
|
||||
}
|
||||
for kind, name := range kindNames {
|
||||
@ -83,34 +97,6 @@ func TestNamePartString(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartTooLong(t *testing.T) {
|
||||
for i := Host; i <= Build; i++ {
|
||||
t.Run(i.String(), func(t *testing.T) {
|
||||
var p Name
|
||||
switch i {
|
||||
case Host:
|
||||
p.host = strings.Repeat("a", MaxNamePartLen+1)
|
||||
case Namespace:
|
||||
p.namespace = strings.Repeat("a", MaxNamePartLen+1)
|
||||
case Model:
|
||||
p.model = strings.Repeat("a", MaxNamePartLen+1)
|
||||
case Tag:
|
||||
p.tag = strings.Repeat("a", MaxNamePartLen+1)
|
||||
case Build:
|
||||
p.build = strings.Repeat("a", MaxNamePartLen+1)
|
||||
}
|
||||
s := strings.Trim(p.String(), "+:/")
|
||||
if len(s) != MaxNamePartLen+1 {
|
||||
t.Errorf("len(String()) = %d; want %d", len(s), MaxNamePartLen+1)
|
||||
t.Logf("String() = %q", s)
|
||||
}
|
||||
if ParseName(p.String()).Valid() {
|
||||
t.Errorf("Valid(%q) = true; want false", p)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseName(t *testing.T) {
|
||||
for baseName, want := range testNames {
|
||||
for _, prefix := range []string{"", "https://", "http://"} {
|
||||
@ -119,19 +105,20 @@ func TestParseName(t *testing.T) {
|
||||
s := prefix + baseName
|
||||
|
||||
t.Run(s, func(t *testing.T) {
|
||||
got := ParseName(s)
|
||||
if !got.EqualFold(want) {
|
||||
name := ParseName(s)
|
||||
got := fieldsFromName(name)
|
||||
if got != want {
|
||||
t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
|
||||
}
|
||||
|
||||
// test round-trip
|
||||
if !ParseName(got.String()).EqualFold(got) {
|
||||
t.Errorf("String() = %s; want %s", got.String(), baseName)
|
||||
if !ParseName(name.String()).EqualFold(name) {
|
||||
t.Errorf("String() = %s; want %s", name.String(), baseName)
|
||||
}
|
||||
|
||||
if got.Valid() && got.model == "" {
|
||||
if name.Valid() && name.DisplayModel() == "" {
|
||||
t.Errorf("Valid() = true; Model() = %q; want non-empty name", got.model)
|
||||
} else if !got.Valid() && got.DisplayModel() != "" {
|
||||
} else if !name.Valid() && name.DisplayModel() != "" {
|
||||
t.Errorf("Valid() = false; Model() = %q; want empty name", got.model)
|
||||
}
|
||||
})
|
||||
@ -405,7 +392,7 @@ func TestNameTextMarshal(t *testing.T) {
|
||||
t.Fatal("MarshalText() = 0; want non-zero")
|
||||
}
|
||||
})
|
||||
if allocs > 1 {
|
||||
if allocs > 0 {
|
||||
// TODO: Update when/if this lands:
|
||||
// https://github.com/golang/go/issues/62384
|
||||
//
|
||||
@ -414,6 +401,16 @@ func TestNameTextMarshal(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameStringAllocs(t *testing.T) {
|
||||
name := ParseName("example.com/ns/mistral:latest+Q4_0")
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(name.String())
|
||||
})
|
||||
if allocs > 1 {
|
||||
t.Errorf("String allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleFill() {
|
||||
defaults := ParseName("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0")
|
||||
r := Fill(ParseName("mistral"), defaults)
|
||||
|
Loading…
x
Reference in New Issue
Block a user