diff --git a/x/model/digest.go b/x/model/digest.go new file mode 100644 index 00000000..b985a4d0 --- /dev/null +++ b/x/model/digest.go @@ -0,0 +1,120 @@ +package model + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "log/slog" + "strings" + "unicode" +) + +// Digest is an opaque reference to a model digest. It holds the digest type +// and the digest itself. +// +// It is comparable with other Digests and can be used as a map key. +type Digest struct { + typ string + digest string +} + +func (d Digest) Type() string { return d.typ } +func (d Digest) Digest() string { return d.digest } +func (d Digest) Valid() bool { return d != Digest{} } + +func (d Digest) String() string { + if !d.Valid() { + return "" + } + return fmt.Sprintf("%s-%s", d.typ, d.digest) +} + +func (d Digest) MarshalText() ([]byte, error) { + return []byte(d.String()), nil +} + +func (d *Digest) UnmarshalText(text []byte) error { + if d.Valid() { + return errors.New("model.Digest: illegal UnmarshalText on valid Digest") + } + *d = ParseDigest(string(text)) + return nil +} + +func (d Digest) LogValue() slog.Value { + return slog.StringValue(d.String()) +} + +var ( + _ driver.Valuer = Digest{} + _ sql.Scanner = (*Digest)(nil) +) + +func (d *Digest) Scan(src any) error { + if d.Valid() { + return errors.New("model.Digest: illegal Scan on valid Digest") + } + switch v := src.(type) { + case string: + *d = ParseDigest(v) + return nil + case []byte: + *d = ParseDigest(string(v)) + return nil + } + return fmt.Errorf("model.Digest: invalid Scan source %T", src) +} + +func (d Digest) Value() (driver.Value, error) { + return d.String(), nil +} + +// ParseDigest parses a string in the form of "-" into a +// Digest. +func ParseDigest(s string) Digest { + typ, digest, ok := strings.Cut(s, "-") + if ok && isValidDigestType(typ) && isValidHex(digest) { + return Digest{typ: typ, digest: digest} + } + return Digest{} +} + +// isValidDigest returns true if the given string in the form of +// "-", and is in the form of [a-z0-9]+ +// and is a valid hex string. +// +// It does not check if the digest is a valid hash for the given digest +// type, or restrict the digest type to a known set of types. This is left +// up to ueers of this package. +func isValidDigest(s string) bool { + typ, digest, ok := strings.Cut(s, "-") + res := ok && isValidDigestType(typ) && isValidHex(digest) + fmt.Printf("DEBUG: %q: typ: %s, digest: %s, ok: %v res: %v\n", s, typ, digest, ok, res) + return res +} + +func isValidDigestType(s string) bool { + if len(s) == 0 { + return false + } + for _, r := range s { + if !unicode.IsLower(r) && !unicode.IsDigit(r) { + return false + } + } + return true +} + +func isValidHex(s string) bool { + if len(s) == 0 { + return false + } + for i := range s { + c := s[i] + if c < '0' || c > '9' && c < 'a' || c > 'f' { + return false + } + } + return true +} diff --git a/x/model/digest_test.go b/x/model/digest_test.go new file mode 100644 index 00000000..b6a25946 --- /dev/null +++ b/x/model/digest_test.go @@ -0,0 +1,53 @@ +package model + +import "testing" + +// - test scan +// - test marshal text +// - test unmarshal text +// - test log value +// - test string +// - test type +// - test digest +// - test valid +// - test driver valuer +// - test sql scanner +// - test parse digest + +var testDigests = map[string]Digest{ + "": {}, + "sha256-1234": {typ: "sha256", digest: "1234"}, + "sha256-5678": {typ: "sha256", digest: "5678"}, + "blake2-9abc": {typ: "blake2", digest: "9abc"}, + "-1234": {}, + "sha256-": {}, + "sha256-1234-5678": {}, + "sha256-P": {}, // invalid hex + "sha256-1234P": {}, + "---": {}, +} + +func TestDigestParse(t *testing.T) { + // Test cases. + for s, want := range testDigests { + got := ParseDigest(s) + t.Logf("ParseDigest(%q) = %#v", s, got) + if got != want { + t.Errorf("ParseDigest(%q) = %q; want %q", s, got, want) + } + } +} + +func TestDigestString(t *testing.T) { + // Test cases. + for s, d := range testDigests { + want := s + if !d.Valid() { + want = "" + } + got := d.String() + if got != want { + t.Errorf("ParseDigest(%q).String() = %q; want %q", s, got, want) + } + } +} diff --git a/x/model/name.go b/x/model/name.go index 14f0bad0..8070f49d 100644 --- a/x/model/name.go +++ b/x/model/name.go @@ -6,7 +6,6 @@ import ( "database/sql" "database/sql/driver" "errors" - "fmt" "hash/maphash" "io" "iter" @@ -14,7 +13,6 @@ import ( "slices" "strings" "sync" - "unicode" "github.com/ollama/ollama/x/types/structs" ) @@ -25,6 +23,7 @@ var ( // other packages do not need to invent their own error type when they // need to return an error for an invalid name. ErrIncompleteName = errors.New("incomplete model name") + ErrInvalidDigest = errors.New("invalid digest") ) const MaxNamePartLen = 128 @@ -592,42 +591,3 @@ func isValidByte(kind NamePart, c byte) bool { } return false } - -// isValidDigest returns true if the given string in the form of -// "-", and is in the form of [a-z0-9]+ -// and is a valid hex string. -// -// It does not check if the digest is a valid hash for the given digest -// type, or restrict the digest type to a known set of types. This is left -// up to ueers of this package. -func isValidDigest(s string) bool { - typ, digest, ok := strings.Cut(s, "-") - res := ok && isValidDigestType(typ) && isValidHex(digest) - fmt.Printf("DEBUG: %q: typ: %s, digest: %s, ok: %v res: %v\n", s, typ, digest, ok, res) - return res -} - -func isValidDigestType(s string) bool { - if len(s) == 0 { - return false - } - for _, r := range s { - if !unicode.IsLower(r) && !unicode.IsDigit(r) { - return false - } - } - return true -} - -func isValidHex(s string) bool { - if len(s) == 0 { - return false - } - for i := range s { - c := s[i] - if c < '0' || c > '9' && c < 'a' || c > 'f' { - return false - } - } - return true -} diff --git a/x/model/name_test.go b/x/model/name_test.go index 65b45557..54e0dcc5 100644 --- a/x/model/name_test.go +++ b/x/model/name_test.go @@ -117,53 +117,6 @@ func TestNamePartString(t *testing.T) { } } -func TestIsValidDigestType(t *testing.T) { - cases := []struct { - in string - want bool - }{ - {"sha256", true}, - {"blake2", true}, - - {"", false}, - {"-sha256", false}, - {"sha256-", false}, - {"Sha256", false}, - {"sha256(", false}, - {" sha256", false}, - } - for _, tt := range cases { - t.Run(tt.in, func(t *testing.T) { - if g := isValidDigestType(tt.in); g != tt.want { - t.Errorf("isValidDigestType(%q) = %v; want %v", tt.in, g, tt.want) - } - }) - } -} - -func TestIsValidDigest(t *testing.T) { - cases := []struct { - in string - want bool - }{ - {"", false}, - {"sha256-123", true}, - {"sha256-1234567890abcdef", true}, - {"sha256-1234567890abcdef1234567890abcdeffffffffffffffffffffffffffffffffffffffffff", true}, - {"!sha256-123", false}, - {"sha256-123!", false}, - {"sha256-", false}, - {"-123", false}, - } - for _, tt := range cases { - t.Run(tt.in, func(t *testing.T) { - if g := isValidDigest(tt.in); g != tt.want { - t.Errorf("isValidDigest(%q) = %v; want %v", tt.in, g, tt.want) - } - }) - } -} - func TestParseName(t *testing.T) { for baseName, want := range testNames { for _, prefix := range []string{"", "https://", "http://"} {