x/model: replace part fields with array of parts

This makes building strings and reasoning about parts easier.
This commit is contained in:
Blake Mizerany 2024-04-06 13:37:33 -07:00
parent 45d8d22785
commit 14a6f85e9e
2 changed files with 136 additions and 146 deletions

View File

@ -1,9 +1,11 @@
package model
import (
"bytes"
"cmp"
"errors"
"hash/maphash"
"io"
"iter"
"log/slog"
"slices"
@ -41,12 +43,15 @@ func (k NamePart) String() string {
// Levels of concreteness
const (
Invalid NamePart = iota
Host
Host NamePart = iota
Namespace
Model
Tag
Build
NumParts = Build + 1
Invalid = NamePart(-1)
)
// Name is an opaque reference to a model. It holds the parts of a model
@ -84,13 +89,8 @@ const (
//
// To update parts of a Name with defaults, use [Fill].
type Name struct {
_ structs.Incomparable
host string
namespace string
model string
tag string
build string
_ structs.Incomparable
parts [NumParts]string
}
// ParseName parses s into a Name. The input string must be a valid string
@ -127,20 +127,10 @@ type Name struct {
func ParseName(s string) Name {
var r Name
for kind, part := range NameParts(s) {
switch kind {
case Host:
r.host = part
case Namespace:
r.namespace = part
case Model:
r.model = part
case Tag:
r.tag = part
case Build:
r.build = part
case Invalid:
if kind == Invalid {
return Name{}
}
r.parts[kind] = part
}
if !r.Valid() {
return Name{}
@ -152,18 +142,16 @@ func ParseName(s string) Name {
//
// The returned Name will only be valid if dst is valid.
func Fill(dst, src Name) Name {
return Name{
model: cmp.Or(dst.model, src.model),
host: cmp.Or(dst.host, src.host),
namespace: cmp.Or(dst.namespace, src.namespace),
tag: cmp.Or(dst.tag, src.tag),
build: cmp.Or(dst.build, src.build),
var r Name
for i := range r.parts {
r.parts[i] = cmp.Or(dst.parts[i], src.parts[i])
}
return r
}
// WithBuild returns a copy of r with the build set to the given string.
func (r Name) WithBuild(build string) Name {
r.build = build
r.parts[Build] = build
return r
}
@ -188,9 +176,15 @@ func (r Name) MapHash() uint64 {
return h.Sum64()
}
func (r Name) slice(from, to NamePart) Name {
var v Name
copy(v.parts[from:to+1], r.parts[from:to+1])
return v
}
// DisplayModel returns the a display string composed of the model only.
func (r Name) DisplayModel() string {
return r.model
return r.parts[Model]
}
// DisplayFullest returns the fullest possible display string in form:
@ -202,12 +196,7 @@ func (r Name) DisplayModel() string {
// It does not include the build part. For the fullest possible display
// string with the build, use [Name.String].
func (r Name) DisplayFullest() string {
return (Name{
host: r.host,
namespace: r.namespace,
model: r.model,
tag: r.tag,
}).String()
return r.slice(Host, Tag).String()
}
// DisplayShort returns the fullest possible display string in form:
@ -216,10 +205,7 @@ func (r Name) DisplayFullest() string {
//
// If any part is missing, it is omitted from the display string.
func (r Name) DisplayShort() string {
return (Name{
model: r.model,
tag: r.tag,
}).String()
return r.slice(Model, Tag).String()
}
// DisplayLong returns the fullest possible display string in form:
@ -228,11 +214,36 @@ func (r Name) DisplayShort() string {
//
// If any part is missing, it is omitted from the display string.
func (r Name) DisplayLong() string {
return (Name{
namespace: r.namespace,
model: r.model,
tag: r.tag,
}).String()
return r.slice(Namespace, Tag).String()
}
var seps = [...]string{
Host: "/",
Namespace: "/",
Model: ":",
Tag: "+",
Build: "",
}
func (r Name) WriteTo(w io.Writer) (n int64, err error) {
for i := range r.parts {
if r.parts[i] == "" {
continue
}
if n > 0 {
n1, err := io.WriteString(w, seps[i-1])
n += int64(n1)
if err != nil {
return n, err
}
}
n1, err := io.WriteString(w, r.parts[i])
n += int64(n1)
if err != nil {
return n, err
}
}
return n, nil
}
var builderPool = sync.Pool{
@ -241,6 +252,9 @@ var builderPool = sync.Pool{
},
}
// TODO(bmizerany): Add WriteTo and use in String and MarshalText with
// strings.Builder and bytes.Buffer, respectively.
// String returns the fullest possible display string in form:
//
// <host>/<namespace>/<model>:<tag>+<build>
@ -251,33 +265,10 @@ var builderPool = sync.Pool{
// [Name.DisplayFullest].
func (r Name) String() string {
b := builderPool.Get().(*strings.Builder)
b.Reset()
defer builderPool.Put(b)
b.Grow(0 +
len(r.host) +
len(r.namespace) +
len(r.model) +
len(r.tag) +
len(r.build) +
4, // 4 possible separators
)
if r.host != "" {
b.WriteString(r.host)
b.WriteString("/")
}
if r.namespace != "" {
b.WriteString(r.namespace)
b.WriteString("/")
}
b.WriteString(r.model)
if r.tag != "" {
b.WriteString(":")
b.WriteString(r.tag)
}
if r.build != "" {
b.WriteString("+")
b.WriteString(r.build)
}
b.Reset()
b.Grow(50) // arbitrarily long enough for most names
_, _ = r.WriteTo(b)
return b.String()
}
@ -286,13 +277,11 @@ func (r Name) String() string {
// returns a string that includes all parts of the Name, with missing parts
// replaced with a ("?").
func (r Name) GoString() string {
return (Name{
host: cmp.Or(r.host, "?"),
namespace: cmp.Or(r.namespace, "?"),
model: cmp.Or(r.model, "?"),
tag: cmp.Or(r.tag, "?"),
build: cmp.Or(r.build, "?"),
}).String()
var v Name
for i := range r.parts {
v.parts[i] = cmp.Or(r.parts[i], "?")
}
return v.String()
}
// LogValue implements slog.Valuer.
@ -300,18 +289,25 @@ func (r Name) LogValue() slog.Value {
return slog.StringValue(r.GoString())
}
// MarshalText implements encoding.TextMarshaler.
func (r Name) MarshalText() ([]byte, error) {
// unsafeBytes is safe here because we gurantee that the string is
// never used after this function returns.
//
// TODO: We can remove this if https://github.com/golang/go/issues/62384
// lands.
return unsafeBytes(r.String()), nil
var bufPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}
func unsafeBytes(s string) []byte {
return *(*[]byte)(unsafe.Pointer(&s))
// MarshalText implements encoding.TextMarshaler.
func (r Name) MarshalText() ([]byte, error) {
b := bufPool.Get().(*bytes.Buffer)
b.Reset()
b.Grow(50) // arbitrarily long enough for most names
defer bufPool.Put(b)
_, err := r.WriteTo(b)
if err != nil {
return nil, err
}
// TODO: We can remove this alloc if/when
// https://github.com/golang/go/issues/62384 lands.
return b.Bytes(), nil
}
// UnmarshalText implements encoding.TextUnmarshaler.
@ -329,13 +325,13 @@ func unsafeString(b []byte) string {
// Complete reports whether the Name is fully qualified. That is it has a
// domain, namespace, name, tag, and build.
func (r Name) Complete() bool {
return !slices.Contains(r.Parts(), "")
return !slices.Contains(r.parts[:], "")
}
// CompleteNoBuild is like [Name.Complete] but it does not require the
// build part to be present.
func (r Name) CompleteNoBuild() bool {
return !slices.Contains(r.Parts()[:4], "")
return !slices.Contains(r.parts[:Tag], "")
}
// EqualFold reports whether r and o are equivalent model names, ignoring
@ -350,27 +346,23 @@ func (r Name) EqualFold(o Name) bool {
//
// For simple equality checks, use [Name.EqualFold].
func (r Name) CompareFold(o Name) int {
return cmp.Or(
compareFold(r.host, o.host),
compareFold(r.namespace, o.namespace),
compareFold(r.model, o.model),
compareFold(r.tag, o.tag),
compareFold(r.build, o.build),
)
for i := range r.parts {
if n := compareFold(r.parts[i], o.parts[i]); n != 0 {
return n
}
}
return 0
}
func compareFold(a, b string) int {
// fast-path for unequal lengths
if n := cmp.Compare(len(a), len(b)); n != 0 {
return n
}
for i := 0; i < len(a) && i < len(b); i++ {
ca, cb := downcase(a[i]), downcase(b[i])
if n := cmp.Compare(ca, cb); n != 0 {
return n
}
}
return 0
return cmp.Compare(len(a), len(b))
}
func downcase(c byte) byte {
@ -387,13 +379,7 @@ func downcase(c byte) byte {
//
// The length of the returned slice is always 5.
func (r Name) Parts() []string {
return []string{
r.host,
r.namespace,
r.model,
r.tag,
r.build,
}
return slices.Clone(r.parts[:])
}
// Parts returns a sequence of the parts of a Name string from most specific
@ -492,7 +478,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
func (r Name) Valid() bool {
// Parts ensures we only have valid parts, so no need to validate
// them here, only check if we have a name or not.
return r.model != ""
return r.parts[Model] != ""
}
// isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-]
@ -520,3 +506,10 @@ func isValidByte(kind NamePart, c byte) bool {
}
return false
}
func sumLens(a []string) (sum int) {
for _, n := range a {
sum += len(n)
}
return
}

View File

@ -11,7 +11,21 @@ import (
"testing"
)
var testNames = map[string]Name{
type fields struct {
host, namespace, model, tag, build string
}
func fieldsFromName(p Name) fields {
return fields{
host: p.parts[Host],
namespace: p.parts[Namespace],
model: p.parts[Model],
tag: p.parts[Tag],
build: p.parts[Build],
}
}
var testNames = map[string]fields{
"mistral:latest": {model: "mistral", tag: "latest"},
"mistral": {model: "mistral"},
"mistral:30B": {model: "mistral", tag: "30B"},
@ -23,7 +37,7 @@ var testNames = map[string]Name{
"llama2": {model: "llama2"},
"user/model": {namespace: "user", model: "model"},
"example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"},
"example.com/ns/mistral:7b+x": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
"example.com/ns/mistral:7b+X": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
// preserves case for build
"x+b": {model: "x", build: "b"},
@ -73,7 +87,7 @@ func TestNameParts(t *testing.T) {
}
func TestNamePartString(t *testing.T) {
if g := NamePart(-1).String(); g != "Unknown" {
if g := NamePart(-2).String(); g != "Unknown" {
t.Errorf("Unknown part = %q; want %q", g, "Unknown")
}
for kind, name := range kindNames {
@ -83,34 +97,6 @@ func TestNamePartString(t *testing.T) {
}
}
func TestPartTooLong(t *testing.T) {
for i := Host; i <= Build; i++ {
t.Run(i.String(), func(t *testing.T) {
var p Name
switch i {
case Host:
p.host = strings.Repeat("a", MaxNamePartLen+1)
case Namespace:
p.namespace = strings.Repeat("a", MaxNamePartLen+1)
case Model:
p.model = strings.Repeat("a", MaxNamePartLen+1)
case Tag:
p.tag = strings.Repeat("a", MaxNamePartLen+1)
case Build:
p.build = strings.Repeat("a", MaxNamePartLen+1)
}
s := strings.Trim(p.String(), "+:/")
if len(s) != MaxNamePartLen+1 {
t.Errorf("len(String()) = %d; want %d", len(s), MaxNamePartLen+1)
t.Logf("String() = %q", s)
}
if ParseName(p.String()).Valid() {
t.Errorf("Valid(%q) = true; want false", p)
}
})
}
}
func TestParseName(t *testing.T) {
for baseName, want := range testNames {
for _, prefix := range []string{"", "https://", "http://"} {
@ -119,19 +105,20 @@ func TestParseName(t *testing.T) {
s := prefix + baseName
t.Run(s, func(t *testing.T) {
got := ParseName(s)
if !got.EqualFold(want) {
name := ParseName(s)
got := fieldsFromName(name)
if got != want {
t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
}
// test round-trip
if !ParseName(got.String()).EqualFold(got) {
t.Errorf("String() = %s; want %s", got.String(), baseName)
if !ParseName(name.String()).EqualFold(name) {
t.Errorf("String() = %s; want %s", name.String(), baseName)
}
if got.Valid() && got.model == "" {
if name.Valid() && name.DisplayModel() == "" {
t.Errorf("Valid() = true; Model() = %q; want non-empty name", got.model)
} else if !got.Valid() && got.DisplayModel() != "" {
} else if !name.Valid() && name.DisplayModel() != "" {
t.Errorf("Valid() = false; Model() = %q; want empty name", got.model)
}
})
@ -405,7 +392,7 @@ func TestNameTextMarshal(t *testing.T) {
t.Fatal("MarshalText() = 0; want non-zero")
}
})
if allocs > 1 {
if allocs > 0 {
// TODO: Update when/if this lands:
// https://github.com/golang/go/issues/62384
//
@ -414,6 +401,16 @@ func TestNameTextMarshal(t *testing.T) {
}
}
func TestNameStringAllocs(t *testing.T) {
name := ParseName("example.com/ns/mistral:latest+Q4_0")
allocs := testing.AllocsPerRun(1000, func() {
keep(name.String())
})
if allocs > 1 {
t.Errorf("String allocs = %v; want 0", allocs)
}
}
func ExampleFill() {
defaults := ParseName("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0")
r := Fill(ParseName("mistral"), defaults)