From 1407fd3d4a00b5d05b81c6482604a9f58c7c8449 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sun, 7 Apr 2024 10:25:44 -0700 Subject: [PATCH] x/model: Name: implement sql.Scanner and driver.Valuer --- x/model/name.go | 34 +++++++++++++++++++++++++++++++++- x/model/name_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/x/model/name.go b/x/model/name.go index 147e2acc..77ab08af 100644 --- a/x/model/name.go +++ b/x/model/name.go @@ -3,6 +3,8 @@ package model import ( "bytes" "cmp" + "database/sql" + "database/sql/driver" "errors" "hash/maphash" "io" @@ -330,7 +332,7 @@ func (r *Name) UnmarshalText(text []byte) error { // called on an invalid/zero Name. If we allow UnmarshalText // on a valid Name, then the Name will be mutated, breaking // the immutability of the Name. - return errors.New("model.Name: UnmarshalText on valid Name") + return errors.New("model.Name: illegal UnmarshalText on valid Name") } // The contract of UnmarshalText is that we copy to keep the text. @@ -338,6 +340,36 @@ func (r *Name) UnmarshalText(text []byte) error { return nil } +var ( + _ driver.Valuer = Name{} + _ sql.Scanner = (*Name)(nil) +) + +// Scan implements [database/sql.Scanner]. +func (r *Name) Scan(src any) error { + if r.Valid() { + // The invariant of Scan is that it should only be called on an + // invalid/zero Name. If we allow Scan on a valid Name, then the + // Name will be mutated, breaking the immutability of the Name. + return errors.New("model.Name: illegal Scan on valid Name") + + } + switch v := src.(type) { + case string: + *r = ParseName(v) + return nil + case []byte: + *r = ParseName(string(v)) + return nil + } + return errors.New("model.Name: invalid Scan source") +} + +// Value implements [database/sql/driver.Valuer]. +func (r Name) Value() (driver.Value, error) { + return r.String(), nil +} + // Complete reports whether the Name is fully qualified. That is it has a // domain, namespace, name, tag, and build. func (r Name) Complete() bool { diff --git a/x/model/name_test.go b/x/model/name_test.go index 94818511..16140a86 100644 --- a/x/model/name_test.go +++ b/x/model/name_test.go @@ -406,6 +406,32 @@ func TestNameTextUnmarshalCallOnValidName(t *testing.T) { } } +func TestSQL(t *testing.T) { + t.Run("Scan for already valid Name", func(t *testing.T) { + p := mustParse("x") + if err := p.Scan("mistral:latest+Q4_0"); err == nil { + t.Error("Scan() = nil; want error") + } + }) + t.Run("Scan for invalid Name", func(t *testing.T) { + p := Name{} + if err := p.Scan("mistral:latest+Q4_0"); err != nil { + t.Errorf("Scan() = %v; want nil", err) + } + if p.String() != "mistral:latest+Q4_0" { + t.Errorf("String() = %q; want %q", p, "mistral:latest+Q4_0") + } + }) + t.Run("Value", func(t *testing.T) { + p := mustParse("x") + if g, err := p.Value(); err != nil { + t.Errorf("Value() error = %v; want nil", err) + } else if g != "x" { + t.Errorf("Value() = %q; want %q", g, "x") + } + }) +} + func TestNameTextMarshalAllocs(t *testing.T) { var data []byte name := ParseName("example.com/ns/mistral:latest+Q4_0")