x/model: add Digest type
This commit is contained in:
parent
4eb7acf84b
commit
2100129e83
120
x/model/digest.go
Normal file
120
x/model/digest.go
Normal file
@ -0,0 +1,120 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// Digest is an opaque reference to a model digest. It holds the digest type
|
||||
// and the digest itself.
|
||||
//
|
||||
// It is comparable with other Digests and can be used as a map key.
|
||||
type Digest struct {
|
||||
typ string
|
||||
digest string
|
||||
}
|
||||
|
||||
func (d Digest) Type() string { return d.typ }
|
||||
func (d Digest) Digest() string { return d.digest }
|
||||
func (d Digest) Valid() bool { return d != Digest{} }
|
||||
|
||||
func (d Digest) String() string {
|
||||
if !d.Valid() {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s-%s", d.typ, d.digest)
|
||||
}
|
||||
|
||||
func (d Digest) MarshalText() ([]byte, error) {
|
||||
return []byte(d.String()), nil
|
||||
}
|
||||
|
||||
func (d *Digest) UnmarshalText(text []byte) error {
|
||||
if d.Valid() {
|
||||
return errors.New("model.Digest: illegal UnmarshalText on valid Digest")
|
||||
}
|
||||
*d = ParseDigest(string(text))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d Digest) LogValue() slog.Value {
|
||||
return slog.StringValue(d.String())
|
||||
}
|
||||
|
||||
var (
|
||||
_ driver.Valuer = Digest{}
|
||||
_ sql.Scanner = (*Digest)(nil)
|
||||
)
|
||||
|
||||
func (d *Digest) Scan(src any) error {
|
||||
if d.Valid() {
|
||||
return errors.New("model.Digest: illegal Scan on valid Digest")
|
||||
}
|
||||
switch v := src.(type) {
|
||||
case string:
|
||||
*d = ParseDigest(v)
|
||||
return nil
|
||||
case []byte:
|
||||
*d = ParseDigest(string(v))
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("model.Digest: invalid Scan source %T", src)
|
||||
}
|
||||
|
||||
func (d Digest) Value() (driver.Value, error) {
|
||||
return d.String(), nil
|
||||
}
|
||||
|
||||
// ParseDigest parses a string in the form of "<digest-type>-<digest>" into a
|
||||
// Digest.
|
||||
func ParseDigest(s string) Digest {
|
||||
typ, digest, ok := strings.Cut(s, "-")
|
||||
if ok && isValidDigestType(typ) && isValidHex(digest) {
|
||||
return Digest{typ: typ, digest: digest}
|
||||
}
|
||||
return Digest{}
|
||||
}
|
||||
|
||||
// isValidDigest returns true if the given string in the form of
|
||||
// "<digest-type>-<digest>", and <digest-type> is in the form of [a-z0-9]+
|
||||
// and <digest> is a valid hex string.
|
||||
//
|
||||
// It does not check if the digest is a valid hash for the given digest
|
||||
// type, or restrict the digest type to a known set of types. This is left
|
||||
// up to ueers of this package.
|
||||
func isValidDigest(s string) bool {
|
||||
typ, digest, ok := strings.Cut(s, "-")
|
||||
res := ok && isValidDigestType(typ) && isValidHex(digest)
|
||||
fmt.Printf("DEBUG: %q: typ: %s, digest: %s, ok: %v res: %v\n", s, typ, digest, ok, res)
|
||||
return res
|
||||
}
|
||||
|
||||
func isValidDigestType(s string) bool {
|
||||
if len(s) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if !unicode.IsLower(r) && !unicode.IsDigit(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidHex(s string) bool {
|
||||
if len(s) == 0 {
|
||||
return false
|
||||
}
|
||||
for i := range s {
|
||||
c := s[i]
|
||||
if c < '0' || c > '9' && c < 'a' || c > 'f' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
53
x/model/digest_test.go
Normal file
53
x/model/digest_test.go
Normal file
@ -0,0 +1,53 @@
|
||||
package model
|
||||
|
||||
import "testing"
|
||||
|
||||
// - test scan
|
||||
// - test marshal text
|
||||
// - test unmarshal text
|
||||
// - test log value
|
||||
// - test string
|
||||
// - test type
|
||||
// - test digest
|
||||
// - test valid
|
||||
// - test driver valuer
|
||||
// - test sql scanner
|
||||
// - test parse digest
|
||||
|
||||
var testDigests = map[string]Digest{
|
||||
"": {},
|
||||
"sha256-1234": {typ: "sha256", digest: "1234"},
|
||||
"sha256-5678": {typ: "sha256", digest: "5678"},
|
||||
"blake2-9abc": {typ: "blake2", digest: "9abc"},
|
||||
"-1234": {},
|
||||
"sha256-": {},
|
||||
"sha256-1234-5678": {},
|
||||
"sha256-P": {}, // invalid hex
|
||||
"sha256-1234P": {},
|
||||
"---": {},
|
||||
}
|
||||
|
||||
func TestDigestParse(t *testing.T) {
|
||||
// Test cases.
|
||||
for s, want := range testDigests {
|
||||
got := ParseDigest(s)
|
||||
t.Logf("ParseDigest(%q) = %#v", s, got)
|
||||
if got != want {
|
||||
t.Errorf("ParseDigest(%q) = %q; want %q", s, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDigestString(t *testing.T) {
|
||||
// Test cases.
|
||||
for s, d := range testDigests {
|
||||
want := s
|
||||
if !d.Valid() {
|
||||
want = ""
|
||||
}
|
||||
got := d.String()
|
||||
if got != want {
|
||||
t.Errorf("ParseDigest(%q).String() = %q; want %q", s, got, want)
|
||||
}
|
||||
}
|
||||
}
|
@ -6,7 +6,6 @@ import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"io"
|
||||
"iter"
|
||||
@ -14,7 +13,6 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/x/types/structs"
|
||||
)
|
||||
@ -25,6 +23,7 @@ var (
|
||||
// other packages do not need to invent their own error type when they
|
||||
// need to return an error for an invalid name.
|
||||
ErrIncompleteName = errors.New("incomplete model name")
|
||||
ErrInvalidDigest = errors.New("invalid digest")
|
||||
)
|
||||
|
||||
const MaxNamePartLen = 128
|
||||
@ -592,42 +591,3 @@ func isValidByte(kind NamePart, c byte) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isValidDigest returns true if the given string in the form of
|
||||
// "<digest-type>-<digest>", and <digest-type> is in the form of [a-z0-9]+
|
||||
// and <digest> is a valid hex string.
|
||||
//
|
||||
// It does not check if the digest is a valid hash for the given digest
|
||||
// type, or restrict the digest type to a known set of types. This is left
|
||||
// up to ueers of this package.
|
||||
func isValidDigest(s string) bool {
|
||||
typ, digest, ok := strings.Cut(s, "-")
|
||||
res := ok && isValidDigestType(typ) && isValidHex(digest)
|
||||
fmt.Printf("DEBUG: %q: typ: %s, digest: %s, ok: %v res: %v\n", s, typ, digest, ok, res)
|
||||
return res
|
||||
}
|
||||
|
||||
func isValidDigestType(s string) bool {
|
||||
if len(s) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if !unicode.IsLower(r) && !unicode.IsDigit(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidHex(s string) bool {
|
||||
if len(s) == 0 {
|
||||
return false
|
||||
}
|
||||
for i := range s {
|
||||
c := s[i]
|
||||
if c < '0' || c > '9' && c < 'a' || c > 'f' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
@ -117,53 +117,6 @@ func TestNamePartString(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidDigestType(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{"sha256", true},
|
||||
{"blake2", true},
|
||||
|
||||
{"", false},
|
||||
{"-sha256", false},
|
||||
{"sha256-", false},
|
||||
{"Sha256", false},
|
||||
{"sha256(", false},
|
||||
{" sha256", false},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
if g := isValidDigestType(tt.in); g != tt.want {
|
||||
t.Errorf("isValidDigestType(%q) = %v; want %v", tt.in, g, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidDigest(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{"", false},
|
||||
{"sha256-123", true},
|
||||
{"sha256-1234567890abcdef", true},
|
||||
{"sha256-1234567890abcdef1234567890abcdeffffffffffffffffffffffffffffffffffffffffff", true},
|
||||
{"!sha256-123", false},
|
||||
{"sha256-123!", false},
|
||||
{"sha256-", false},
|
||||
{"-123", false},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
if g := isValidDigest(tt.in); g != tt.want {
|
||||
t.Errorf("isValidDigest(%q) = %v; want %v", tt.in, g, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseName(t *testing.T) {
|
||||
for baseName, want := range testNames {
|
||||
for _, prefix := range []string{"", "https://", "http://"} {
|
||||
|
Loading…
x
Reference in New Issue
Block a user