x/model: add Digest type

This commit is contained in:
Blake Mizerany 2024-04-07 15:38:11 -07:00
parent 4eb7acf84b
commit 2100129e83
4 changed files with 174 additions and 88 deletions

120
x/model/digest.go Normal file
View File

@ -0,0 +1,120 @@
package model
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"log/slog"
"strings"
"unicode"
)
// Digest is an opaque reference to a model digest. It holds the digest type
// and the digest itself.
//
// It is comparable with other Digests and can be used as a map key.
type Digest struct {
typ string
digest string
}
func (d Digest) Type() string { return d.typ }
func (d Digest) Digest() string { return d.digest }
func (d Digest) Valid() bool { return d != Digest{} }
func (d Digest) String() string {
if !d.Valid() {
return ""
}
return fmt.Sprintf("%s-%s", d.typ, d.digest)
}
func (d Digest) MarshalText() ([]byte, error) {
return []byte(d.String()), nil
}
func (d *Digest) UnmarshalText(text []byte) error {
if d.Valid() {
return errors.New("model.Digest: illegal UnmarshalText on valid Digest")
}
*d = ParseDigest(string(text))
return nil
}
func (d Digest) LogValue() slog.Value {
return slog.StringValue(d.String())
}
var (
_ driver.Valuer = Digest{}
_ sql.Scanner = (*Digest)(nil)
)
func (d *Digest) Scan(src any) error {
if d.Valid() {
return errors.New("model.Digest: illegal Scan on valid Digest")
}
switch v := src.(type) {
case string:
*d = ParseDigest(v)
return nil
case []byte:
*d = ParseDigest(string(v))
return nil
}
return fmt.Errorf("model.Digest: invalid Scan source %T", src)
}
func (d Digest) Value() (driver.Value, error) {
return d.String(), nil
}
// ParseDigest parses a string in the form of "<digest-type>-<digest>" into a
// Digest.
func ParseDigest(s string) Digest {
typ, digest, ok := strings.Cut(s, "-")
if ok && isValidDigestType(typ) && isValidHex(digest) {
return Digest{typ: typ, digest: digest}
}
return Digest{}
}
// isValidDigest returns true if the given string in the form of
// "<digest-type>-<digest>", and <digest-type> is in the form of [a-z0-9]+
// and <digest> is a valid hex string.
//
// It does not check if the digest is a valid hash for the given digest
// type, or restrict the digest type to a known set of types. This is left
// up to ueers of this package.
func isValidDigest(s string) bool {
typ, digest, ok := strings.Cut(s, "-")
res := ok && isValidDigestType(typ) && isValidHex(digest)
fmt.Printf("DEBUG: %q: typ: %s, digest: %s, ok: %v res: %v\n", s, typ, digest, ok, res)
return res
}
func isValidDigestType(s string) bool {
if len(s) == 0 {
return false
}
for _, r := range s {
if !unicode.IsLower(r) && !unicode.IsDigit(r) {
return false
}
}
return true
}
func isValidHex(s string) bool {
if len(s) == 0 {
return false
}
for i := range s {
c := s[i]
if c < '0' || c > '9' && c < 'a' || c > 'f' {
return false
}
}
return true
}

53
x/model/digest_test.go Normal file
View File

@ -0,0 +1,53 @@
package model
import "testing"
// - test scan
// - test marshal text
// - test unmarshal text
// - test log value
// - test string
// - test type
// - test digest
// - test valid
// - test driver valuer
// - test sql scanner
// - test parse digest
var testDigests = map[string]Digest{
"": {},
"sha256-1234": {typ: "sha256", digest: "1234"},
"sha256-5678": {typ: "sha256", digest: "5678"},
"blake2-9abc": {typ: "blake2", digest: "9abc"},
"-1234": {},
"sha256-": {},
"sha256-1234-5678": {},
"sha256-P": {}, // invalid hex
"sha256-1234P": {},
"---": {},
}
func TestDigestParse(t *testing.T) {
// Test cases.
for s, want := range testDigests {
got := ParseDigest(s)
t.Logf("ParseDigest(%q) = %#v", s, got)
if got != want {
t.Errorf("ParseDigest(%q) = %q; want %q", s, got, want)
}
}
}
func TestDigestString(t *testing.T) {
// Test cases.
for s, d := range testDigests {
want := s
if !d.Valid() {
want = ""
}
got := d.String()
if got != want {
t.Errorf("ParseDigest(%q).String() = %q; want %q", s, got, want)
}
}
}

View File

@ -6,7 +6,6 @@ import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"hash/maphash"
"io"
"iter"
@ -14,7 +13,6 @@ import (
"slices"
"strings"
"sync"
"unicode"
"github.com/ollama/ollama/x/types/structs"
)
@ -25,6 +23,7 @@ var (
// other packages do not need to invent their own error type when they
// need to return an error for an invalid name.
ErrIncompleteName = errors.New("incomplete model name")
ErrInvalidDigest = errors.New("invalid digest")
)
const MaxNamePartLen = 128
@ -592,42 +591,3 @@ func isValidByte(kind NamePart, c byte) bool {
}
return false
}
// isValidDigest returns true if the given string in the form of
// "<digest-type>-<digest>", and <digest-type> is in the form of [a-z0-9]+
// and <digest> is a valid hex string.
//
// It does not check if the digest is a valid hash for the given digest
// type, or restrict the digest type to a known set of types. This is left
// up to ueers of this package.
func isValidDigest(s string) bool {
typ, digest, ok := strings.Cut(s, "-")
res := ok && isValidDigestType(typ) && isValidHex(digest)
fmt.Printf("DEBUG: %q: typ: %s, digest: %s, ok: %v res: %v\n", s, typ, digest, ok, res)
return res
}
func isValidDigestType(s string) bool {
if len(s) == 0 {
return false
}
for _, r := range s {
if !unicode.IsLower(r) && !unicode.IsDigit(r) {
return false
}
}
return true
}
func isValidHex(s string) bool {
if len(s) == 0 {
return false
}
for i := range s {
c := s[i]
if c < '0' || c > '9' && c < 'a' || c > 'f' {
return false
}
}
return true
}

View File

@ -117,53 +117,6 @@ func TestNamePartString(t *testing.T) {
}
}
func TestIsValidDigestType(t *testing.T) {
cases := []struct {
in string
want bool
}{
{"sha256", true},
{"blake2", true},
{"", false},
{"-sha256", false},
{"sha256-", false},
{"Sha256", false},
{"sha256(", false},
{" sha256", false},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
if g := isValidDigestType(tt.in); g != tt.want {
t.Errorf("isValidDigestType(%q) = %v; want %v", tt.in, g, tt.want)
}
})
}
}
func TestIsValidDigest(t *testing.T) {
cases := []struct {
in string
want bool
}{
{"", false},
{"sha256-123", true},
{"sha256-1234567890abcdef", true},
{"sha256-1234567890abcdef1234567890abcdeffffffffffffffffffffffffffffffffffffffffff", true},
{"!sha256-123", false},
{"sha256-123!", false},
{"sha256-", false},
{"-123", false},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
if g := isValidDigest(tt.in); g != tt.want {
t.Errorf("isValidDigest(%q) = %v; want %v", tt.in, g, tt.want)
}
})
}
}
func TestParseName(t *testing.T) {
for baseName, want := range testNames {
for _, prefix := range []string{"", "https://", "http://"} {