From 45d8d2278577a6d017706dd557a8d055cded54da Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sat, 6 Apr 2024 00:10:12 -0700 Subject: [PATCH] x/model: add MarshalText and UnmarshalText to Name --- x/model/name.go | 46 +++++++++++++++++++++++++++++++- x/model/name_test.go | 62 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 1 deletion(-) diff --git a/x/model/name.go b/x/model/name.go index 0f06a33c..5ef1392d 100644 --- a/x/model/name.go +++ b/x/model/name.go @@ -8,6 +8,8 @@ import ( "log/slog" "slices" "strings" + "sync" + "unsafe" "github.com/ollama/ollama/x/types/structs" ) @@ -233,6 +235,12 @@ func (r Name) DisplayLong() string { }).String() } +var builderPool = sync.Pool{ + New: func() interface{} { + return &strings.Builder{} + }, +} + // String returns the fullest possible display string in form: // // //:+ @@ -242,7 +250,17 @@ func (r Name) DisplayLong() string { // For the fullest possible display string without the build, use // [Name.DisplayFullest]. func (r Name) String() string { - var b strings.Builder + b := builderPool.Get().(*strings.Builder) + b.Reset() + defer builderPool.Put(b) + b.Grow(0 + + len(r.host) + + len(r.namespace) + + len(r.model) + + len(r.tag) + + len(r.build) + + 4, // 4 possible separators + ) if r.host != "" { b.WriteString(r.host) b.WriteString("/") @@ -282,6 +300,32 @@ func (r Name) LogValue() slog.Value { return slog.StringValue(r.GoString()) } +// MarshalText implements encoding.TextMarshaler. +func (r Name) MarshalText() ([]byte, error) { + // unsafeBytes is safe here because we gurantee that the string is + // never used after this function returns. + // + // TODO: We can remove this if https://github.com/golang/go/issues/62384 + // lands. + return unsafeBytes(r.String()), nil +} + +func unsafeBytes(s string) []byte { + return *(*[]byte)(unsafe.Pointer(&s)) +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (r *Name) UnmarshalText(text []byte) error { + // unsafeString is safe here because the contract of UnmarshalText + // that text belongs to us for the duration of the call. + *r = ParseName(unsafeString(text)) + return nil +} + +func unsafeString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + // Complete reports whether the Name is fully qualified. That is it has a // domain, namespace, name, tag, and build. func (r Name) Complete() bool { diff --git a/x/model/name_test.go b/x/model/name_test.go index bf2937f7..4dcaa052 100644 --- a/x/model/name_test.go +++ b/x/model/name_test.go @@ -3,6 +3,7 @@ package model import ( "bytes" "cmp" + "errors" "fmt" "log/slog" "slices" @@ -352,6 +353,67 @@ func TestFill(t *testing.T) { } } +func TestNameTextMarshal(t *testing.T) { + cases := []struct { + in string + want string + wantErr error + }{ + {"example.com/mistral:latest+Q4_0", "", nil}, + {"mistral:latest+Q4_0", "mistral:latest+Q4_0", nil}, + {"mistral:latest", "mistral:latest", nil}, + {"mistral", "mistral", nil}, + {"mistral:7b", "mistral:7b", nil}, + {"example.com/library/mistral:latest+Q4_0", "example.com/library/mistral:latest+Q4_0", nil}, + } + + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + p := ParseName(tt.in) + got, err := p.MarshalText() + if !errors.Is(err, tt.wantErr) { + t.Fatalf("MarshalText() error = %v; want %v", err, tt.wantErr) + } + if string(got) != tt.want { + t.Errorf("MarshalText() = %q; want %q", got, tt.want) + } + + var r Name + if err := r.UnmarshalText(got); err != nil { + t.Fatalf("UnmarshalText() error = %v; want nil", err) + } + if !r.EqualFold(p) { + t.Errorf("UnmarshalText() = %q; want %q", r, p) + } + }) + } + + var data []byte + name := ParseName("example.com/ns/mistral:latest+Q4_0") + if !name.Complete() { + // sanity check + t.Fatal("name is not complete") + } + + allocs := testing.AllocsPerRun(1000, func() { + var err error + data, err = name.MarshalText() + if err != nil { + t.Fatal(err) + } + if len(data) == 0 { + t.Fatal("MarshalText() = 0; want non-zero") + } + }) + if allocs > 1 { + // TODO: Update when/if this lands: + // https://github.com/golang/go/issues/62384 + // + // Currently, the best we can do is 1 alloc. + t.Errorf("MarshalText allocs = %v; want <= 1", allocs) + } +} + func ExampleFill() { defaults := ParseName("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0") r := Fill(ParseName("mistral"), defaults)