x/model: add MarshalText and UnmarshalText to Name

This commit is contained in:
Blake Mizerany 2024-04-06 00:10:12 -07:00
parent e201627c63
commit 45d8d22785
2 changed files with 107 additions and 1 deletions

View File

@ -8,6 +8,8 @@ import (
"log/slog"
"slices"
"strings"
"sync"
"unsafe"
"github.com/ollama/ollama/x/types/structs"
)
@ -233,6 +235,12 @@ func (r Name) DisplayLong() string {
}).String()
}
var builderPool = sync.Pool{
New: func() interface{} {
return &strings.Builder{}
},
}
// String returns the fullest possible display string in form:
//
// <host>/<namespace>/<model>:<tag>+<build>
@ -242,7 +250,17 @@ func (r Name) DisplayLong() string {
// For the fullest possible display string without the build, use
// [Name.DisplayFullest].
func (r Name) String() string {
var b strings.Builder
b := builderPool.Get().(*strings.Builder)
b.Reset()
defer builderPool.Put(b)
b.Grow(0 +
len(r.host) +
len(r.namespace) +
len(r.model) +
len(r.tag) +
len(r.build) +
4, // 4 possible separators
)
if r.host != "" {
b.WriteString(r.host)
b.WriteString("/")
@ -282,6 +300,32 @@ func (r Name) LogValue() slog.Value {
return slog.StringValue(r.GoString())
}
// MarshalText implements encoding.TextMarshaler.
func (r Name) MarshalText() ([]byte, error) {
// unsafeBytes is safe here because we gurantee that the string is
// never used after this function returns.
//
// TODO: We can remove this if https://github.com/golang/go/issues/62384
// lands.
return unsafeBytes(r.String()), nil
}
func unsafeBytes(s string) []byte {
return *(*[]byte)(unsafe.Pointer(&s))
}
// UnmarshalText implements encoding.TextUnmarshaler.
func (r *Name) UnmarshalText(text []byte) error {
// unsafeString is safe here because the contract of UnmarshalText
// that text belongs to us for the duration of the call.
*r = ParseName(unsafeString(text))
return nil
}
func unsafeString(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
// Complete reports whether the Name is fully qualified. That is it has a
// domain, namespace, name, tag, and build.
func (r Name) Complete() bool {

View File

@ -3,6 +3,7 @@ package model
import (
"bytes"
"cmp"
"errors"
"fmt"
"log/slog"
"slices"
@ -352,6 +353,67 @@ func TestFill(t *testing.T) {
}
}
func TestNameTextMarshal(t *testing.T) {
cases := []struct {
in string
want string
wantErr error
}{
{"example.com/mistral:latest+Q4_0", "", nil},
{"mistral:latest+Q4_0", "mistral:latest+Q4_0", nil},
{"mistral:latest", "mistral:latest", nil},
{"mistral", "mistral", nil},
{"mistral:7b", "mistral:7b", nil},
{"example.com/library/mistral:latest+Q4_0", "example.com/library/mistral:latest+Q4_0", nil},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
p := ParseName(tt.in)
got, err := p.MarshalText()
if !errors.Is(err, tt.wantErr) {
t.Fatalf("MarshalText() error = %v; want %v", err, tt.wantErr)
}
if string(got) != tt.want {
t.Errorf("MarshalText() = %q; want %q", got, tt.want)
}
var r Name
if err := r.UnmarshalText(got); err != nil {
t.Fatalf("UnmarshalText() error = %v; want nil", err)
}
if !r.EqualFold(p) {
t.Errorf("UnmarshalText() = %q; want %q", r, p)
}
})
}
var data []byte
name := ParseName("example.com/ns/mistral:latest+Q4_0")
if !name.Complete() {
// sanity check
t.Fatal("name is not complete")
}
allocs := testing.AllocsPerRun(1000, func() {
var err error
data, err = name.MarshalText()
if err != nil {
t.Fatal(err)
}
if len(data) == 0 {
t.Fatal("MarshalText() = 0; want non-zero")
}
})
if allocs > 1 {
// TODO: Update when/if this lands:
// https://github.com/golang/go/issues/62384
//
// Currently, the best we can do is 1 alloc.
t.Errorf("MarshalText allocs = %v; want <= 1", allocs)
}
}
func ExampleFill() {
defaults := ParseName("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0")
r := Fill(ParseName("mistral"), defaults)