x/model: add MarshalText and UnmarshalText to Name
This commit is contained in:
parent
e201627c63
commit
45d8d22785
@ -8,6 +8,8 @@ import (
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/x/types/structs"
|
||||
)
|
||||
@ -233,6 +235,12 @@ func (r Name) DisplayLong() string {
|
||||
}).String()
|
||||
}
|
||||
|
||||
var builderPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &strings.Builder{}
|
||||
},
|
||||
}
|
||||
|
||||
// String returns the fullest possible display string in form:
|
||||
//
|
||||
// <host>/<namespace>/<model>:<tag>+<build>
|
||||
@ -242,7 +250,17 @@ func (r Name) DisplayLong() string {
|
||||
// For the fullest possible display string without the build, use
|
||||
// [Name.DisplayFullest].
|
||||
func (r Name) String() string {
|
||||
var b strings.Builder
|
||||
b := builderPool.Get().(*strings.Builder)
|
||||
b.Reset()
|
||||
defer builderPool.Put(b)
|
||||
b.Grow(0 +
|
||||
len(r.host) +
|
||||
len(r.namespace) +
|
||||
len(r.model) +
|
||||
len(r.tag) +
|
||||
len(r.build) +
|
||||
4, // 4 possible separators
|
||||
)
|
||||
if r.host != "" {
|
||||
b.WriteString(r.host)
|
||||
b.WriteString("/")
|
||||
@ -282,6 +300,32 @@ func (r Name) LogValue() slog.Value {
|
||||
return slog.StringValue(r.GoString())
|
||||
}
|
||||
|
||||
// MarshalText implements encoding.TextMarshaler.
|
||||
func (r Name) MarshalText() ([]byte, error) {
|
||||
// unsafeBytes is safe here because we gurantee that the string is
|
||||
// never used after this function returns.
|
||||
//
|
||||
// TODO: We can remove this if https://github.com/golang/go/issues/62384
|
||||
// lands.
|
||||
return unsafeBytes(r.String()), nil
|
||||
}
|
||||
|
||||
func unsafeBytes(s string) []byte {
|
||||
return *(*[]byte)(unsafe.Pointer(&s))
|
||||
}
|
||||
|
||||
// UnmarshalText implements encoding.TextUnmarshaler.
|
||||
func (r *Name) UnmarshalText(text []byte) error {
|
||||
// unsafeString is safe here because the contract of UnmarshalText
|
||||
// that text belongs to us for the duration of the call.
|
||||
*r = ParseName(unsafeString(text))
|
||||
return nil
|
||||
}
|
||||
|
||||
func unsafeString(b []byte) string {
|
||||
return *(*string)(unsafe.Pointer(&b))
|
||||
}
|
||||
|
||||
// Complete reports whether the Name is fully qualified. That is it has a
|
||||
// domain, namespace, name, tag, and build.
|
||||
func (r Name) Complete() bool {
|
||||
|
@ -3,6 +3,7 @@ package model
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
@ -352,6 +353,67 @@ func TestFill(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameTextMarshal(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
wantErr error
|
||||
}{
|
||||
{"example.com/mistral:latest+Q4_0", "", nil},
|
||||
{"mistral:latest+Q4_0", "mistral:latest+Q4_0", nil},
|
||||
{"mistral:latest", "mistral:latest", nil},
|
||||
{"mistral", "mistral", nil},
|
||||
{"mistral:7b", "mistral:7b", nil},
|
||||
{"example.com/library/mistral:latest+Q4_0", "example.com/library/mistral:latest+Q4_0", nil},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
p := ParseName(tt.in)
|
||||
got, err := p.MarshalText()
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("MarshalText() error = %v; want %v", err, tt.wantErr)
|
||||
}
|
||||
if string(got) != tt.want {
|
||||
t.Errorf("MarshalText() = %q; want %q", got, tt.want)
|
||||
}
|
||||
|
||||
var r Name
|
||||
if err := r.UnmarshalText(got); err != nil {
|
||||
t.Fatalf("UnmarshalText() error = %v; want nil", err)
|
||||
}
|
||||
if !r.EqualFold(p) {
|
||||
t.Errorf("UnmarshalText() = %q; want %q", r, p)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
var data []byte
|
||||
name := ParseName("example.com/ns/mistral:latest+Q4_0")
|
||||
if !name.Complete() {
|
||||
// sanity check
|
||||
t.Fatal("name is not complete")
|
||||
}
|
||||
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
var err error
|
||||
data, err = name.MarshalText()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
t.Fatal("MarshalText() = 0; want non-zero")
|
||||
}
|
||||
})
|
||||
if allocs > 1 {
|
||||
// TODO: Update when/if this lands:
|
||||
// https://github.com/golang/go/issues/62384
|
||||
//
|
||||
// Currently, the best we can do is 1 alloc.
|
||||
t.Errorf("MarshalText allocs = %v; want <= 1", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleFill() {
|
||||
defaults := ParseName("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0")
|
||||
r := Fill(ParseName("mistral"), defaults)
|
||||
|
Loading…
x
Reference in New Issue
Block a user