build/blob: fix awkward Ref type

This commit is contained in:
Blake Mizerany 2024-04-01 21:19:58 -07:00
parent fd411b3cf6
commit 7cfc8a0838
8 changed files with 325 additions and 122 deletions

View File

@ -7,7 +7,6 @@ import (
"os"
"bllamo.com/build"
"bllamo.com/build/blob"
"bllamo.com/client/ollama/apitype"
"bllamo.com/oweb"
"bllamo.com/registry"
@ -56,12 +55,7 @@ func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error {
const registryURLTODO = "http://localhost:8888"
ref := blob.ParseRef(params.Name)
if !ref.FullyQualified() {
return errUnqualifiedRef
}
man, err := s.Build.Manifest(ref)
man, err := s.Build.Manifest(params.Name)
if err != nil {
if errors.Is(err, build.ErrNotFound) {
return errRefNotFound

View File

@ -2,33 +2,99 @@ package blob
import (
"cmp"
"path"
"path/filepath"
"fmt"
"slices"
"strings"
)
// Levels of concreteness
const (
domain = iota
namespace
name
tag
build
)
// Ref is an opaque reference to a blob.
//
// It is comparable and can be used as a map key.
//
// Users or Ref must check Valid before using it.
type Ref struct {
domain string
name string
tag string
build string
domain string
namespace string
name string
tag string
build string
}
// WithDomain returns a copy of r with the provided domain. If the provided
// domain is empty, it returns the short, unqualified copy of r.
func (r Ref) WithDomain(s string) Ref {
return with(r, domain, s)
}
// WithNamespace returns a copy of r with the provided namespace. If the
// provided namespace is empty, it returns the short, unqualified copy of r.
func (r Ref) WithNamespace(s string) Ref {
return with(r, namespace, s)
}
func (r Ref) WithTag(s string) Ref {
return with(r, tag, s)
}
// WithBuild returns a copy of r with the provided build. If the provided
// build is empty, it returns the short, unqualified copy of r.
func (r Ref) WithBuild(build string) Ref {
if build == "" {
return Ref{r.domain, r.name, r.tag, ""}
}
if !isValidPart(build) {
func (r Ref) WithBuild(s string) Ref {
return with(r, build, s)
}
func with(r Ref, part int, value string) Ref {
if value != "" && !isValidPart(value) {
return Ref{}
}
return makeRef(r.domain, r.name, r.tag, build)
switch part {
case domain:
r.domain = value
case namespace:
r.namespace = value
case name:
r.name = value
case tag:
r.tag = value
case build:
r.build = value
default:
panic(fmt.Sprintf("invalid completeness: %d", part))
}
return r
}
// Format returns a string representation of the ref with the given
// concreteness. If a part is missing, it is replaced with a loud
// placeholder.
func (r Ref) Full() string {
r.domain = cmp.Or(r.domain, "!(MISSING DOMAIN)")
r.namespace = cmp.Or(r.namespace, "!(MISSING NAMESPACE)")
r.name = cmp.Or(r.name, "!(MISSING NAME)")
r.tag = cmp.Or(r.tag, "!(MISSING TAG)")
r.build = cmp.Or(r.build, "!(MISSING BUILD)")
return r.String()
}
func (r Ref) NameAndTag() string {
r.domain = ""
r.namespace = ""
r.build = ""
return r.String()
}
func (r Ref) NameTagAndBuild() string {
r.domain = ""
r.namespace = ""
return r.String()
}
// String returns the fully qualified ref string.
@ -38,6 +104,10 @@ func (r Ref) String() string {
b.WriteString(r.domain)
b.WriteString("/")
}
if r.namespace != "" {
b.WriteString(r.namespace)
b.WriteString("/")
}
b.WriteString(r.name)
if r.tag != "" {
b.WriteString(":")
@ -50,40 +120,41 @@ func (r Ref) String() string {
return b.String()
}
// Full returns the fully qualified ref string, or a string indicating the
// build is missing, or an empty string if the ref is invalid.
func (r Ref) Full() string {
if !r.Valid() {
return ""
// Complete returns true if the ref is valid and has no empty parts.
func (r Ref) Complete() bool {
return r.Valid() && !slices.Contains(r.Parts(), "")
}
// Less returns true if r is less concrete than o; false otherwise.
func (r Ref) Less(o Ref) bool {
rp := r.Parts()
op := o.Parts()
for i := range rp {
if rp[i] < op[i] {
return true
}
}
return makeRef(r.domain, r.name, r.tag, cmp.Or(r.build, "!(MISSING BUILD)")).String()
return false
}
// Short returns the short ref string which does not include the build.
func (r Ref) Short() string {
return r.WithBuild("").String()
// Parts returns the parts of the ref in order of concreteness.
//
// The length of the returned slice is always 5.
func (r Ref) Parts() []string {
return []string{
domain: r.domain,
namespace: r.namespace,
name: r.name,
tag: r.tag,
build: r.build,
}
}
func (r Ref) Valid() bool {
return r.name != ""
}
func (r Ref) FullyQualified() bool {
return r.name != "" && r.tag != "" && r.build != ""
}
func (r Ref) Path() string {
return path.Join(r.domain, r.name, r.tag, r.build)
}
func (r Ref) Filepath() string {
return filepath.Join(r.domain, r.name, r.tag, r.build)
}
func (r Ref) Domain() string { return r.domain }
func (r Ref) Name() string { return r.name }
func (r Ref) Tag() string { return r.tag }
func (r Ref) Build() string { return r.build }
func (r Ref) Domain() string { return r.namespace }
func (r Ref) Namespace() string { return r.namespace }
func (r Ref) Name() string { return r.name }
func (r Ref) Tag() string { return r.tag }
func (r Ref) Build() string { return r.build }
// ParseRef parses a ref string into a Ref. A ref string is a name, an
// optional tag, and an optional build, separated by colons and pluses.
@ -112,25 +183,86 @@ func ParseRef(s string) Ref {
return Ref{}
}
nameAndTag, build, expectBuild := strings.Cut(s, "+")
name, tag, expectTag := strings.Cut(nameAndTag, ":")
if !isValidPart(name) {
return Ref{}
if strings.HasPrefix(s, "http://") {
s = s[len("http://"):]
}
if expectTag && !isValidPart(tag) {
return Ref{}
}
if expectBuild && !isValidPart(build) {
return Ref{}
if strings.HasPrefix(s, "https://") {
s = s[len("https://"):]
}
const TODO = "registry.ollama.ai"
return makeRef(TODO, name, tag, build)
var r Ref
state, j := build, len(s)
for i := len(s) - 1; i >= 0; i-- {
c := s[i]
switch c {
case '+':
switch state {
case build:
r.build = s[i+1 : j]
r.build = strings.ToUpper(r.build)
state, j = tag, i
default:
return Ref{}
}
case ':':
switch state {
case build, tag:
r.tag = s[i+1 : j]
state, j = name, i
default:
return Ref{}
}
case '/':
switch state {
case name, tag, build:
r.name = s[i+1 : j]
state, j = namespace, i
case namespace:
r.namespace = s[i+1 : j]
state, j = domain, i
default:
return Ref{}
}
}
}
// handle the first part based on final state
switch state {
case domain:
r.domain = s[:j]
case namespace:
r.namespace = s[:j]
default:
r.name = s[:j]
}
if !r.Valid() {
return Ref{}
}
return r
}
// makeRef makes a ref, skipping validation.
func makeRef(domain, name, tag, build string) Ref {
return Ref{domain, name, cmp.Or(tag, "latest"), strings.ToUpper(build)}
func (r Ref) Valid() bool {
// Name is required
if !isValidPart(r.name) {
return false
}
// Optional parts must be valid if present
if r.domain != "" && !isValidPart(r.domain) {
return false
}
if r.namespace != "" && !isValidPart(r.namespace) {
return false
}
if r.tag != "" && !isValidPart(r.tag) {
return false
}
if r.build != "" && !isValidPart(r.build) {
return false
}
return true
}
// isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-]

View File

@ -7,29 +7,63 @@ const (
refTooLong = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
)
func TestRefParts(t *testing.T) {
const wantNumParts = 5
var ref Ref
if len(ref.Parts()) != wantNumParts {
t.Errorf("Parts() = %d; want %d", len(ref.Parts()), wantNumParts)
}
}
func TestParseRef(t *testing.T) {
cases := []struct {
in string
want Ref
}{
{"mistral:latest", Ref{"registry.ollama.ai", "mistral", "latest", ""}},
{"mistral", Ref{"registry.ollama.ai", "mistral", "latest", ""}},
{"mistral:30B", Ref{"registry.ollama.ai", "mistral", "30B", ""}},
{"mistral:7b", Ref{"registry.ollama.ai", "mistral", "7b", ""}},
{"mistral:7b+Q4_0", Ref{"registry.ollama.ai", "mistral", "7b", "Q4_0"}},
{"mistral+KQED", Ref{"registry.ollama.ai", "mistral", "latest", "KQED"}},
{"mistral.x-3:7b+Q4_0", Ref{"registry.ollama.ai", "mistral.x-3", "7b", "Q4_0"}},
{"mistral:latest", Ref{
name: "mistral",
tag: "latest",
}},
{"mistral", Ref{
name: "mistral",
}},
{"mistral:30B", Ref{
name: "mistral",
tag: "30B",
}},
{"mistral:7b", Ref{
name: "mistral",
tag: "7b",
}},
{"mistral:7b+Q4_0", Ref{
name: "mistral",
tag: "7b",
build: "Q4_0",
}},
{"mistral+KQED", Ref{
name: "mistral",
build: "KQED",
}},
{"mistral.x-3:7b+Q4_0", Ref{
name: "mistral.x-3",
tag: "7b",
build: "Q4_0",
}},
// lowecase build
{"mistral:7b+q4_0", Ref{"registry.ollama.ai", "mistral", "7b", "Q4_0"}},
{"mistral:7b+q4_0", Ref{
name: "mistral",
tag: "7b",
build: "Q4_0",
}},
{"llama2:+", Ref{name: "llama2"}},
// Invalid
{"mistral:7b+Q4_0:latest", Ref{"", "", "", ""}},
{"mi tral", Ref{"", "", "", ""}},
{"llama2:+", Ref{"", "", "", ""}},
{"mistral:7b+Q4_0:latest", Ref{}},
{"mi tral", Ref{}},
// too long
{refTooLong, Ref{"", "", "", ""}},
{refTooLong, Ref{}},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
@ -42,25 +76,29 @@ func TestParseRef(t *testing.T) {
}
func TestRefFull(t *testing.T) {
const empty = "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/!(MISSING NAME):!(MISSING TAG)+!(MISSING BUILD)"
cases := []struct {
in string
wantShort string
wantFull string
in string
wantFull string
}{
{"", "", ""},
{"mistral:7b+x", "registry.ollama.ai/mistral:7b", "registry.ollama.ai/mistral:7b+X"},
{"mistral:7b+Q4_0", "registry.ollama.ai/mistral:7b", "registry.ollama.ai/mistral:7b+Q4_0"},
{"mistral:latest", "registry.ollama.ai/mistral:latest", "registry.ollama.ai/mistral:latest+!(MISSING BUILD)"},
{"mistral", "registry.ollama.ai/mistral:latest", "registry.ollama.ai/mistral:latest+!(MISSING BUILD)"},
{"mistral:30b", "registry.ollama.ai/mistral:30b", "registry.ollama.ai/mistral:30b+!(MISSING BUILD)"},
{"", empty},
{"example.com/mistral:7b+x", "!(MISSING DOMAIN)/example.com/mistral:7b+X"},
{"example.com/mistral:7b+Q4_0", "!(MISSING DOMAIN)/example.com/mistral:7b+Q4_0"},
{"example.com/x/mistral:latest", "example.com/x/mistral:latest+!(MISSING BUILD)"},
{"example.com/x/mistral:latest+Q4_0", "example.com/x/mistral:latest+Q4_0"},
{"mistral:7b+x", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:7b+X"},
{"mistral:7b+Q4_0", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:7b+Q4_0"},
{"mistral:latest", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:latest+!(MISSING BUILD)"},
{"mistral", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:!(MISSING TAG)+!(MISSING BUILD)"},
{"mistral:30b", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:30b+!(MISSING BUILD)"},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
ref := ParseRef(tt.in)
if g := ref.Short(); g != tt.wantShort {
t.Errorf("Short(%q) = %q; want %q", tt.in, g, tt.wantShort)
}
t.Logf("ParseRef(%q) = %#v", tt.in, ref)
if g := ref.Full(); g != tt.wantFull {
t.Errorf("Full(%q) = %q; want %q", tt.in, g, tt.wantFull)
}

View File

@ -14,7 +14,8 @@ import (
// Errors
var (
ErrInvalidRef = errors.New("invalid ref")
ErrRefUnqualified = errors.New("unqualified ref")
ErrRefBuildPresent = errors.New("ref too long")
ErrUnsupportedModelFormat = errors.New("unsupported model format")
ErrMissingFileType = errors.New("missing 'general.file_type' key")
ErrNoSuchBlob = errors.New("no such blob")
@ -53,14 +54,12 @@ func Open(dir string) (*Server, error) {
func (s *Server) Build(ref string, f model.File) error {
br := blob.ParseRef(ref)
if !br.Valid() {
return invalidRef(ref)
if !br.Complete() {
return fmt.Errorf("%w: %q", ErrRefUnqualified, br.Full())
}
// 1. Resolve FROM
// a. If it's a local file (gguf), hash it and add it to the store.
// b. If it's a local dir (safetensor), convert to gguf and add to
// store.
// c. If it's a remote file (http), refuse.
// 2. Turn other pragmas into layers, and add them to the store.
// 3. Create a manifest from the layers.
@ -109,17 +108,22 @@ func (s *Server) LayerFile(digest string) (string, error) {
return fileName, nil
}
func (s *Server) Manifest(ref blob.Ref) ([]byte, error) {
data, _, err := s.getManifestData(ref)
if errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("%w: %q", ErrNotFound, ref)
func (s *Server) Manifest(ref string) ([]byte, error) {
br, err := parseFullRef(ref)
if err != nil {
return nil, err
}
data, _, err := s.getManifestData(br)
return data, err
}
// WeightFile returns the absolute path to the weights file for the given model ref.
func (s *Server) WeightsFile(ref blob.Ref) (string, error) {
m, err := s.getManifest(ref)
func (s *Server) WeightsFile(ref string) (string, error) {
br, err := parseFullRef(ref)
if err != nil {
return "", err
}
m, err := s.getManifest(br)
if err != nil {
return "", err
}
@ -157,9 +161,17 @@ func (s *Server) getManifest(ref blob.Ref) (manifestJSON, error) {
}
func (s *Server) getManifestData(ref blob.Ref) (data []byte, path string, err error) {
return s.st.Resolve(ref)
data, path, err = s.st.Resolve(ref)
if errors.Is(err, blobstore.ErrUnknownRef) {
return nil, "", fmt.Errorf("%w: %q", ErrNotFound, ref)
}
return data, path, err
}
func invalidRef(ref string) error {
return fmt.Errorf("%w: %q", ErrInvalidRef, ref)
func parseFullRef(ref string) (blob.Ref, error) {
br := blob.ParseRef(ref)
if !br.Complete() {
return blob.Ref{}, fmt.Errorf("%w: %q", ErrRefUnqualified, ref)
}
return br, nil
}

View File

@ -6,11 +6,12 @@ import (
"path/filepath"
"testing"
"bllamo.com/build/blob"
"bllamo.com/encoding/gguf"
"bllamo.com/model"
)
const qualifiedRef = "x/y/z:latest+Q4_0"
func TestServerBuildErrors(t *testing.T) {
dir := t.TempDir()
@ -19,8 +20,15 @@ func TestServerBuildErrors(t *testing.T) {
t.Fatal(err)
}
t.Run("unqualified ref", func(t *testing.T) {
err := s.Build("x", model.File{})
if !errors.Is(err, ErrRefUnqualified) {
t.Fatalf("Build() err = %v; want unqualified ref", err)
}
})
t.Run("FROM pragma missing", func(t *testing.T) {
err := s.Build("foo", model.File{})
err := s.Build(qualifiedRef, model.File{})
var e *model.Error
if !errors.As(err, &e) {
t.Fatalf("unexpected error: %v", err)
@ -34,7 +42,7 @@ func TestServerBuildErrors(t *testing.T) {
})
t.Run("FROM file not found", func(t *testing.T) {
err := s.Build("x", model.File{From: "bar"})
err := s.Build(qualifiedRef, model.File{From: "bar"})
if !errors.Is(err, os.ErrNotExist) {
t.Fatalf("Build() err = %v; want file not found", err)
}
@ -51,7 +59,7 @@ func TestServerBuildErrors(t *testing.T) {
"",
)
err := s.Build("x", model.File{From: w.fileName("gguf")})
err := s.Build(qualifiedRef, model.File{From: w.fileName("gguf")})
if !errors.Is(err, ErrMissingFileType) {
t.Fatalf("Build() err = %#v; want missing file type", err)
}
@ -60,7 +68,7 @@ func TestServerBuildErrors(t *testing.T) {
t.Run("FROM obscure dir", func(t *testing.T) {
w := newWorkDir(t)
w.mkdirAll("unknown")
if err := s.Build("x", model.File{From: w.fileName("unknown")}); err != ErrUnsupportedModelFormat {
if err := s.Build(qualifiedRef, model.File{From: w.fileName("unknown")}); err != ErrUnsupportedModelFormat {
t.Fatalf("Build() err = %#v; want unsupported model type", err)
}
})
@ -68,7 +76,7 @@ func TestServerBuildErrors(t *testing.T) {
t.Run("FROM unsupported model type", func(t *testing.T) {
w := newWorkDir(t)
from := w.write("unknown", "unknown content")
err := s.Build("x", model.File{From: from})
err := s.Build(qualifiedRef, model.File{From: from})
if !errors.Is(err, ErrUnsupportedModelFormat) {
t.Fatalf("Build() err = %#v; want unsupported model type", err)
}
@ -96,7 +104,7 @@ func TestBuildBasicGGUF(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if err := s.Build("x", model.File{From: w.fileName("gguf")}); err != nil {
if err := s.Build(qualifiedRef, model.File{From: w.fileName("gguf")}); err != nil {
t.Fatal(err)
}
@ -105,7 +113,12 @@ func TestBuildBasicGGUF(t *testing.T) {
return nil
})
path, err := s.WeightsFile(blob.ParseRef("x+Q4_0"))
_, err = s.WeightsFile("unknown/y/z:latest+Q4_0")
if !errors.Is(err, ErrNotFound) {
t.Fatalf("WeightsFile() err = %v; want not found", err)
}
path, err := s.WeightsFile("x/y/z:latest+Q4_0")
if err != nil {
t.Fatal(err)
}

View File

@ -18,7 +18,8 @@ import (
)
var (
ErrInvalidID = errors.New("invalid ID")
ErrInvalidID = errors.New("invalid ID")
ErrUnknownRef = errors.New("unknown ref")
)
const HashSize = 32
@ -199,6 +200,9 @@ func (s *Store) Resolve(ref blob.Ref) (data []byte, path string, err error) {
return nil, "", err
}
data, err = os.ReadFile(path)
if errors.Is(err, fs.ErrNotExist) {
return nil, "", fmt.Errorf("%w: %q", ErrUnknownRef, ref)
}
if err != nil {
return nil, "", &entryNotFoundError{Err: err}
}
@ -221,10 +225,10 @@ func (s *Store) Set(ref blob.Ref, data []byte) error {
}
func (s *Store) refFileName(ref blob.Ref) (string, error) {
if !ref.FullyQualified() {
if !ref.Complete() {
return "", fmt.Errorf("ref not fully qualified: %q", ref)
}
return filepath.Join(s.dir, "manifests", ref.Domain(), ref.Name(), ref.Tag(), ref.Build()), nil
return filepath.Join(s.dir, "manifests", filepath.Join(ref.Parts()...)), nil
}
// Get looks up the blob ID in the store,

View File

@ -70,14 +70,13 @@ func TestStoreBasicBlob(t *testing.T) {
}
// Check tags
ref := blob.ParseRef("test+KQED")
ref := blob.ParseRef("registry.ollama.ai/library/test:latest+KQED")
t.Logf("resolving %s", ref)
t.Logf("RESOLVING: %q", ref.Parts())
data, _, err := st.Resolve(ref)
var e *entryNotFoundError
if !errors.As(err, &e) {
t.Fatal(err)
if !errors.Is(err, ErrUnknownRef) {
t.Fatalf("unexpected error: %v", err)
}
if data != nil {
t.Errorf("unexpected data: %q", data)
@ -119,6 +118,7 @@ func checkDir(t testing.TB, dir string, want []string) {
var matches []string
for path, err := range walkDir(dir) {
t.Helper()
if err != nil {
t.Fatal(err)
}

View File

@ -3,6 +3,8 @@ package registry
import (
"context"
"encoding/json"
"errors"
"io"
"net/http/httptest"
"os/exec"
"strings"
@ -32,7 +34,9 @@ func TestPush(t *testing.T) {
]
}`)
got, err := c.Push(context.Background(), "x+y", manifest)
const ref = "registry.ollama.ai/x/y:latest+Z"
got, err := c.Push(context.Background(), ref, manifest)
if err != nil {
t.Fatal(err)
}
@ -44,13 +48,13 @@ func TestPush(t *testing.T) {
}, diff.ZeroFields[apitype.Requirement]("URL"))
for _, r := range got {
body := strings.NewReader(strings.Repeat("x", int(r.Size)))
body := io.Reader(strings.NewReader(strings.Repeat("x", int(r.Size))))
if err := PushLayer(context.Background(), r.URL, r.Size, body); err != nil {
t.Fatal(err)
}
}
got, err = c.Push(context.Background(), "x+y", manifest)
got, err = c.Push(context.Background(), ref, manifest)
if err != nil {
t.Fatal(err)
}
@ -81,10 +85,10 @@ func TestPush(t *testing.T) {
"blobs/sha256-1",
"blobs/sha256-2",
"blobs/sha256-3",
"manifests/registry.ollama.ai/x/latest/Y",
"manifests/registry.ollama.ai/x/y/latest/Z",
})
obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/latest/Y", minio.GetObjectOptions{})
obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{})
if err != nil {
t.Fatal(err)
}
@ -117,7 +121,13 @@ func startMinio(t *testing.T) {
t.Cleanup(func() {
cmd.Process.Kill()
if err := cmd.Wait(); err != nil {
t.Log(err)
var e *exec.ExitError
if errors.As(err, &e) && e.Exited() {
t.Logf("minio stderr: %s", e.Stderr)
t.Logf("minio exit status: %v", e.ExitCode())
t.Logf("minio exited: %v", e.Exited())
t.Error(err)
}
}
})