From c49947dcf5600e147fa76a8289c08615b49ada36 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Thu, 7 Mar 2024 20:33:57 -0800 Subject: [PATCH 01/29] init --- api/api.go | 106 +++++ build/blob/ref.go | 135 +++++++ build/blob/ref_test.go | 69 ++++ build/build.go | 165 ++++++++ build/build_test.go | 150 +++++++ build/convert.go | 12 + build/default.go | 28 ++ build/import.go | 59 +++ build/internal/blobstore/blob.go | 378 ++++++++++++++++++ build/internal/blobstore/blob_test.go | 54 +++ build/internal/blobstore/store_test.go | 161 ++++++++ client/ollama/apitype/apitype.go | 31 ++ client/ollama/ollama.go | 70 ++++ cmd/bllamo/bllamo.go | 100 +++++ cmd/bllamo/flags.go | 59 +++ cmd/gguf/gguf.go | 97 +++++ encoding/gguf/gguf.go | 376 +++++++++++++++++ encoding/gguf/gguf_test.go | 345 ++++++++++++++++ encoding/gguf/ggufio.go | 195 +++++++++ encoding/gguf/reader.go | 70 ++++ .../fuzz/FuzzReadInfo/787da6e90e4be491 | 2 + .../fuzz/FuzzReadInfo/8b42c37d144cd2c6 | 2 + .../fuzz/FuzzReadInfo/92b890e394a77cfc | 2 + .../fuzz/FuzzReadInfo/9cfd6a48931a2753 | 2 + .../fuzz/FuzzReadInfo/a8c5454e2a164af2 | 2 + .../fuzz/FuzzReadInfo/a931e37cb6f932d4 | 2 + .../fuzz/FuzzReadInfo/bcd20fa73e7351a2 | 2 + .../fuzz/FuzzReadInfo/d29846a68e32052d | 2 + go.mod | 30 ++ go.sum | 63 +++ model/file.go | 126 ++++++ oweb/oweb.go | 143 +++++++ registry/apitypes.go | 27 ++ registry/client.go | 50 +++ registry/server.go | 117 ++++++ registry/server_test.go | 99 +++++ types/empty/message.go | 4 + types/structs/structs.go | 15 + types/they/want.go | 12 + 39 files changed, 3362 insertions(+) create mode 100644 api/api.go create mode 100644 build/blob/ref.go create mode 100644 build/blob/ref_test.go create mode 100644 build/build.go create mode 100644 build/build_test.go create mode 100644 build/convert.go create mode 100644 build/default.go create mode 100644 build/import.go create mode 100644 build/internal/blobstore/blob.go create mode 100644 build/internal/blobstore/blob_test.go create mode 100644 build/internal/blobstore/store_test.go create mode 100644 client/ollama/apitype/apitype.go create mode 100644 client/ollama/ollama.go create mode 100644 cmd/bllamo/bllamo.go create mode 100644 cmd/bllamo/flags.go create mode 100644 cmd/gguf/gguf.go create mode 100644 encoding/gguf/gguf.go create mode 100644 encoding/gguf/gguf_test.go create mode 100644 encoding/gguf/ggufio.go create mode 100644 encoding/gguf/reader.go create mode 100644 encoding/gguf/testdata/fuzz/FuzzReadInfo/787da6e90e4be491 create mode 100644 encoding/gguf/testdata/fuzz/FuzzReadInfo/8b42c37d144cd2c6 create mode 100644 encoding/gguf/testdata/fuzz/FuzzReadInfo/92b890e394a77cfc create mode 100644 encoding/gguf/testdata/fuzz/FuzzReadInfo/9cfd6a48931a2753 create mode 100644 encoding/gguf/testdata/fuzz/FuzzReadInfo/a8c5454e2a164af2 create mode 100644 encoding/gguf/testdata/fuzz/FuzzReadInfo/a931e37cb6f932d4 create mode 100644 encoding/gguf/testdata/fuzz/FuzzReadInfo/bcd20fa73e7351a2 create mode 100644 encoding/gguf/testdata/fuzz/FuzzReadInfo/d29846a68e32052d create mode 100644 go.mod create mode 100644 go.sum create mode 100644 model/file.go create mode 100644 oweb/oweb.go create mode 100644 registry/apitypes.go create mode 100644 registry/client.go create mode 100644 registry/server.go create mode 100644 registry/server_test.go create mode 100644 types/empty/message.go create mode 100644 types/structs/structs.go create mode 100644 types/they/want.go diff --git a/api/api.go b/api/api.go new file mode 100644 index 00000000..fd525937 --- /dev/null +++ b/api/api.go @@ -0,0 +1,106 @@ +package api + +import ( + "errors" + "fmt" + "net/http" + "os" + + "bllamo.com/build" + "bllamo.com/build/blob" + "bllamo.com/client/ollama/apitype" + "bllamo.com/oweb" + "bllamo.com/registry" +) + +// Common API Errors +var ( + errUnqualifiedRef = oweb.Mistake("invalid", "name", "must be fully qualified") + errRefNotFound = oweb.Mistake("not_found", "name", "no such model") +) + +type Server struct { + Build *build.Server +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + oweb.Serve(s.serveHTTP, w, r) +} + +func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error { + switch r.URL.Path { + case "/v1/push": + return s.handlePush(w, r) + default: + return oweb.ErrNotFound + } +} + +func want(r *http.Request, method, path string) bool { + return r.Method == method && r.URL.Path == path +} + +func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error { + if r.Method != "POST" { + return oweb.ErrMethodNotAllowed + } + + params, err := oweb.DecodeJSON[apitype.PushRequest](r.Body) + if err != nil { + return err + } + + if params.Name == "" { + return oweb.Missing("name") + } + + const registryURLTODO = "http://localhost:8888" + + ref := blob.ParseRef(params.Name) + if !ref.FullyQualified() { + return errUnqualifiedRef + } + + man, err := s.Build.Manifest(ref) + if err != nil { + if errors.Is(err, build.ErrNotFound) { + return errRefNotFound + } + return err + } + + c := registry.Client{BaseURL: registryURLTODO} + requirements, err := c.Push(r.Context(), params.Name, man) + if err != nil { + return err + } + + for _, rq := range requirements { + l, err := s.Build.LayerFile(rq.Digest) + if err != nil { + return err + } + err = func() error { + f, err := os.Open(l) + if err != nil { + return err + } + defer f.Close() + return registry.PushLayer(r.Context(), rq.URL, rq.Size, f) + }() + if err != nil { + return err + } + } + + // commit the manifest to the registry + requirements, err = c.Push(r.Context(), params.Name, man) + if err != nil { + return err + } + for _, r := range requirements { + err = errors.Join(err, fmt.Errorf("push failed for %q", r.Digest)) + } + return err + +} diff --git a/build/blob/ref.go b/build/blob/ref.go new file mode 100644 index 00000000..9a033fcb --- /dev/null +++ b/build/blob/ref.go @@ -0,0 +1,135 @@ +package blob + +import ( + "cmp" + "strings" +) + +// Ref is an opaque reference to a blob. +// +// It is comparable and can be used as a map key. +// +// Users or Ref must check Valid before using it. +type Ref struct { + name string + tag string + build string +} + +// WithBuild returns a copy of r with the provided build. If the provided +// build is empty, it returns the short, unqualified copy of r. +func (r Ref) WithBuild(build string) Ref { + if build == "" { + return Ref{r.name, r.tag, ""} + } + if !isValidPart(build) { + return Ref{} + } + return makeRef(r.name, r.tag, build) +} + +// String returns the fully qualified ref string. +func (r Ref) String() string { + var b strings.Builder + b.WriteString(r.name) + if r.tag != "" { + b.WriteString(":") + b.WriteString(r.tag) + } + if r.build != "" { + b.WriteString("+") + b.WriteString(r.build) + } + return b.String() +} + +// Full returns the fully qualified ref string, or a string indicating the +// build is missing, or an empty string if the ref is invalid. +func (r Ref) Full() string { + if !r.Valid() { + return "" + } + return makeRef(r.name, r.tag, cmp.Or(r.build, "!(MISSING BUILD)")).String() +} + +// Short returns the short ref string which does not include the build. +func (r Ref) Short() string { + return r.WithBuild("").String() +} + +func (r Ref) Valid() bool { + return r.name != "" +} + +func (r Ref) FullyQualified() bool { + return r.name != "" && r.tag != "" && r.build != "" +} + +func (r Ref) Name() string { return r.name } +func (r Ref) Tag() string { return r.tag } +func (r Ref) Build() string { return r.build } + +// ParseRef parses a ref string into a Ref. A ref string is a name, an +// optional tag, and an optional build, separated by colons and pluses. +// +// The name must be valid ascii [a-zA-Z0-9_]. +// The tag must be valid ascii [a-zA-Z0-9_]. +// The build must be valid ascii [a-zA-Z0-9_]. +// +// It returns then zero value if the ref is invalid. +// +// // Valid Examples: +// ParseRef("mistral:latest") returns ("mistral", "latest", "") +// ParseRef("mistral") returns ("mistral", "", "") +// ParseRef("mistral:30B") returns ("mistral", "30B", "") +// ParseRef("mistral:7b") returns ("mistral", "7b", "") +// ParseRef("mistral:7b+Q4_0") returns ("mistral", "7b", "Q4_0") +// ParseRef("mistral+KQED") returns ("mistral", "latest", "KQED") +// ParseRef(".x.:7b+Q4_0:latest") returns (".x.", "7b", "Q4_0") +// ParseRef("-grok-f.oo:7b+Q4_0") returns ("-grok-f.oo", "7b", "Q4_0") +// +// // Invalid Examples: +// ParseRef("m stral") returns ("", "", "") // zero +// ParseRef("... 129 chars ...") returns ("", "", "") // zero +func ParseRef(s string) Ref { + if len(s) > 128 { + return Ref{} + } + + nameAndTag, build, expectBuild := strings.Cut(s, "+") + name, tag, expectTag := strings.Cut(nameAndTag, ":") + if !isValidPart(name) { + return Ref{} + } + if expectTag && !isValidPart(tag) { + return Ref{} + } + if expectBuild && !isValidPart(build) { + return Ref{} + } + return makeRef(name, tag, build) +} + +// makeRef makes a ref, skipping validation. +func makeRef(name, tag, build string) Ref { + return Ref{name, cmp.Or(tag, "latest"), strings.ToUpper(build)} +} + +// isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-] +func isValidPart(s string) bool { + if len(s) == 0 { + return false + } + for _, c := range []byte(s) { + if c == '.' || c == '-' { + return true + } + if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' { + continue + } else { + return false + + } + } + return true +} diff --git a/build/blob/ref_test.go b/build/blob/ref_test.go new file mode 100644 index 00000000..b49d39df --- /dev/null +++ b/build/blob/ref_test.go @@ -0,0 +1,69 @@ +package blob + +import "testing" + +// test refs +const ( + refTooLong = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" +) + +func TestParseRef(t *testing.T) { + cases := []struct { + in string + want Ref + }{ + {"mistral:latest", Ref{"mistral", "latest", ""}}, + {"mistral", Ref{"mistral", "latest", ""}}, + {"mistral:30B", Ref{"mistral", "30B", ""}}, + {"mistral:7b", Ref{"mistral", "7b", ""}}, + {"mistral:7b+Q4_0", Ref{"mistral", "7b", "Q4_0"}}, + {"mistral+KQED", Ref{"mistral", "latest", "KQED"}}, + {"mistral.x-3:7b+Q4_0", Ref{"mistral.x-3", "7b", "Q4_0"}}, + + // lowecase build + {"mistral:7b+q4_0", Ref{"mistral", "7b", "Q4_0"}}, + + // Invalid + {"mistral:7b+Q4_0:latest", Ref{"", "", ""}}, + {"mi tral", Ref{"", "", ""}}, + {"llama2:+", Ref{"", "", ""}}, + + // too long + {refTooLong, Ref{"", "", ""}}, + } + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + got := ParseRef(tt.in) + if got != tt.want { + t.Errorf("ParseRef(%q) = %q; want %q", tt.in, got, tt.want) + } + }) + } +} + +func TestRefFull(t *testing.T) { + cases := []struct { + in string + wantShort string + wantFull string + }{ + {"", "", ""}, + {"mistral:7b+x", "mistral:7b", "mistral:7b+X"}, + {"mistral:7b+Q4_0", "mistral:7b", "mistral:7b+Q4_0"}, + {"mistral:latest", "mistral:latest", "mistral:latest+!(MISSING BUILD)"}, + {"mistral", "mistral:latest", "mistral:latest+!(MISSING BUILD)"}, + {"mistral:30b", "mistral:30b", "mistral:30b+!(MISSING BUILD)"}, + } + + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + ref := ParseRef(tt.in) + if g := ref.Short(); g != tt.wantShort { + t.Errorf("Short(%q) = %q; want %q", tt.in, g, tt.wantShort) + } + if g := ref.Full(); g != tt.wantFull { + t.Errorf("Full(%q) = %q; want %q", tt.in, g, tt.wantFull) + } + }) + } +} diff --git a/build/build.go b/build/build.go new file mode 100644 index 00000000..f66138db --- /dev/null +++ b/build/build.go @@ -0,0 +1,165 @@ +package build + +import ( + "encoding/json" + "errors" + "fmt" + "io/fs" + "os" + + "bllamo.com/build/blob" + "bllamo.com/build/internal/blobstore" + "bllamo.com/model" +) + +// Errors +var ( + ErrInvalidRef = errors.New("invalid ref") + ErrUnsupportedModelFormat = errors.New("unsupported model format") + ErrMissingFileType = errors.New("missing 'general.file_type' key") + ErrNoSuchBlob = errors.New("no such blob") + ErrNotFound = errors.New("not found") +) + +type mediaType string + +// Known media types +const ( + mediaTypeModel mediaType = "application/vnd.ollama.image.model" +) + +type Server struct { + st *blobstore.Store +} + +// Open starts a new build server that uses dir as the base directory for all +// build artifacts. If dir is empty, DefaultDir is used. +// +// It returns an error if the provided or default dir cannot be initialized. +func Open(dir string) (*Server, error) { + if dir == "" { + var err error + dir, err = DefaultDir() + if err != nil { + return nil, err + } + } + st, err := blobstore.Open(dir) + if err != nil { + return nil, err + } + return &Server{st: st}, nil +} + +func (s *Server) Build(ref string, f model.File) error { + br := blob.ParseRef(ref) + if !br.Valid() { + return invalidRef(ref) + } + + // 1. Resolve FROM + // a. If it's a local file (gguf), hash it and add it to the store. + // b. If it's a local dir (safetensor), convert to gguf and add to + // store. + // c. If it's a remote file (http), refuse. + // 2. Turn other pragmas into layers, and add them to the store. + // 3. Create a manifest from the layers. + // 4. Store the manifest in the manifest cache + // 5. Done. + + if f.From == "" { + return &model.Error{Pragma: "FROM", Message: "missing"} + } + + var layers []layerJSON + + id, info, size, err := s.importModel(f.From) + if err != nil { + return err + } + layers = append(layers, layerJSON{ + ID: id, + MediaType: mediaTypeModel, + Size: size, + }) + + id, size, err = blobstore.PutString(s.st, f.License) + if err != nil { + return err + } + layers = append(layers, layerJSON{ + ID: id, + MediaType: "text/plain", + Size: size, + }) + + data, err := json.Marshal(manifestJSON{Layers: layers}) + if err != nil { + return err + } + return s.st.Set(br.WithBuild(info.FileType.String()), data) +} + +func (s *Server) LayerFile(digest string) (string, error) { + fileName := s.st.OutputFilename(blobstore.ParseID(digest)) + _, err := os.Stat(fileName) + if errors.Is(err, fs.ErrNotExist) { + return "", fmt.Errorf("%w: %q", ErrNoSuchBlob, digest) + } + return fileName, nil +} + +func (s *Server) Manifest(ref blob.Ref) ([]byte, error) { + data, _, err := s.getManifestData(ref) + if errors.Is(err, fs.ErrNotExist) { + return nil, fmt.Errorf("%w: %q", ErrNotFound, ref) + } + return data, err +} + +// WeightFile returns the absolute path to the weights file for the given model ref. +func (s *Server) WeightsFile(ref blob.Ref) (string, error) { + m, err := s.getManifest(ref) + if err != nil { + return "", err + } + for _, l := range m.Layers { + if l.MediaType == mediaTypeModel { + return s.st.OutputFilename(l.ID), nil + } + } + return "", fmt.Errorf("missing weights layer for %q", ref) +} + +type manifestJSON struct { + // Layers is the list of layers in the manifest. + Layers []layerJSON `json:"layers"` +} + +// Layer is a layer in a model manifest. +type layerJSON struct { + // ID is the ID of the layer. + ID blobstore.ID `json:"digest"` + MediaType mediaType `json:"mediaType"` + Size int64 `json:"size"` +} + +func (s *Server) getManifest(ref blob.Ref) (manifestJSON, error) { + data, path, err := s.getManifestData(ref) + if err != nil { + return manifestJSON{}, err + } + var m manifestJSON + if err := json.Unmarshal(data, &m); err != nil { + return manifestJSON{}, &fs.PathError{Op: "unmarshal", Path: path, Err: err} + } + return m, nil +} + +func (s *Server) getManifestData(ref blob.Ref) (data []byte, path string, err error) { + return s.st.Resolve(ref) +} + +func invalidRef(ref string) error { + return fmt.Errorf("%w: %q", ErrInvalidRef, ref) +} diff --git a/build/build_test.go b/build/build_test.go new file mode 100644 index 00000000..c146717e --- /dev/null +++ b/build/build_test.go @@ -0,0 +1,150 @@ +package build + +import ( + "errors" + "os" + "path/filepath" + "testing" + + "bllamo.com/build/blob" + "bllamo.com/encoding/gguf" + "bllamo.com/model" +) + +func TestServerBuildErrors(t *testing.T) { + dir := t.TempDir() + + s, err := Open(dir) + if err != nil { + t.Fatal(err) + } + + t.Run("FROM pragma missing", func(t *testing.T) { + err := s.Build("foo", model.File{}) + var e *model.Error + if !errors.As(err, &e) { + t.Fatalf("unexpected error: %v", err) + } + if e.Pragma != "FROM" { + t.Errorf("e.Pragma = %s; want FROM", e.Pragma) + } + if e.Message != "missing" { + t.Errorf("e.Message = %s; want missing", e.Message) + } + }) + + t.Run("FROM file not found", func(t *testing.T) { + err := s.Build("x", model.File{From: "bar"}) + if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("Build() err = %v; want file not found", err) + } + }) + + t.Run("FROM gguf", func(t *testing.T) { + w := newWorkDir(t) + // Write a gguf file without general.file_type metadata. + w.write("gguf", ""+ + "GGUF"+ // magic + "\x03\x00\x00\x00"+ // version + "\x00\x00\x00\x00\x00\x00\x00\x00"+ // numMetaValues + "\x00\x00\x00\x00\x00\x00\x00\x00"+ // numTensors + "", + ) + + err := s.Build("x", model.File{From: w.fileName("gguf")}) + if !errors.Is(err, ErrMissingFileType) { + t.Fatalf("Build() err = %#v; want missing file type", err) + } + }) + + t.Run("FROM obscure dir", func(t *testing.T) { + w := newWorkDir(t) + w.mkdirAll("unknown") + if err := s.Build("x", model.File{From: w.fileName("unknown")}); err != ErrUnsupportedModelFormat { + t.Fatalf("Build() err = %#v; want unsupported model type", err) + } + }) + + t.Run("FROM unsupported model type", func(t *testing.T) { + w := newWorkDir(t) + from := w.write("unknown", "unknown content") + err := s.Build("x", model.File{From: from}) + if !errors.Is(err, ErrUnsupportedModelFormat) { + t.Fatalf("Build() err = %#v; want unsupported model type", err) + } + }) +} + +func TestBuildBasicGGUF(t *testing.T) { + w := newWorkDir(t) + w.write("gguf", ""+ + "GGUF"+ // magic + "\x03\x00\x00\x00"+ // version + "\x00\x00\x00\x00\x00\x00\x00\x00"+ // numTensors + "\x01\x00\x00\x00\x00\x00\x00\x00"+ // numMetaValues + + // general.file_type key + "\x11\x00\x00\x00\x00\x00\x00\x00"+ // key length + "general.file_type"+ // key + "\x04\x00\x00\x00"+ // type (uint32) + "\x02\x00\x00\x00\x00\x00\x00\x00"+ // uint32 value + "", + ) + + dir := t.TempDir() + s, err := Open(dir) + if err != nil { + t.Fatal(err) + } + if err := s.Build("x", model.File{From: w.fileName("gguf")}); err != nil { + t.Fatal(err) + } + + filepath.Walk(dir, func(p string, info os.FileInfo, err error) error { + t.Logf("file: %s", p) + return nil + }) + + path, err := s.WeightsFile(blob.ParseRef("x+Q4_0")) + if err != nil { + t.Fatal(err) + } + + info, err := gguf.Stat(path) + if err != nil { + t.Fatal(err) + } + if info.FileType != gguf.TypeQ4_0 { + t.Errorf("info.FileType = %d; want 1", info.FileType) + } +} + +type work struct { + t testing.TB + dir string +} + +func newWorkDir(t *testing.T) work { + return work{t: t, dir: t.TempDir()} +} + +func (w work) write(name, content string) (path string) { + w.t.Helper() + path = w.fileName(name) + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + w.t.Fatal(err) + } + return path +} + +func (w work) fileName(name string) string { + w.t.Helper() + return filepath.Join(w.dir, name) +} + +func (w work) mkdirAll(path string) { + w.t.Helper() + if err := os.MkdirAll(filepath.Join(w.dir, path), 0755); err != nil { + w.t.Fatal(err) + } +} diff --git a/build/convert.go b/build/convert.go new file mode 100644 index 00000000..939bb691 --- /dev/null +++ b/build/convert.go @@ -0,0 +1,12 @@ +package build + +func convertSafeTensorToGGUF(path string) (ggufPath string, err error) { + // TODO: decine on hueristic for converting safetensor to gguf and + // the errors that can be returned. For now, we just say + // "unsupported", however it may be intended to be a valid safe + // tensor but we hit an error in the conversion. + // + // I (bmizernay) think this will naturally evolve as we implement + // the conversion. + return "", ErrUnsupportedModelFormat +} diff --git a/build/default.go b/build/default.go new file mode 100644 index 00000000..927301d8 --- /dev/null +++ b/build/default.go @@ -0,0 +1,28 @@ +package build + +import ( + "os" + "path/filepath" + "sync" +) + +var ( + defaultDir = sync.OnceValues(func() (string, error) { + dir := os.Getenv("OLLAMA_MODELS") + if dir == "" { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + dir = filepath.Join(home, ".ollama", "models") + } + return dir, nil + }) +) + +// DefaultDir returns the default directory for models. It returns the value +// of the OLLAMA_MODELS environment variable if set; otherwise it returns +// "$HOME/.ollama/models". +func DefaultDir() (string, error) { + return defaultDir() +} diff --git a/build/import.go b/build/import.go new file mode 100644 index 00000000..412989a6 --- /dev/null +++ b/build/import.go @@ -0,0 +1,59 @@ +package build + +import ( + "errors" + "fmt" + "os" + + "bllamo.com/build/internal/blobstore" + "bllamo.com/encoding/gguf" +) + +func importError(err error) (blobstore.ID, gguf.Info, int64, error) { + return blobstore.ID{}, gguf.Info{}, 0, err +} + +func (s *Server) importModel(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) { + info, err := os.Stat(path) + if err != nil { + return importError(err) + } + if info.IsDir() { + return s.importSafeTensor(path) + } else { + return s.importGGUF(path) + } +} + +func (s *Server) importGGUF(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) { + f, err := os.Open(path) + if err != nil { + return importError(err) + } + defer f.Close() + + info, err := gguf.StatReader(f) + if errors.Is(err, gguf.ErrBadMagic) { + return importError(ErrUnsupportedModelFormat) + } + if err != nil { + return importError(err) + } + + if info.FileType == 0 { + return importError(fmt.Errorf("%w: %q", ErrMissingFileType, path)) + } + id, size, err := s.st.Put(f) + if err != nil { + return importError(err) + } + return id, info, size, nil +} + +func (s *Server) importSafeTensor(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) { + path, err := convertSafeTensorToGGUF(path) + if err != nil { + return importError(err) + } + return s.importGGUF(path) +} diff --git a/build/internal/blobstore/blob.go b/build/internal/blobstore/blob.go new file mode 100644 index 00000000..e5981416 --- /dev/null +++ b/build/internal/blobstore/blob.go @@ -0,0 +1,378 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package blobstore implements a blob store. +package blobstore + +import ( + "bytes" + "crypto/sha256" + "errors" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "strings" + "time" + + "bllamo.com/build/blob" + "bllamo.com/types/structs" +) + +var ( + ErrInvalidID = errors.New("invalid ID") +) + +const HashSize = 32 + +// An ID is a blob output key, the hash of an output of a computation. +type ID struct { + a [HashSize]byte +} + +func (id ID) MarshalText() ([]byte, error) { + return []byte(id.String()), nil +} + +func (id *ID) UnmarshalText(text []byte) error { + *id = ParseID(string(text)) + return nil +} + +func ParseID(s string) ID { + const prefix = "sha256-" + h, ok := strings.CutPrefix(s, prefix) + if !ok { + return ID{} + } + + if len(h) != HashSize*2 { + return ID{} + } + + var b []byte + _, err := fmt.Sscanf(h, "%x", &b) + if err != nil { + return ID{} + } + + var id ID + copy(id.a[:], b) + return id +} + +func (id ID) String() string { + if !id.Valid() { + return "" + } + return fmt.Sprintf("sha256-%x", id.a[:]) +} + +func (id ID) Valid() bool { + return id != ID{} +} + +func (id ID) Match(h [HashSize]byte) bool { + return id.a == h +} + +// A Store is a blob store, backed by a file system directory tree. +type Store struct { + dir string + now func() time.Time +} + +// Open opens and returns the store in the given directory. +// +// It is safe for multiple processes on a single machine to use the +// same store directory in a local file system simultaneously. +// They will coordinate using operating system file locks and may +// duplicate effort but will not corrupt the store. +// +// However, it is NOT safe for multiple processes on different machines +// to share a store directory (for example, if the directory were stored +// in a network file system). File locking is notoriously unreliable in +// network file systems and may not suffice to protect the store. +func Open(dir string) (*Store, error) { + info, err := os.Stat(dir) + if err != nil { + return nil, err + } + if !info.IsDir() { + return nil, &fs.PathError{Op: "open", Path: dir, Err: fmt.Errorf("not a directory")} + } + + for _, sub := range []string{"blobs", "manifests"} { + if err := os.MkdirAll(filepath.Join(dir, sub), 0777); err != nil { + return nil, err + } + } + + c := &Store{ + dir: dir, + now: time.Now, + } + return c, nil +} + +// fileName returns the name of the blob file corresponding to the given id. +func (s *Store) fileName(id ID) string { + return filepath.Join(s.dir, "blobs", fmt.Sprintf("sha256-%x", id.a[:])) +} + +// An entryNotFoundError indicates that a store entry was not found, with an +// optional underlying reason. +type entryNotFoundError struct { + Err error +} + +func (e *entryNotFoundError) Error() string { + if e.Err == nil { + return "store entry not found" + } + return fmt.Sprintf("store entry not found: %v", e.Err) +} + +func (e *entryNotFoundError) Unwrap() error { + return e.Err +} + +type Entry struct { + _ structs.Incomparable + + ID ID + Size int64 + Time time.Time // when added to store +} + +// GetFile looks up the blob ID in the store and returns +// the name of the corresponding data file. +func GetFile(s *Store, id ID) (file string, entry Entry, err error) { + entry, err = s.Get(id) + if err != nil { + return "", Entry{}, err + } + file = s.OutputFilename(entry.ID) + info, err := os.Stat(file) + if err != nil { + return "", Entry{}, &entryNotFoundError{Err: err} + } + if info.Size() != entry.Size { + return "", Entry{}, &entryNotFoundError{Err: errors.New("file incomplete")} + } + return file, entry, nil +} + +// GetBytes looks up the blob ID in the store and returns +// the corresponding output bytes. +// GetBytes should only be used for data that can be expected to fit in memory. +func GetBytes(s *Store, id ID) ([]byte, Entry, error) { + entry, err := s.Get(id) + if err != nil { + return nil, entry, err + } + data, _ := os.ReadFile(s.OutputFilename(entry.ID)) + if entry.ID.Match(sha256.Sum256(data)) { + return nil, entry, &entryNotFoundError{Err: errors.New("bad checksum")} + } + return data, entry, nil +} + +// OutputFilename returns the name of the blob file for the given ID. +func (s *Store) OutputFilename(id ID) string { + file := s.fileName(id) + // TODO(bmizerany): touch as "used" for cache trimming. (see + // cache.go in cmd/go/internal/cache for the full reference implementation to go off of. + return file +} + +// Resolve returns the data for the given ref, if any. +// +// TODO: This should ideally return an ID, but the current on +// disk layout is that the actual manifest is stored in the "ref" instead of +// a pointer to a content-addressed blob. I (bmizerany) think we should +// change the on-disk layout to store the manifest in a content-addressed +// blob, and then have the ref point to that blob. This would simplify the +// code, allow us to have integrity checks on the manifest, and clean up +// this interface. +func (s *Store) Resolve(ref blob.Ref) (data []byte, path string, err error) { + path, err = s.refFileName(ref) + if err != nil { + return nil, "", err + } + data, err = os.ReadFile(path) + if err != nil { + return nil, "", &entryNotFoundError{Err: err} + } + return data, path, nil +} + +// Set sets the data for the given ref. +func (s *Store) Set(ref blob.Ref, data []byte) error { + path, err := s.refFileName(ref) + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil { + return err + } + if err := os.WriteFile(path, data, 0666); err != nil { + return err + } + return nil +} + +func (s *Store) refFileName(ref blob.Ref) (string, error) { + if !ref.FullyQualified() { + return "", fmt.Errorf("ref not fully qualified: %q", ref) + } + const cheatTODO = "registry.ollama.ai/library" + return filepath.Join(s.dir, "manifests", cheatTODO, ref.Name(), ref.Tag(), ref.Build()), nil +} + +// Get looks up the blob ID in the store, +// returning the corresponding output ID and file size, if any. +// Note that finding an output ID does not guarantee that the +// saved file for that output ID is still available. +func (s *Store) Get(id ID) (Entry, error) { + file := s.fileName(id) + info, err := os.Stat(file) + if err != nil { + return Entry{}, &entryNotFoundError{Err: err} + } + return Entry{ + ID: id, + Size: info.Size(), + Time: info.ModTime(), + }, nil +} + +func (s *Store) Close() error { + // TODO(bmizerany): return c.Trim() + return nil +} + +// Put stores the data read from the given file into the store as ID. +// +// It may read file twice. The content of file must not change between the +// two passes. +func (s *Store) Put(file io.ReadSeeker) (ID, int64, error) { + return s.put(file) +} + +func PutBytes(s *Store, data []byte) (ID, int64, error) { + return s.Put(bytes.NewReader(data)) +} + +func PutString(s *Store, data string) (ID, int64, error) { + return s.Put(strings.NewReader(data)) +} + +func (s *Store) put(file io.ReadSeeker) (ID, int64, error) { + // Compute output ID. + h := sha256.New() + if _, err := file.Seek(0, 0); err != nil { + return ID{}, 0, err + } + size, err := io.Copy(h, file) + if err != nil { + return ID{}, 0, err + } + var out ID + h.Sum(out.a[:0]) + + // Copy to blob file (if not already present). + if err := s.copyFile(file, out, size); err != nil { + return out, size, err + } + + // TODO: Add to manifest index. + return out, size, nil +} + +// copyFile copies file into the store, expecting it to have the given +// output ID and size, if that file is not present already. +func (s *Store) copyFile(file io.ReadSeeker, out ID, size int64) error { + name := s.fileName(out) + println("name", name) + info, err := os.Stat(name) + if err == nil && info.Size() == size { + // Check hash. + if f, err := os.Open(name); err == nil { + h := sha256.New() + io.Copy(h, f) + f.Close() + var out2 ID + h.Sum(out2.a[:0]) + if out == out2 { + return nil + } + } + // Hash did not match. Fall through and rewrite file. + } + + // Copy file to blobs directory. + mode := os.O_RDWR | os.O_CREATE + if err == nil && info.Size() > size { // shouldn't happen but fix in case + mode |= os.O_TRUNC + } + f, err := os.OpenFile(name, mode, 0666) + if err != nil { + return err + } + defer f.Close() + if size == 0 { + // File now exists with correct size. + // Only one possible zero-length file, so contents are OK too. + // Early return here makes sure there's a "last byte" for code below. + return nil + } + + // From here on, if any of the I/O writing the file fails, + // we make a best-effort attempt to truncate the file f + // before returning, to avoid leaving bad bytes in the file. + + // Copy file to f, but also into h to double-check hash. + if _, err := file.Seek(0, 0); err != nil { + f.Truncate(0) + return err + } + h := sha256.New() + w := io.MultiWriter(f, h) + if _, err := io.CopyN(w, file, size-1); err != nil { + f.Truncate(0) + return err + } + // Check last byte before writing it; writing it will make the size match + // what other processes expect to find and might cause them to start + // using the file. + buf := make([]byte, 1) + if _, err := file.Read(buf); err != nil { + f.Truncate(0) + return err + } + h.Write(buf) + sum := h.Sum(nil) + if !bytes.Equal(sum, out.a[:]) { + f.Truncate(0) + return fmt.Errorf("file content changed underfoot") + } + + // Commit manifest entry. + if _, err := f.Write(buf); err != nil { + f.Truncate(0) + return err + } + if err := f.Close(); err != nil { + // Data might not have been written, + // but file may look like it is the right size. + // To be extra careful, remove stored file. + os.Remove(name) + return err + } + os.Chtimes(name, s.now(), s.now()) // mainly for tests + + return nil +} diff --git a/build/internal/blobstore/blob_test.go b/build/internal/blobstore/blob_test.go new file mode 100644 index 00000000..f54684a2 --- /dev/null +++ b/build/internal/blobstore/blob_test.go @@ -0,0 +1,54 @@ +package blobstore + +import ( + "strings" + "testing" +) + +func TestParseID(t *testing.T) { + const valid = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + var invalid = strings.Repeat("\x00", HashSize*2) + + cases := []struct { + in string + want string + }{ + {"", invalid}, + {"sha256-", invalid}, + {"sha256-" + valid, valid}, + + {"" + valid, invalid}, // no prefix + {"sha123-" + valid, invalid}, // invalid prefix + {"sha256-" + valid[1:], invalid}, // too short + {"sha256-" + valid + "a", invalid}, // too long + {"sha256-!" + valid[1:], invalid}, // invalid hex + } + + for _, tt := range cases { + t.Run("", func(t *testing.T) { + // sanity check + if len(tt.want) > HashSize*2 { + panic("invalid test") + } + + got := ParseID(tt.in) + + wantValid := tt.want != invalid + if wantValid { + if !got.Valid() { + t.Errorf("ParseID(%q).Valid() = false; want true", tt.in) + } + if got.String() != "sha256-"+tt.want { + t.Errorf("ParseID(%q).String() = %q; want %q", tt.in, got.String(), "sha256-"+tt.want) + } + } else { + if got.Valid() { + t.Errorf("ParseID(%q).Valid() = true; want false", tt.in) + } + if got.String() != "" { + t.Errorf("ParseID(%q).String() = %q; want %q", tt.in, got.String(), "") + } + } + }) + } +} diff --git a/build/internal/blobstore/store_test.go b/build/internal/blobstore/store_test.go new file mode 100644 index 00000000..ddcc05aa --- /dev/null +++ b/build/internal/blobstore/store_test.go @@ -0,0 +1,161 @@ +package blobstore + +import ( + "errors" + "iter" + "os" + "path/filepath" + "testing" + "time" + + "bllamo.com/build/blob" + "kr.dev/diff" +) + +const ( + blobNameHello = "sha256-2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824" +) + +func TestStoreBasicBlob(t *testing.T) { + dir := t.TempDir() + + checkDir(t, dir, nil) + + st, err := Open(dir) + if err != nil { + t.Fatal(err) + } + + now := time.Now() + st.now = func() time.Time { return now } + + checkDir(t, dir, []string{ + "blobs/", + "manifests/", + }) + + id, size, err := PutBytes(st, []byte("hello")) + if err != nil { + t.Fatal(err) + } + + if id != ParseID(blobNameHello) { + t.Errorf("unexpected ID: %s", id) + } + if size != 5 { + t.Errorf("unexpected size: %d", size) + } + + checkDir(t, dir, []string{ + "blobs/", + "blobs/" + blobNameHello, + "manifests/", + }) + + got, err := st.Get(id) + if err != nil { + t.Fatal(err) + } + + diff.Test(t, t.Errorf, got, Entry{ + ID: id, + Size: 5, + Time: now, + }) + + file := st.OutputFilename(id) + wantFile := filepath.Join(dir, "blobs", blobNameHello) + if file != wantFile { + t.Errorf("unexpected file: %s", file) + } + + // Check tags + ref := blob.ParseRef("test+KQED") + + t.Logf("resolving %s", ref) + + data, _, err := st.Resolve(ref) + var e *entryNotFoundError + if !errors.As(err, &e) { + t.Fatal(err) + } + if data != nil { + t.Errorf("unexpected data: %q", data) + } + + if err := st.Set(ref, []byte("{}")); err != nil { + t.Fatal(err) + } + + data, _, err = st.Resolve(ref) + if err != nil { + t.Fatal(err) + } + + if g := string(data); g != "{}" { + t.Errorf("g = %q; want %q", g, "{}") + } + + checkDir(t, dir, []string{ + "blobs/", + "blobs/sha256-2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", + "manifests/", + "manifests/registry.ollama.ai/", + "manifests/registry.ollama.ai/library/", + "manifests/registry.ollama.ai/library/test/", + "manifests/registry.ollama.ai/library/test/latest/", + "manifests/registry.ollama.ai/library/test/latest/KQED", + }) +} + +// checkDir checks that the directory at dir contains the files in want. The +// files in want must be relative to dir. +// +// direcotories are suffixed with a slash (e.g. "foo/" instead of "foo"). +// +// want must be in lexicographic order. +func checkDir(t testing.TB, dir string, want []string) { + t.Helper() + + var matches []string + for path, err := range walkDir(dir) { + if err != nil { + t.Fatal(err) + } + t.Logf("found %s", path) + if path == "./" { + continue + } + path = filepath.ToSlash(path) + matches = append(matches, path) + } + + diff.Test(t, t.Errorf, matches, want) +} + +var errStop = errors.New("stop") + +func walkDir(dir string) iter.Seq2[string, error] { + return func(yield func(string, error) bool) { + err := filepath.WalkDir(dir, func(path string, info os.DirEntry, err error) error { + if err != nil { + return err + } + path, err = filepath.Rel(dir, path) + if err != nil { + return err + } + path = filepath.ToSlash(path) + if info.IsDir() { + path += "/" + } + if !yield(path, nil) { + return errStop + } + return nil + }) + if !errors.Is(err, errStop) && err != nil { + yield("", err) + } + } +} diff --git a/client/ollama/apitype/apitype.go b/client/ollama/apitype/apitype.go new file mode 100644 index 00000000..e2def7a2 --- /dev/null +++ b/client/ollama/apitype/apitype.go @@ -0,0 +1,31 @@ +package apitype + +import "time" + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type Model struct { + Ref string `json:"ref"` + Digest string `json:"digest"` + Size int64 `json:"size"` + ModifiedAt int64 `json:"modified"` +} + +func (m Model) Modifed() time.Time { + return time.Unix(0, m.ModifiedAt) +} + +type PushRequest struct { + Name string `json:"name"` // Ref is the official term, "name" is for backward compatibility with exiting clients. + Insecure bool `json:"insecure"` + Stream bool `json:"stream"` +} + +type PushStatus struct { + Status string `json:"status"` + Digest string `json:"digest"` + Total int64 `json:"total"` +} diff --git a/client/ollama/ollama.go b/client/ollama/ollama.go new file mode 100644 index 00000000..2c89c4cd --- /dev/null +++ b/client/ollama/ollama.go @@ -0,0 +1,70 @@ +package ollama + +import ( + "cmp" + "context" + "io/fs" + "iter" + "os" + + "bllamo.com/client/ollama/apitype" + "bllamo.com/oweb" + "bllamo.com/types/empty" +) + +// TODO(bmizerany): PROGRESS INDICATORS!!!! + +const DefaultBaseURL = "http://localhost:11434" + +var envBaseURL = cmp.Or(os.Getenv("OLLAMA_BASE_URL"), DefaultBaseURL) + +// Default returns a new client with the default base URL. +func Default() *Client { + return &Client{BaseURL: envBaseURL} +} + +// I_Acknowledge_This_API_Is_Under_Development is a flag that must be set to +// true for any instance of Client to work. +var I_Acknowledge_This_API_Is_Under_Development bool + +// Client is a client for the Ollama API. +type Client struct { + // BaseURL is the base URL of the Ollama API. + BaseURL string +} + +// Build requests the remote Ollama service to build a model. It uploads any +// source files the server needs. +func (c *Client) Build(ctx context.Context, ref string, modelfile []byte, source fs.FS) error { + panic("not implemented") +} + +// Push requests the remote Ollama service to push a model to the server. +func (c *Client) Push(ctx context.Context, ref string) error { + _, err := oweb.Do[empty.Message](ctx, "POST", c.BaseURL+"/v1/push", apitype.PushRequest{Name: ref}) + return err +} + +func (c *Client) Pull(ctx context.Context, ref string) error { + panic("not implemented") +} + +func (c *Client) List(ctx context.Context) iter.Seq2[apitype.Model, error] { + panic("not implemented") +} + +func (c *Client) Show(ctx context.Context, ref string) (*apitype.Model, error) { + panic("not implemented") +} + +func (c *Client) Remove(ctx context.Context, ref string) error { + panic("not implemented") +} + +func (c *Client) Copy(ctx context.Context, dstRef, srcRef string) error { + panic("not implemented") +} + +func (c *Client) Run(ctx context.Context, ref string, messages []apitype.Message) error { + panic("not implemented") +} diff --git a/cmd/bllamo/bllamo.go b/cmd/bllamo/bllamo.go new file mode 100644 index 00000000..32dadfb7 --- /dev/null +++ b/cmd/bllamo/bllamo.go @@ -0,0 +1,100 @@ +// Bllamo is a (new) tool for managing Ollama models. +// +// Usage: +// +// bllamo [arguments] +// +// The commands are: +// +// build build a model from a Modelfile +// list list all models +// push push a model from an ollama registry +// pull pull a model from an ollama registry +// delete delete a model from an ollama registry +// help display help for a command +package main + +import ( + "cmp" + "context" + "flag" + "fmt" + "net/http" + "os" + + "bllamo.com/api" + "bllamo.com/build" + "bllamo.com/client/ollama" + "bllamo.com/registry" +) + +func main() { + flag.Parse() + args := flag.Args() + if len(args) < 1 { + fmt.Fprintln(os.Stderr, "bllamo: no command provided") + os.Exit(2) + } + if err := Main(flag.Args()...); err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } +} + +var TODOUsage = fmt.Errorf("TODO: usage") + +var commands = map[string]func(ctx context.Context, args ...string) error{ + "build": cmdBuild, + "push": cmdPush, + "serve": cmdServe, + "registry": cmdRegistry, +} + +// Main is the entry point for the blammo command. +func Main(args ...string) error { + cmd := args[0] + args = args[1:] + if f, ok := commands[cmd]; ok { + ctx := context.TODO() + return f(ctx, args...) + } + return fmt.Errorf("blammo: unknown command %q", cmd) +} + +func cmdBuild(ctx context.Context, args ...string) error { + var v struct { + Modelfile string `flag:"f,the Modelfile to use"` + } + + fs := readFlags("build", args, &v) + if fs.NArg() != 1 { + return TODOUsage + } + + modelfile, err := os.ReadFile(cmp.Or(v.Modelfile, "Modelfile")) + if err != nil { + return err + } + return ollama.Default().Build(ctx, args[0], modelfile, os.DirFS(".")) +} + +func cmdRegistry(_ context.Context, _ ...string) error { + var s registry.Server + return http.ListenAndServe(":8888", &s) +} + +func cmdServe(ctx context.Context, args ...string) error { + bs, err := build.Open("") + if err != nil { + return err + } + return http.ListenAndServe(":11434", &api.Server{Build: bs}) +} + +func cmdPush(ctx context.Context, args ...string) error { + fs := readFlags("push", args, nil) + if fs.NArg() != 1 { + return TODOUsage + } + return ollama.Default().Push(ctx, fs.Arg(0)) +} diff --git a/cmd/bllamo/flags.go b/cmd/bllamo/flags.go new file mode 100644 index 00000000..a781c7ce --- /dev/null +++ b/cmd/bllamo/flags.go @@ -0,0 +1,59 @@ +package main + +import ( + "flag" + "fmt" + "reflect" + "strings" +) + +// parseArgs parses the provided args using a flag.FlagSet that is +// dynamically build using reflection for the provided type. The type fields +// that have a "flag" tag are used to build the flags. The flag tag should +// include either a ('-'). Example usage: +// +// func main() { +// var flags struct { +// Modelfile string `flag:"f,path to the Modelfile"` +// } +// +// fs := readFlags(os.Args[1:], &flags) +// fs.Parse(os.Args[1:]) +// } +func readFlags(name string, args []string, v any) *flag.FlagSet { + fs := flag.NewFlagSet(name, flag.ExitOnError) + defer fs.Parse(args) + if v == nil { + return fs + } + + for i := 0; i < reflect.ValueOf(v).NumField(); i++ { + f := reflect.ValueOf(v).Field(i) + if !f.CanSet() { + continue + } + + tag := f.Type().Field(i).Tag.Get("flag") + if tag == "" { + continue + } + var name, usage string + if i := strings.Index(tag, ","); i != -1 { + name = tag[:i] + usage = tag[i+1:] + } else { + name = tag + } + + // TODO(bmizerany): add more types as needed + switch f.Kind() { + case reflect.String: + fs.StringVar(f.Addr().Interface().(*string), name, "", usage) + case reflect.Bool: + fs.BoolVar(f.Addr().Interface().(*bool), name, false, usage) + default: + panic(fmt.Sprintf("unsupported type %v", f.Kind())) + } + } + return fs +} diff --git a/cmd/gguf/gguf.go b/cmd/gguf/gguf.go new file mode 100644 index 00000000..0ea8bf55 --- /dev/null +++ b/cmd/gguf/gguf.go @@ -0,0 +1,97 @@ +// Gguf is a tool for learning about GGUF files. +// +// Usage: +// +// gguf [flags] +package main + +import ( + "flag" + "fmt" + "io" + "log" + "os" + "text/tabwriter" + + "bllamo.com/encoding/gguf" +) + +func main() { + if err := Main(os.Stdout, os.Args[1:]...); err != nil { + log.Fatal(err) + } +} + +func Main(stdout io.Writer, args ...string) error { + fs := flag.NewFlagSet("gguf", flag.ExitOnError) + flagGPU := fs.Uint64("gpu", 0, "use N bytes of GPU memory (default is 0)") + + fs.Usage = func() { + io.WriteString(stdout, "Gguf is a tool for learning about GGUF files.\n") + io.WriteString(stdout, "\n") + io.WriteString(stdout, "Usage:\n") + io.WriteString(stdout, "\n") + io.WriteString(stdout, "\tgguf [flags] \n") + io.WriteString(stdout, "\n") + var numFlags int + fs.VisitAll(func(*flag.Flag) { numFlags++ }) + if numFlags > 0 { + io.WriteString(stdout, "Flags:\n") + fs.PrintDefaults() + } + } + fs.Parse(args) + + if fs.NArg() != 1 { + fs.Usage() + os.Exit(2) + } + + file := fs.Arg(0) + f, err := os.Open(file) + if err != nil { + log.Fatal(err) + } + defer f.Close() + + g, err := gguf.ReadFile(f) + if err != nil { + log.Fatal(err) + } + + tw := tabwriter.NewWriter(stdout, 0, 2, 2, ' ', 0) + defer tw.Flush() + + fmt.Fprintf(tw, "version:\t%d\n", g.Version()) + + for m, err := range g.Metadata { + if err != nil { + log.Fatal(err) + } + if len(m.Values) > 5 { + fmt.Fprintf(tw, "meta:\t%q: ... (%d values)\n", m.Key, len(m.Values)) + } else { + fmt.Fprintf(tw, "meta:\t%q: %v\n", m.Key, m.Values) + } + } + + var i int + var totalLayerBytes uint64 + var offGPU bool + for t, err := range g.Tensors { + if err != nil { + log.Fatal(err) + } + + totalLayerBytes += t.Size + if totalLayerBytes > *flagGPU { + offGPU = true + } + + const msg = "tensor (layer %000d):\t%q\t%s\tdims=%v\toffset=%d\tsize=%d\tonGPU=%v\n" + fmt.Fprintf(tw, msg, i, t.Name, t.Type, t.Dimensions, t.Offset, t.Size, !offGPU) + + i++ + } + return nil +} diff --git a/encoding/gguf/gguf.go b/encoding/gguf/gguf.go new file mode 100644 index 00000000..79e7c98a --- /dev/null +++ b/encoding/gguf/gguf.go @@ -0,0 +1,376 @@ +package gguf + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "os" + "strconv" + "strings" + + "bllamo.com/types/structs" +) + +// TODO(bmizerany): determine a more reasonable value for MaxDimensions. + +// MaxDimensions is the maximum number of dimensions a tensor can have. +const MaxDimensions uint32 = 1e6 + +// Errors +var ( + // ErrBadMagic is returned when the magic bytes at the start of the + // file. This is useful for detecting if the file is not a gguf + // file. + ErrBadMagic = errors.New("gguf: bad magic") + + ErrUnsupportedVersion = errors.New("gguf: unsupported version") + ErrMangled = errors.New("gguf: mangled data") +) + +type Type uint32 + +const ( + TypeF32 Type = 0 + TypeF16 Type = 1 + TypeQ4_0 Type = 2 + TypeQ4_1 Type = 3 + TypeQ5_0 Type = 6 + TypeQ5_1 Type = 7 + TypeQ8_0 Type = 8 + TypeQ8_1 Type = 9 + TypeQ2_K Type = 10 + TypeQ3_K Type = 11 + TypeQ4_K Type = 12 + TypeQ5_K Type = 13 + TypeQ6_K Type = 14 + TypeQ8_K Type = 15 + TypeI8 Type = 16 + TypeI16 Type = 17 + TypeI32 Type = 18 + TypeCount Type = 19 +) + +var typeNames = map[Type]string{ + TypeF32: "F32", + TypeF16: "F16", + TypeQ4_0: "Q4_0", + TypeQ4_1: "Q4_1", + TypeQ5_0: "Q5_0", + TypeQ5_1: "Q5_1", + TypeQ8_0: "Q8_0", + TypeQ8_1: "Q8_1", + TypeQ2_K: "Q2_K", + TypeQ3_K: "Q3_K", + TypeQ4_K: "Q4_K", + TypeQ5_K: "Q5_K", + TypeQ6_K: "Q6_K", + TypeQ8_K: "Q8_K", + TypeI8: "I8", + TypeI16: "I16", + TypeI32: "I32", + TypeCount: "COUNT", +} + +func (t Type) String() string { + if name := typeNames[t]; name != "" { + return name + } + return fmt.Sprintf("(!unknown_type %d!)", t) +} + +// ValueType is the type of a metadata value. +type ValueType uint32 + +func (t ValueType) String() string { + if name := metaTypeNames[t]; name != "" { + return name + } + return fmt.Sprintf("(!unknown_value_type %d!)", t) +} + +const ( + ValueTypeUint8 ValueType = 0 + ValueTypeInt8 ValueType = 1 + ValueTypeUint16 ValueType = 2 + ValueTypeInt16 ValueType = 3 + ValueTypeUint32 ValueType = 4 + ValueTypeInt32 ValueType = 5 + ValueTypeFloat32 ValueType = 6 + ValueTypeBool ValueType = 7 + ValueTypeString ValueType = 8 + ValueTypeArray ValueType = 9 + ValueTypeUint64 ValueType = 10 + ValueTypeInt64 ValueType = 11 + ValueTypeFloat64 ValueType = 12 +) + +var metaTypeNames = map[ValueType]string{ + ValueTypeUint8: "uint8", + ValueTypeInt8: "int8", + ValueTypeUint16: "uint16", + ValueTypeInt16: "int16", + ValueTypeUint32: "uint32", + ValueTypeInt32: "int32", + ValueTypeFloat32: "float32", + ValueTypeBool: "bool", + ValueTypeString: "string", + ValueTypeArray: "array", + ValueTypeUint64: "uint64", + ValueTypeInt64: "int64", + ValueTypeFloat64: "float64", +} + +type TensorInfo struct { + Name string + Dimensions []uint64 + Type Type + Offset uint64 + Size uint64 +} + +type MetaValue struct { + Type ValueType + Value []byte +} + +func (v MetaValue) String() string { + var b strings.Builder + b.WriteString(v.Type.String()) + b.WriteString("(") + switch v.Type { + case ValueTypeArray: + b.WriteString("[...]") + case ValueTypeString: + b.WriteString(strconv.Quote(string(v.Value))) + case ValueTypeBool: + if len(v.Value) == 0 { + b.WriteString("(!invalid bool)") + } + switch v.Value[0] { + case 0: + b.WriteString("false") + case 1: + b.WriteString("true") + default: + b.WriteString("!invalid bool") + } + case ValueTypeUint8, ValueTypeInt8, ValueTypeUint16, ValueTypeInt16, ValueTypeUint32, ValueTypeInt32, ValueTypeUint64, ValueTypeInt64, ValueTypeFloat32, ValueTypeFloat64: + var buf [8]byte + if len(v.Value) < 8 { + copy(buf[:], v.Value) + } + fmt.Fprintf(&b, "%v", binary.LittleEndian.Uint64(buf[:])) + default: + fmt.Fprintf(&b, "%v", v.Value) + } + b.WriteString(")") + return b.String() +} + +type MetaEntry struct { + Key string + Type ValueType + Values []MetaValue +} + +func (e MetaEntry) String() string { + if len(e.Values) == 0 { + return "" + } + return string(e.Values[0].Value) +} + +func (e MetaEntry) Uint32() uint32 { + if len(e.Values) == 0 { + return 0 + } + return binary.LittleEndian.Uint32(e.Values[0].Value) +} + +func (e MetaEntry) FileType() Type { + if len(e.Values) == 0 { + return TypeCount + } + return Type(e.Uint32()) +} + +func (e MetaEntry) GoString() string { + var b strings.Builder + b.WriteString(e.Key) + b.WriteString(": ") + b.WriteString(e.Type.String()) + b.WriteString("(") + for i, v := range e.Values { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(v.String()) + } + b.WriteString(")") + return b.String() +} + +type Info struct { + _ structs.Incomparable // prevent comparison of Info values so we can change the implementation later + + Version int + FileType Type +} + +func Stat(path string) (Info, error) { + f, err := os.Open(path) + if err != nil { + return Info{}, err + } + defer f.Close() + return StatReader(f) +} + +// StatReader reads the header information from r and returns an Info +// struct with the version and file type. +// +// It returns an error if any. +// +// As a special case, it returns ErrBadMagic if the file does not start with +// the magic bytes. This can be used to detect if the file is not a GGUF +// file. +func StatReader(r io.ReadSeeker) (Info, error) { + if _, err := r.Seek(0, 0); err != nil { + return Info{}, err + } + f, err := ReadFile(r) + if err != nil { + return Info{}, err + } + info := Info{Version: f.Version()} + for m, err := range f.Metadata { + if err != nil { + return Info{}, err + } + if m.Key == "general.file_type" { + if m.Type != ValueTypeUint32 { + return Info{}, fmt.Errorf("unexpected type for metadata key %q: %v, want %v", m.Key, m.Type, ValueTypeUint32) + } + info.FileType = m.FileType() + } + } + return info, nil +} + +type File struct { + version uint32 + numMetaValues uint64 + numTensors uint64 + + gr *ggufReader +} + +// ReadFile reads header information from r and returns a File, ready for +// iteration over Metadata and Tensors. +func ReadFile(r io.Reader) (*File, error) { + f, err := readFile(r) + if err != nil { + return nil, err + } + return f, nil +} + +func (f *File) Version() int { + return int(f.version) +} + +// Metadata iterates over the metadata in the file. It must be exhausted +// before calling Tensors. +// +// It is not resumable. +func (f *File) Metadata(yield func(MetaEntry, error) bool) { + var n int + for range f.numMetaValues { + meta, err := f.gr.readMetaEntry() + if err != nil { + err = fmt.Errorf("error reading metadata entry %d: %w", n, err) + yield(MetaEntry{}, err) + return + } + if !yield(meta, nil) { + return + } + n++ + } +} + +// Tensors iterates over the tensors in the file. It must only be called +// after exhausting the metadata iterator. +// +// It is not resumable. +func (f *File) Tensors(yield func(TensorInfo, error) bool) { + var last TensorInfo + for range f.numTensors { + info, err := f.gr.readTensorInfo() + + // If the last tensor had a valid offset, yield it. + // + // NOTE: No tensor should have an offset of 0 because the + // offset is the start of the tensor data which is always + // afer the magic bytes, version, numMetaValues, and + // numTensors, which MUST all be non-zero bytes as per the + // GGUF spec. + if last.Offset > 0 { + if !yield(last, err) { + return + } + } + if err != nil { + yield(TensorInfo{}, err) + return + } + + // Tensor data does not include size, so we need to + // calculate it based on the offset of the previous tensor + // offset to the current. + offset0 := last.Offset + last = info + last.Size = info.Offset - offset0 + } + if last.Offset > 0 { + yield(last, nil) + } +} + +var magicBytes = []byte{0x47, 0x47, 0x55, 0x46} + +func readFile(r io.Reader) (*File, error) { + gr := &ggufReader{r: &reader{r: r}} + magic, err := gr.next(4) + if err != nil { + return nil, errors.Join(err, ErrBadMagic) + } + if !bytes.Equal(magic, magicBytes) { + return nil, ErrBadMagic + } + version, err := gr.readUint32() + if err != nil { + return nil, err + } + if version != 3 { + return nil, fmt.Errorf("%w: %d", ErrUnsupportedVersion, version) + } + numTensors, err := gr.readUint64() + if err != nil { + return nil, err + } + numMetaValues, err := gr.readUint64() + if err != nil { + return nil, err + } + info := &File{ + version: version, + + numMetaValues: numMetaValues, + numTensors: numTensors, + gr: gr, + } + return info, nil +} diff --git a/encoding/gguf/gguf_test.go b/encoding/gguf/gguf_test.go new file mode 100644 index 00000000..2f6aa4f0 --- /dev/null +++ b/encoding/gguf/gguf_test.go @@ -0,0 +1,345 @@ +package gguf + +import ( + "errors" + "io" + "strings" + "testing" + + "kr.dev/diff" +) + +func TestStat(t *testing.T) { + cases := []struct { + name string + data string + wantInfo Info + wantErr error + }{ + { + name: "empty", + wantErr: ErrBadMagic, + }, + { + name: "bad magic", + data: "\xBB\xAA\xDD\x00", + wantErr: ErrBadMagic, + }, + { + name: "bad version", + data: string(magicBytes) + + "\x02\x00\x00\x00" + // version + "", + wantErr: ErrUnsupportedVersion, + }, + { + name: "valid general.file_type", + data: string(magicBytes) + // magic + "\x03\x00\x00\x00" + // version + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + + // general.file_type key + "\x11\x00\x00\x00\x00\x00\x00\x00" + // key length + "general.file_type" + // key + "\x04\x00\x00\x00" + // type (uint32) + "\x01\x00\x00\x00\x00\x00\x00\x00" + // uint32 value + "", + wantInfo: Info{ + Version: 3, + FileType: 1, + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + info, err := StatReader(strings.NewReader(tt.data)) + if tt.wantErr != nil { + if !errors.Is(err, tt.wantErr) { + t.Fatalf("err = %v; want %q", err, tt.wantErr) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + diff.Test(t, t.Errorf, info, tt.wantInfo) + }) + } +} + +func TestReadInfo(t *testing.T) { + cases := []struct { + name string + data string + + wantMeta []MetaEntry + wantTensor []TensorInfo + wantReadErr error + wantMetaErr error + wantTensorErr error + wantInfo Info + }{ + { + name: "empty", + wantReadErr: io.ErrUnexpectedEOF, + }, + { + name: "bad magic", + data: "\xBB\xAA\xDD\x00", + wantReadErr: ErrBadMagic, + }, + { + name: "bad version", + data: string(magicBytes) + + "\x02\x00\x00\x00" + // version + "", + wantReadErr: ErrUnsupportedVersion, + }, + { + name: "no metadata or tensors", + data: string(magicBytes) + // magic + "\x03\x00\x00\x00" + // version + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "", + wantReadErr: nil, + }, + { + name: "good metadata", + data: string(magicBytes) + // magic + "\x03\x00\x00\x00" + // version + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + "\x01\x00\x00\x00\x00\x00\x00\x00" + // key length + "K" + // key + "\x08\x00\x00\x00" + // type (string) + "\x02\x00\x00\x00\x00\x00\x00\x00" + // string length + "VV" + // string value + "", + wantMeta: []MetaEntry{ + {Key: "K", Type: ValueTypeString, Values: []MetaValue{{Type: ValueTypeString, Value: []byte("VV")}}}, + }, + }, + { + name: "good metadata with multiple values", + data: string(magicBytes) + // magic + "\x03\x00\x00\x00" + // version + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "\x02\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + + // MetaEntry 1 + "\x01\x00\x00\x00\x00\x00\x00\x00" + // key length + "x" + // key + "\x08\x00\x00\x00" + // type (string) + "\x02\x00\x00\x00\x00\x00\x00\x00" + // string length + "XX" + // string value + + // MetaEntry 2 + "\x01\x00\x00\x00\x00\x00\x00\x00" + // key length + "y" + // key + "\x04\x00\x00\x00" + // type (uint32) + "\x99\x88\x77\x66" + // uint32 value + "", + wantMeta: []MetaEntry{ + {Key: "x", Type: ValueTypeString, Values: []MetaValue{{ + Type: ValueTypeString, + Value: []byte("XX"), + }}}, + {Key: "y", Type: ValueTypeUint32, Values: []MetaValue{{ + Type: ValueTypeUint32, + Value: []byte{0x99, 0x88, 0x77, 0x66}, + }}}, + }, + }, + { + name: "negative string length in meta key", + data: string(magicBytes) + // magic + "\x03\x00\x00\x00" + // version + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + "\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF" + // key length + "K" + // key + "\x08\x00\x00\x00" + // type (string) + "\x02\x00\x00\x00\x00\x00\x00\x00" + // string length + "VV" + // string value + "", + wantMetaErr: ErrMangled, + }, + + // Tensor tests + { + name: "good tensor", + data: string(magicBytes) + // magic + "\x03\x00\x00\x00" + // version + "\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + + // Tensor 1 + "\x01\x00\x00\x00\x00\x00\x00\x00" + // name length + "t" + + + // dimensions + "\x01\x00\x00\x00" + // dimensions length + "\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0] + + "\x03\x00\x00\x00" + // type (i8) + "\x00\x01\x00\x00\x00\x00\x00\x00" + // offset + "", + wantTensor: []TensorInfo{ + { + Name: "t", + Dimensions: []uint64{1}, + Type: TypeQ4_1, + Offset: 256, + Size: 256, + }, + }, + }, + { + name: "too many dimensions", + data: string(magicBytes) + // magic + "\x03\x00\x00\x00" + // version + "\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + + // Tensor 1 + "\x01\x00\x00\x00\x00\x00\x00\x00" + // name length + "t" + + + "\x00\x00\x00\x01" + // dimensions length + "", + wantTensorErr: ErrMangled, + }, + { + name: "size computed", + data: string(magicBytes) + // magic + "\x03\x00\x00\x00" + // version + "\x02\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + + // Tensor 1 + "\x01\x00\x00\x00\x00\x00\x00\x00" + // name length + "t" + + "\x01\x00\x00\x00" + // dimensions length + "\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0] + "\x03\x00\x00\x00" + // type (i8) + "\x00\x01\x00\x00\x00\x00\x00\x00" + // offset + + // Tensor 2 + "\x01\x00\x00\x00\x00\x00\x00\x00" + // name length + "t" + + "\x01\x00\x00\x00" + // dimensions length + "\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0] + "\x03\x00\x00\x00" + // type (i8) + "\x00\x03\x00\x00\x00\x00\x00\x00" + // offset + "", + wantTensor: []TensorInfo{ + { + Name: "t", + Dimensions: []uint64{1}, + Type: TypeQ4_1, + Offset: 256, + Size: 256, + }, + { + Name: "t", + Dimensions: []uint64{1}, + Type: TypeQ4_1, + Offset: 768, + Size: 512, + }, + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + f, err := ReadFile(strings.NewReader(tt.data)) + if err != nil { + if !errors.Is(err, tt.wantReadErr) { + t.Fatalf("unexpected ReadFile error: %v", err) + } + return + } + + var got []MetaEntry + for meta, err := range f.Metadata { + if !errors.Is(err, tt.wantMetaErr) { + t.Fatalf("err = %v; want %v", err, ErrMangled) + } + if err != nil { + return + } + got = append(got, meta) + } + diff.Test(t, t.Errorf, got, tt.wantMeta) + + var gotT []TensorInfo + for tinfo, err := range f.Tensors { + if !errors.Is(err, tt.wantTensorErr) { + t.Fatalf("err = %v; want %v", err, tt.wantTensorErr) + } + if err != nil { + return + } + gotT = append(gotT, tinfo) + } + diff.Test(t, t.Errorf, gotT, tt.wantTensor) + }) + } +} + +func FuzzReadInfo(f *testing.F) { + f.Add(string(magicBytes)) + f.Add(string(magicBytes) + + "\x03\x00\x00\x00" + // version + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "") + f.Add(string(magicBytes) + + "\x03\x00\x00\x00" + // version + "\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + "\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "\x01\x00\x00\x00\x00\x00\x00\x00" + // key length + "K" + // key + "\x08\x00\x00\x00" + // type (string) + "\x02\x00\x00\x00\x00\x00\x00\x00" + // string length + "VV" + // string value + "\x01\x00\x00\x00\x00\x00\x00\x00" + // name length + "t" + + "\x01\x00\x00\x00" + // dimensions length + "\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0] + "\x03\x00\x00\x00" + // type (i8) + "\x05\x00\x00\x00\x00\x00\x00\x00" + // offset + "") + + f.Fuzz(func(t *testing.T, data string) { + gf, err := ReadFile(strings.NewReader(data)) + if err != nil { + t.Logf("ReadFile error: %v", err) + t.Skip() + } + for _, err := range gf.Metadata { + if err != nil { + t.Logf("metadata error: %v", err) + t.Skip() + } + } + for tinfo, err := range gf.Tensors { + if err != nil { + t.Logf("tensor error: %v", err) + t.Skip() + } + if tinfo.Offset <= 0 { + t.Logf("invalid tensor offset: %+v", t) + t.Skip() + } + if tinfo.Size <= 0 { + t.Logf("invalid tensor size: %+v", t) + t.Skip() + } + } + }) +} diff --git a/encoding/gguf/ggufio.go b/encoding/gguf/ggufio.go new file mode 100644 index 00000000..5179800b --- /dev/null +++ b/encoding/gguf/ggufio.go @@ -0,0 +1,195 @@ +package gguf + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "iter" +) + +type ggufReader struct { + r *reader + n int +} + +func (r *ggufReader) readMetaEntry() (MetaEntry, error) { + key, err := r.readString() + if err != nil { + return MetaEntry{}, err + } + typ, err := r.readValueType() + if err != nil { + return MetaEntry{}, err + } + var values []MetaValue + for v, err := range r.readMetaValues(typ) { + if err != nil { + err = fmt.Errorf("(key=%q type=%s): %w", key, typ, err) + return MetaEntry{}, err + } + values = append(values, v) + } + return MetaEntry{ + Key: string(key), + Type: typ, + Values: values, + }, nil +} + +func (r *ggufReader) readMetaValue(typ ValueType) (MetaValue, error) { + var value []byte + var err error + switch typ { + case ValueTypeUint8, ValueTypeInt8: + value, err = r.next(1) + case ValueTypeUint16, ValueTypeInt16: + value, err = r.next(2) + case ValueTypeUint32, ValueTypeInt32, ValueTypeFloat32: + value, err = r.next(4) + case ValueTypeUint64, ValueTypeInt64, ValueTypeFloat64: + value, err = r.next(8) + case ValueTypeBool: + value, err = r.next(1) + case ValueTypeString: + value, err = r.readString() + case ValueTypeArray: + err = fmt.Errorf("nested arrays are not supported") + default: + err = fmt.Errorf("unsupported metadata type: %d", typ) + } + if err != nil { + return MetaValue{}, err + } + return MetaValue{ + Type: typ, + Value: bytes.Clone(value), + }, nil +} + +func (r *ggufReader) readMetaValues(typ ValueType) iter.Seq2[MetaValue, error] { + return func(yield func(MetaValue, error) bool) { + if typ == ValueTypeArray { + atyp, err := r.readValueType() + if err != nil { + err = fmt.Errorf("invalid type: %w", err) + yield(MetaValue{}, err) + return + } + n, err := r.readUint64() + if err != nil { + err = fmt.Errorf("invalid length: %w", err) + yield(MetaValue{}, err) + return + } + for i := range n { + v, err := r.readMetaValue(atyp) + if err != nil { + err = fmt.Errorf("invalid entry (type=%s) %d: %w", atyp, i, err) + yield(MetaValue{}, err) + return + } + if !yield(v, nil) { + return + } + } + } else { + v, err := r.readMetaValue(typ) + if err != nil { + err = fmt.Errorf("error reading metadata value: %w", err) + yield(MetaValue{}, err) + return + } + yield(v, nil) + } + } +} + +func (r *ggufReader) readValueType() (ValueType, error) { + typ, err := r.readUint32() + return ValueType(typ), err +} + +func (r *ggufReader) readTensorInfo() (TensorInfo, error) { + name, err := r.readString() + if err != nil { + return TensorInfo{}, err + } + + numDimensions, err := r.readUint32() + if err != nil { + return TensorInfo{}, err + } + if numDimensions > MaxDimensions { + return TensorInfo{}, fmt.Errorf("%w: dimensions length (%d) exceeds %d", ErrMangled, numDimensions, MaxDimensions) + } + + dims := make([]uint64, numDimensions) + for i := range dims { + d, err := r.readUint64() + if err != nil { + return TensorInfo{}, err + } + dims[i] = d + } + typ, err := r.readUint32() + if err != nil { + return TensorInfo{}, err + } + offset, err := r.readUint64() + if err != nil { + return TensorInfo{}, err + } + + // TODO(bmizerany): check offset is multiple of ALIGNMENT + return TensorInfo{ + Name: string(name), + Dimensions: dims, + Type: Type(typ), + Offset: offset, + }, nil +} + +func (r *ggufReader) next(n int) ([]byte, error) { + if n < 0 { + return nil, errors.Join(fmt.Errorf("invalid read length: %d", n), ErrMangled) + } + w := r.r.window() + for len(w) < n { + if r.r.extend() == 0 { + return nil, io.ErrUnexpectedEOF + } + w = r.r.window() + } + r.r.release(n) + r.n += n + return w[:n], nil +} + +func (r *ggufReader) readString() ([]byte, error) { + n, err := r.readUint64() + if err != nil { + return nil, err + } + // TODO(bmizerany): limit max string length + return r.next(int(n)) +} + +func (r *ggufReader) readUint32() (uint32, error) { + b, err := r.next(4) + if err != nil { + return 0, err + } + n := binary.LittleEndian.Uint32(b) + return n, nil +} + +func (r *ggufReader) readUint64() (uint64, error) { + b, err := r.next(8) + if err != nil { + return 0, err + } + n := binary.LittleEndian.Uint64(b) + return n, nil +} diff --git a/encoding/gguf/reader.go b/encoding/gguf/reader.go new file mode 100644 index 00000000..7dadc469 --- /dev/null +++ b/encoding/gguf/reader.go @@ -0,0 +1,70 @@ +package gguf + +import "io" + +// A reader implements a sliding window over an io.Reader. +type reader struct { + data []byte + offset int + r io.Reader + err error +} + +// release discards n bytes from the front of the window. +func (b *reader) release(n int) { + b.offset += n +} + +// window returns the current window. +// The window is invalidated by calls to release or extend. +func (b *reader) window() []byte { + return b.data[b.offset:] +} + +// tuning constants for byteReader.extend. +const ( + newBufferSize = 8 << 10 + minReadSize = newBufferSize >> 2 +) + +// extend extends the window with data from the underlying reader. +func (b *reader) extend() int { + if b.err != nil { + return 0 + } + + remaining := len(b.data) - b.offset + if remaining == 0 { + b.data = b.data[:0] + b.offset = 0 + } + if cap(b.data)-len(b.data) >= minReadSize { + // nothing to do, enough space exists between len and cap. + } else if cap(b.data)-remaining >= minReadSize { + // buffer has enough space if we move the data to the front. + b.compact() + } else { + // otherwise, we must allocate/extend a new buffer + b.grow() + } + remaining += b.offset + n, err := b.r.Read(b.data[remaining:cap(b.data)]) + // reduce length to the existing plus the data we read. + b.data = b.data[:remaining+n] + b.err = err + return n +} + +// grow grows the buffer, moving the active data to the front. +func (b *reader) grow() { + buf := make([]byte, max(cap(b.data)*2, newBufferSize)) + copy(buf, b.data[b.offset:]) + b.data = buf + b.offset = 0 +} + +// compact moves the active data to the front of the buffer. +func (b *reader) compact() { + copy(b.data, b.data[b.offset:]) + b.offset = 0 +} diff --git a/encoding/gguf/testdata/fuzz/FuzzReadInfo/787da6e90e4be491 b/encoding/gguf/testdata/fuzz/FuzzReadInfo/787da6e90e4be491 new file mode 100644 index 00000000..e9bc63cd --- /dev/null +++ b/encoding/gguf/testdata/fuzz/FuzzReadInfo/787da6e90e4be491 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xa6\x00\x00\x00\x00\x00\x00\x00\x02\x000\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00") diff --git a/encoding/gguf/testdata/fuzz/FuzzReadInfo/8b42c37d144cd2c6 b/encoding/gguf/testdata/fuzz/FuzzReadInfo/8b42c37d144cd2c6 new file mode 100644 index 00000000..161b2fbb --- /dev/null +++ b/encoding/gguf/testdata/fuzz/FuzzReadInfo/8b42c37d144cd2c6 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("GGUF\x03\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00") diff --git a/encoding/gguf/testdata/fuzz/FuzzReadInfo/92b890e394a77cfc b/encoding/gguf/testdata/fuzz/FuzzReadInfo/92b890e394a77cfc new file mode 100644 index 00000000..e33f4f37 --- /dev/null +++ b/encoding/gguf/testdata/fuzz/FuzzReadInfo/92b890e394a77cfc @@ -0,0 +1,2 @@ +go test fuzz v1 +string("GGUF\x03\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xfd\xff\xff\xff\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00") diff --git a/encoding/gguf/testdata/fuzz/FuzzReadInfo/9cfd6a48931a2753 b/encoding/gguf/testdata/fuzz/FuzzReadInfo/9cfd6a48931a2753 new file mode 100644 index 00000000..42e7b8fd --- /dev/null +++ b/encoding/gguf/testdata/fuzz/FuzzReadInfo/9cfd6a48931a2753 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("GGUF\x03\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00K\b\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00VV\x01\x00\x00\x00\x00\\x00\\x00\\x00\\x00") diff --git a/encoding/gguf/testdata/fuzz/FuzzReadInfo/a8c5454e2a164af2 b/encoding/gguf/testdata/fuzz/FuzzReadInfo/a8c5454e2a164af2 new file mode 100644 index 00000000..05177332 --- /dev/null +++ b/encoding/gguf/testdata/fuzz/FuzzReadInfo/a8c5454e2a164af2 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xa6\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00") diff --git a/encoding/gguf/testdata/fuzz/FuzzReadInfo/a931e37cb6f932d4 b/encoding/gguf/testdata/fuzz/FuzzReadInfo/a931e37cb6f932d4 new file mode 100644 index 00000000..50588528 --- /dev/null +++ b/encoding/gguf/testdata/fuzz/FuzzReadInfo/a931e37cb6f932d4 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00") diff --git a/encoding/gguf/testdata/fuzz/FuzzReadInfo/bcd20fa73e7351a2 b/encoding/gguf/testdata/fuzz/FuzzReadInfo/bcd20fa73e7351a2 new file mode 100644 index 00000000..cbc68bd2 --- /dev/null +++ b/encoding/gguf/testdata/fuzz/FuzzReadInfo/bcd20fa73e7351a2 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("GGUF\x03\x00\x00\x0000000000000000000000000\xe5") diff --git a/encoding/gguf/testdata/fuzz/FuzzReadInfo/d29846a68e32052d b/encoding/gguf/testdata/fuzz/FuzzReadInfo/d29846a68e32052d new file mode 100644 index 00000000..d6cd186f --- /dev/null +++ b/encoding/gguf/testdata/fuzz/FuzzReadInfo/d29846a68e32052d @@ -0,0 +1,2 @@ +go test fuzz v1 +string("GGUF\x03\x00\x00\x0000000000\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x000\b\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x0000\x01\x00\x00\x00\x00\x00\x00\x000\x01\x00\x001\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\a") diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..54c56120 --- /dev/null +++ b/go.mod @@ -0,0 +1,30 @@ +module bllamo.com + +go 1.22.1 + +require kr.dev/diff v0.3.0 + +require ( + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/compress v1.17.6 // indirect + github.com/klauspost/cpuid/v2 v2.2.6 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/minio/md5-simd v1.1.2 // indirect + github.com/minio/minio-go/v7 v7.0.69 // indirect + github.com/minio/sha256-simd v1.0.1 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e // indirect + github.com/rogpeppe/go-internal v1.9.0 // indirect + github.com/rs/xid v1.5.0 // indirect + golang.org/x/crypto v0.19.0 // indirect + golang.org/x/exp v0.0.0-20220218215828-6cf2b201936e // indirect + golang.org/x/net v0.21.0 // indirect + golang.org/x/sync v0.6.0 // indirect + golang.org/x/sys v0.17.0 // indirect + golang.org/x/text v0.14.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..40dd784c --- /dev/null +++ b/go.sum @@ -0,0 +1,63 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI= +github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= +github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= +github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= +github.com/minio/minio-go/v7 v7.0.69 h1:l8AnsQFyY1xiwa/DaQskY4NXSLA2yrGsW5iD9nRPVS0= +github.com/minio/minio-go/v7 v7.0.69/go.mod h1:XAvOPJQ5Xlzk5o3o/ArO2NMbhSGkimC+bpW/ngRKDmQ= +github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM= +github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg= +github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/exp v0.0.0-20220218215828-6cf2b201936e h1:iWVPgObh6F4UDtjBLK51zsy5UHTPLQwCmsNjCsbKhQ0= +golang.org/x/exp v0.0.0-20220218215828-6cf2b201936e/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +kr.dev/diff v0.3.0 h1:o/T8/tkAq9IuRIuFqCupyKPC5iSY3WXpVZ2p6ZK3Emw= +kr.dev/diff v0.3.0/go.mod h1:XiTaLOg2/PD0cmXY7WQXUR8RAF3RwWpqIQEj910J2NY= diff --git a/model/file.go b/model/file.go new file mode 100644 index 00000000..2f112784 --- /dev/null +++ b/model/file.go @@ -0,0 +1,126 @@ +package model + +import ( + "bufio" + "io" + "iter" + "strings" +) + +type Param struct { + Key string + Value string +} + +type Message struct { + Role string + Content string +} + +type File struct { + // From is a required pragma that specifies the source of the model, + // either on disk, or by reference (see blob.ParseRef). + From string + + // Optional + Params []Param + Template string + System string + Adapter string + Messages []Message + + License string +} + +type Error struct { + Pragma string + Message string +} + +func (e *Error) Error() string { + return e.Pragma + ": " + e.Message +} + +type Pragma struct { + // The pragma name + Name string + + // Args contains the user-defined arguments for the pragma. If no + // arguments were provided, it is nil. + Args []string +} + +func (p Pragma) Arg(i int) string { + if i >= len(p.Args) { + return "" + } + return p.Args[i] +} + +func Pragmas(r io.Reader) iter.Seq2[Pragma, error] { + return func(yield func(Pragma, error) bool) { + sc := bufio.NewScanner(r) + for sc.Scan() { + line := sc.Text() + + // TODO(bmizerany): set a max num fields/args to + // prevent mem bloat + args := strings.Fields(line) + if len(args) == 0 { + continue + } + + p := Pragma{ + Name: strings.ToUpper(args[0]), + } + if p.Name == "MESSAGE" { + // handle special case where message content + // is space separated on the _rest_ of the + // line like: `MESSAGE user Is Ontario in + // Canada?` + panic("TODO") + } + if len(args) > 1 { + p.Args = args[1:] + } + if !yield(p, nil) { + return + } + } + if sc.Err() != nil { + yield(Pragma{}, sc.Err()) + } + } +} + +func Decode(r io.Reader) (File, error) { + var f File + for p, err := range Pragmas(r) { + if err != nil { + return File{}, err + } + switch p.Name { + case "FROM": + f.From = p.Arg(0) + case "PARAMETER": + f.Params = append(f.Params, Param{ + Key: strings.ToLower(p.Arg(0)), + Value: p.Arg(1), + }) + case "TEMPLATE": + f.Template = p.Arg(0) + case "SYSTEM": + f.System = p.Arg(0) + case "ADAPTER": + f.Adapter = p.Arg(0) + case "MESSAGE": + f.Messages = append(f.Messages, Message{ + Role: p.Arg(0), + Content: p.Arg(1), + }) + case "LICENSE": + f.License = p.Arg(0) + } + } + return f, nil +} diff --git a/oweb/oweb.go b/oweb/oweb.go new file mode 100644 index 00000000..a5cb499c --- /dev/null +++ b/oweb/oweb.go @@ -0,0 +1,143 @@ +package oweb + +import ( + "bytes" + "cmp" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "strings" +) + +type Error struct { + Status int `json:"-"` + Code string `json:"code"` + Message string `json:"message"` + Field string `json:"field,omitempty"` + RawBody []byte `json:"-"` +} + +func Missing(field string) error { + return &Error{ + Status: 400, + Code: "missing", + Field: field, + Message: fmt.Sprintf("%s is required", field), + } +} + +func Mistake(code, field, message string) error { + return &Error{ + Status: 400, + Code: code, + Field: field, + Message: fmt.Sprintf("%s: %s", field, message), + } +} + +func Fault(code, message string) error { + return &Error{ + Status: 500, + Code: "fault", + Message: message, + } +} + +func (e *Error) Error() string { + var b strings.Builder + b.WriteString("ollama: ") + b.WriteString(e.Code) + if e.Message != "" { + b.WriteString(": ") + b.WriteString(e.Message) + } + return b.String() +} + +// Convinience errors +var ( + ErrNotFound = &Error{Status: 404, Code: "not_found"} + ErrInternal = &Error{Status: 500, Code: "internal_error"} + ErrMethodNotAllowed = &Error{Status: 405, Code: "method_not_allowed"} +) + +type HandlerFunc func(w http.ResponseWriter, r *http.Request) error + +func Serve(h HandlerFunc, w http.ResponseWriter, r *http.Request) { + if err := h(w, r); err != nil { + // TODO: take a slog.Logger + log.Printf("error: %v", err) + var e *Error + if !errors.As(err, &e) { + e = ErrInternal + } + w.WriteHeader(cmp.Or(e.Status, 400)) + if err := EncodeJSON(w, e); err != nil { + log.Printf("error encoding error: %v", err) + } + } +} + +func DecodeUserJSON[T any](r io.Reader) (*T, error) { + v, err := DecodeJSON[T](r) + var e *json.SyntaxError + if errors.As(err, &e) { + return nil, &Error{Code: "invalid_json", Message: e.Error()} + } + var se *json.UnmarshalTypeError + if errors.As(err, &se) { + return nil, &Error{ + Code: "invalid_json", + Message: fmt.Sprintf("%s (%q) is not a %s", se.Field, se.Value, se.Type), + } + } + return v, err +} + +func DecodeJSON[T any](r io.Reader) (*T, error) { + var v *T + if err := json.NewDecoder(r).Decode(&v); err != nil { + var zero T + return &zero, err + } + return v, nil +} + +func EncodeJSON(w io.Writer, v any) error { + return json.NewEncoder(w).Encode(v) +} + +func Do[Res any](ctx context.Context, method, urlStr string, in any) (*Res, error) { + var body bytes.Buffer + if err := EncodeJSON(&body, in); err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, method, urlStr, &body) + if err != nil { + return nil, err + } + + res, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode/100 != 2 { + var b bytes.Buffer + body := io.TeeReader(res.Body, &b) + e, err := DecodeJSON[Error](body) + if err != nil { + return nil, err + } + e.RawBody = b.Bytes() + return nil, e + } + + return DecodeJSON[Res](res.Body) +} diff --git a/registry/apitypes.go b/registry/apitypes.go new file mode 100644 index 00000000..dc3e22d9 --- /dev/null +++ b/registry/apitypes.go @@ -0,0 +1,27 @@ +package registry + +type Manifest struct { + Layers []Layer `json:"layers"` +} + +type Layer struct { + Digest string `json:"digest"` + MediaType string `json:"mediaType"` + Size int64 `json:"size"` +} + +type PushRequest struct { + Manifest Manifest `json:"manifest"` +} + +type Requirement struct { + Digest string `json:"digest"` + Size int64 `json:"size"` + URL string `json:"url"` +} + +type PushResponse struct { + // Requirements is a list of digests that the client needs to push before + // repushing the manifest. + Requirements []Requirement `json:"requirements,omitempty"` +} diff --git a/registry/client.go b/registry/client.go new file mode 100644 index 00000000..2336ab19 --- /dev/null +++ b/registry/client.go @@ -0,0 +1,50 @@ +package registry + +import ( + "context" + "encoding/json" + "io" + "net/http" + + "bllamo.com/oweb" +) + +type Client struct { + BaseURL string +} + +// Push pushes a manifest to the server. +func (c *Client) Push(ctx context.Context, ref string, manifest []byte) ([]Requirement, error) { + // TODO(bmizerany): backoff + v, err := oweb.Do[PushResponse](ctx, "POST", c.BaseURL+"/v1/push/"+ref, struct { + Manifest json.RawMessage `json:"manifest"` + }{manifest}) + if err != nil { + return nil, err + } + return v.Requirements, nil +} + +func PushLayer(ctx context.Context, dstURL string, size int64, file io.Reader) error { + req, err := http.NewRequest("PUT", dstURL, file) + if err != nil { + return err + } + req.ContentLength = size + + res, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != 200 { + e := &oweb.Error{Status: res.StatusCode} + msg, err := io.ReadAll(res.Body) + if err != nil { + return err + } + // TODO(bmizerany): format error message + e.Message = string(msg) + } + return nil +} diff --git a/registry/server.go b/registry/server.go new file mode 100644 index 00000000..2da28c35 --- /dev/null +++ b/registry/server.go @@ -0,0 +1,117 @@ +// Package implements an Ollama registry client and server +package registry + +import ( + "cmp" + "context" + "errors" + "log" + "net/http" + "strings" + "time" + + "bllamo.com/oweb" + "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" +) + +type Server struct{} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if err := s.serveHTTP(w, r); err != nil { + log.Printf("error: %v", err) + var e *oweb.Error + if !errors.As(err, &e) { + e = oweb.ErrInternal + } + w.WriteHeader(cmp.Or(e.Status, 400)) + if err := oweb.EncodeJSON(w, e); err != nil { + log.Printf("error encoding error: %v", err) + } + } +} + +func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error { + switch { + case strings.HasPrefix(r.URL.Path, "/v1/push/"): + return s.handlePush(w, r) + case strings.HasPrefix(r.URL.Path, "/v1/pull/"): + return s.handlePull(w, r) + } + return oweb.ErrNotFound +} + +func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { + pr, err := oweb.DecodeUserJSON[PushRequest](r.Body) + if err != nil { + return err + } + + mc, err := minio.New("localhost:9000", &minio.Options{ + Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""), + Secure: false, + }) + + // TODO(bmizerany): parallelize + var requirements []Requirement + for _, l := range pr.Manifest.Layers { + if l.Size == 0 { + continue + } + + pushed, err := s.statObject(r.Context(), l.Digest) + if err != nil { + return err + } + if !pushed { + const expires = 1 * time.Hour + signedURL, err := mc.PresignedPutObject(r.Context(), "test", l.Digest, expires) + if err != nil { + return err + } + requirements = append(requirements, Requirement{ + Digest: l.Digest, + Size: l.Size, + + // TODO(bmizerany): use signed+temp urls + URL: signedURL.String(), + }) + } + } + + // TODO(bmizerany): commit to db + // ref, _ := strings.CutPrefix(r.URL.Path, "/v1/push/") + + return oweb.EncodeJSON(w, &PushResponse{Requirements: requirements}) +} + +func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error { + // lookup manifest + panic("TODO") +} + +func (s *Server) statObject(ctx context.Context, digest string) (pushed bool, err error) { + // TODO(bmizerany): hold client on *Server (hack for now) + mc, err := minio.New("localhost:9000", &minio.Options{ + Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""), + Secure: false, + }) + if err != nil { + return false, err + } + + // HEAD the object + _, err = mc.StatObject(ctx, "test", digest, minio.StatObjectOptions{}) + if err != nil { + if isNoSuchKey(err) { + err = nil + } + return false, err + } + return true, nil +} + +func isNoSuchKey(err error) bool { + var e minio.ErrorResponse + return errors.As(err, &e) && e.Code == "NoSuchKey" +} diff --git a/registry/server_test.go b/registry/server_test.go new file mode 100644 index 00000000..1457a0d8 --- /dev/null +++ b/registry/server_test.go @@ -0,0 +1,99 @@ +package registry + +import ( + "context" + "net/http/httptest" + "os/exec" + "strings" + "testing" + "time" + + "github.com/kr/pretty" + "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" + "kr.dev/diff" +) + +func TestPush(t *testing.T) { + startMinio(t) + + s := &Server{} + hs := httptest.NewServer(s) + t.Cleanup(hs.Close) + c := &Client{BaseURL: hs.URL} + + manifest := []byte(`{ + "layers": [ + {"digest": "sha256-1", "size": 1}, + {"digest": "sha256-2", "size": 2}, + {"digest": "sha256-3", "size": 3} + ] + }`) + + got, err := c.Push(context.Background(), "x+y", manifest) + if err != nil { + t.Fatal(err) + } + + diff.Test(t, t.Errorf, got, []Requirement{ + {Digest: "sha256-1", Size: 1}, + {Digest: "sha256-2", Size: 2}, + {Digest: "sha256-3", Size: 3}, + }, diff.ZeroFields[Requirement]("URL")) + + for _, r := range got { + body := strings.NewReader(strings.Repeat("x", int(r.Size))) + if err := PushLayer(context.Background(), r.URL, r.Size, body); err != nil { + t.Fatal(err) + } + } + + got, err = c.Push(context.Background(), "x+y", manifest) + if err != nil { + t.Fatal(err) + } + + if len(got) != 0 { + t.Fatalf("unexpected requirements: % #v", pretty.Formatter(got)) + } +} + +func startMinio(t *testing.T) { + t.Helper() + + dir := t.TempDir() + cmd := exec.Command("minio", "server", "--address", "localhost:9000", dir) + + // TODO(bmizerany): wait delay etc... + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + cmd.Process.Kill() + if err := cmd.Wait(); err != nil { + t.Log(err) + } + }) + + mc, err := minio.New("localhost:9000", &minio.Options{ + Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""), + Secure: false, + }) + if err != nil { + t.Fatal(err) + } + + // wait for server to start + // TODO(bmizerany): use backoff + for { + _, err := mc.ListBuckets(context.Background()) + if err == nil { + break + } + time.Sleep(100 * time.Millisecond) + } + + if err := mc.MakeBucket(context.Background(), "test", minio.MakeBucketOptions{}); err != nil { + t.Fatal(err) + } +} diff --git a/types/empty/message.go b/types/empty/message.go new file mode 100644 index 00000000..ab0f1022 --- /dev/null +++ b/types/empty/message.go @@ -0,0 +1,4 @@ +package empty + +// Message is a placeholder type used when encoding json messages. +type Message struct{} diff --git a/types/structs/structs.go b/types/structs/structs.go new file mode 100644 index 00000000..52929ebf --- /dev/null +++ b/types/structs/structs.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package structs contains the Incomparable type. +package structs + +// Incomparable is a zero-width incomparable type. If added as the +// first field in a struct, it marks that struct as not comparable +// (can't do == or be a map key) and usually doesn't add any width to +// the struct (unless the struct has only small fields). +// +// By making a struct incomparable, you can prevent misuse (prevent +// people from using ==), but also you can shrink generated binaries, +// as the compiler can omit equality funcs from the binary. +type Incomparable [0]func() diff --git a/types/they/want.go b/types/they/want.go new file mode 100644 index 00000000..4a601cc6 --- /dev/null +++ b/types/they/want.go @@ -0,0 +1,12 @@ +package they + +import ( + "net/http" + "strings" +) + +// Want returns true if the request method is method and the request path +// starts with pathPrefix. +func Want(r *http.Request, method string, pathPrefix string) bool { + return r.Method == method && strings.HasPrefix(r.URL.Path, pathPrefix) +} From 112ffed18988a94336b8e3753d2aae5ec0e998ac Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sun, 31 Mar 2024 09:20:40 -0700 Subject: [PATCH 02/29] oweb: move Error and Do to client/ollama This allows users of the ollama client library to need only import the client/ollama package, rather than the oweb package as well when inspecting errors. --- client/ollama/ollama.go | 78 ++++++++++++++++++++++++++++++++++++++- oweb/oweb.go | 81 ++++++++--------------------------------- registry/client.go | 6 +-- registry/server.go | 5 ++- 4 files changed, 97 insertions(+), 73 deletions(-) diff --git a/client/ollama/ollama.go b/client/ollama/ollama.go index 2c89c4cd..45b82126 100644 --- a/client/ollama/ollama.go +++ b/client/ollama/ollama.go @@ -1,14 +1,18 @@ package ollama import ( + "bytes" "cmp" "context" + "encoding/json" + "io" "io/fs" "iter" + "net/http" "os" + "strings" "bllamo.com/client/ollama/apitype" - "bllamo.com/oweb" "bllamo.com/types/empty" ) @@ -41,7 +45,7 @@ func (c *Client) Build(ctx context.Context, ref string, modelfile []byte, source // Push requests the remote Ollama service to push a model to the server. func (c *Client) Push(ctx context.Context, ref string) error { - _, err := oweb.Do[empty.Message](ctx, "POST", c.BaseURL+"/v1/push", apitype.PushRequest{Name: ref}) + _, err := Do[empty.Message](ctx, "POST", c.BaseURL+"/v1/push", apitype.PushRequest{Name: ref}) return err } @@ -68,3 +72,73 @@ func (c *Client) Copy(ctx context.Context, dstRef, srcRef string) error { func (c *Client) Run(ctx context.Context, ref string, messages []apitype.Message) error { panic("not implemented") } + +type Error struct { + Status int `json:"-"` + Code string `json:"code"` + Message string `json:"message"` + Field string `json:"field,omitempty"` + RawBody []byte `json:"-"` +} + +func (e *Error) Error() string { + var b strings.Builder + b.WriteString("ollama: ") + b.WriteString(e.Code) + if e.Message != "" { + b.WriteString(": ") + b.WriteString(e.Message) + } + return b.String() +} + +func Do[Res any](ctx context.Context, method, urlStr string, in any) (*Res, error) { + var body bytes.Buffer + // TODO(bmizerany): pool and reuse this buffer AND the encoder + if err := encodeJSON(&body, in); err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, method, urlStr, &body) + if err != nil { + return nil, err + } + + res, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode/100 != 2 { + var b bytes.Buffer + body := io.TeeReader(res.Body, &b) + e, err := decodeJSON[Error](body) + if err != nil { + return nil, err + } + e.RawBody = b.Bytes() + return nil, e + } + + return decodeJSON[Res](res.Body) +} + +// decodeJSON decodes JSON from r into a new value of type T. +// +// NOTE: This is (and encodeJSON) are copies and paste from oweb.go, please +// do not try and consolidate so we can keep ollama/client free from +// dependencies which are moving targets and not pulling enough weight to +// justify their inclusion. +func decodeJSON[T any](r io.Reader) (*T, error) { + var v T + if err := json.NewDecoder(r).Decode(&v); err != nil { + return nil, err + } + return &v, nil +} + +// NOTE: see NOT above decodeJSON +func encodeJSON(w io.Writer, v any) error { + // TODO(bmizerany): pool and reuse encoder + return json.NewEncoder(w).Encode(v) +} diff --git a/oweb/oweb.go b/oweb/oweb.go index a5cb499c..f48e0496 100644 --- a/oweb/oweb.go +++ b/oweb/oweb.go @@ -1,28 +1,19 @@ package oweb import ( - "bytes" "cmp" - "context" "encoding/json" "errors" "fmt" "io" "log" "net/http" - "strings" + + "bllamo.com/client/ollama" ) -type Error struct { - Status int `json:"-"` - Code string `json:"code"` - Message string `json:"message"` - Field string `json:"field,omitempty"` - RawBody []byte `json:"-"` -} - func Missing(field string) error { - return &Error{ + return &ollama.Error{ Status: 400, Code: "missing", Field: field, @@ -31,7 +22,7 @@ func Missing(field string) error { } func Mistake(code, field, message string) error { - return &Error{ + return &ollama.Error{ Status: 400, Code: code, Field: field, @@ -40,29 +31,18 @@ func Mistake(code, field, message string) error { } func Fault(code, message string) error { - return &Error{ + return &ollama.Error{ Status: 500, Code: "fault", Message: message, } } -func (e *Error) Error() string { - var b strings.Builder - b.WriteString("ollama: ") - b.WriteString(e.Code) - if e.Message != "" { - b.WriteString(": ") - b.WriteString(e.Message) - } - return b.String() -} - // Convinience errors var ( - ErrNotFound = &Error{Status: 404, Code: "not_found"} - ErrInternal = &Error{Status: 500, Code: "internal_error"} - ErrMethodNotAllowed = &Error{Status: 405, Code: "method_not_allowed"} + ErrNotFound = &ollama.Error{Status: 404, Code: "not_found"} + ErrInternal = &ollama.Error{Status: 500, Code: "internal_error"} + ErrMethodNotAllowed = &ollama.Error{Status: 405, Code: "method_not_allowed"} ) type HandlerFunc func(w http.ResponseWriter, r *http.Request) error @@ -71,12 +51,12 @@ func Serve(h HandlerFunc, w http.ResponseWriter, r *http.Request) { if err := h(w, r); err != nil { // TODO: take a slog.Logger log.Printf("error: %v", err) - var e *Error - if !errors.As(err, &e) { - e = ErrInternal + var oe *ollama.Error + if !errors.As(err, &oe) { + oe = ErrInternal } - w.WriteHeader(cmp.Or(e.Status, 400)) - if err := EncodeJSON(w, e); err != nil { + w.WriteHeader(cmp.Or(oe.Status, 400)) + if err := EncodeJSON(w, oe); err != nil { log.Printf("error encoding error: %v", err) } } @@ -86,11 +66,11 @@ func DecodeUserJSON[T any](r io.Reader) (*T, error) { v, err := DecodeJSON[T](r) var e *json.SyntaxError if errors.As(err, &e) { - return nil, &Error{Code: "invalid_json", Message: e.Error()} + return nil, &ollama.Error{Code: "invalid_json", Message: e.Error()} } var se *json.UnmarshalTypeError if errors.As(err, &se) { - return nil, &Error{ + return nil, &ollama.Error{ Code: "invalid_json", Message: fmt.Sprintf("%s (%q) is not a %s", se.Field, se.Value, se.Type), } @@ -110,34 +90,3 @@ func DecodeJSON[T any](r io.Reader) (*T, error) { func EncodeJSON(w io.Writer, v any) error { return json.NewEncoder(w).Encode(v) } - -func Do[Res any](ctx context.Context, method, urlStr string, in any) (*Res, error) { - var body bytes.Buffer - if err := EncodeJSON(&body, in); err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, method, urlStr, &body) - if err != nil { - return nil, err - } - - res, err := http.DefaultClient.Do(req) - if err != nil { - return nil, err - } - defer res.Body.Close() - - if res.StatusCode/100 != 2 { - var b bytes.Buffer - body := io.TeeReader(res.Body, &b) - e, err := DecodeJSON[Error](body) - if err != nil { - return nil, err - } - e.RawBody = b.Bytes() - return nil, e - } - - return DecodeJSON[Res](res.Body) -} diff --git a/registry/client.go b/registry/client.go index 2336ab19..b26be554 100644 --- a/registry/client.go +++ b/registry/client.go @@ -6,7 +6,7 @@ import ( "io" "net/http" - "bllamo.com/oweb" + "bllamo.com/client/ollama" ) type Client struct { @@ -16,7 +16,7 @@ type Client struct { // Push pushes a manifest to the server. func (c *Client) Push(ctx context.Context, ref string, manifest []byte) ([]Requirement, error) { // TODO(bmizerany): backoff - v, err := oweb.Do[PushResponse](ctx, "POST", c.BaseURL+"/v1/push/"+ref, struct { + v, err := ollama.Do[PushResponse](ctx, "POST", c.BaseURL+"/v1/push/"+ref, struct { Manifest json.RawMessage `json:"manifest"` }{manifest}) if err != nil { @@ -38,7 +38,7 @@ func PushLayer(ctx context.Context, dstURL string, size int64, file io.Reader) e } defer res.Body.Close() if res.StatusCode != 200 { - e := &oweb.Error{Status: res.StatusCode} + e := &ollama.Error{Status: res.StatusCode} msg, err := io.ReadAll(res.Body) if err != nil { return err diff --git a/registry/server.go b/registry/server.go index 2da28c35..991a1240 100644 --- a/registry/server.go +++ b/registry/server.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "bllamo.com/client/ollama" "bllamo.com/oweb" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" @@ -19,8 +20,8 @@ type Server struct{} func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err := s.serveHTTP(w, r); err != nil { - log.Printf("error: %v", err) - var e *oweb.Error + log.Printf("error: %v", err) // TODO(bmizerany): take a slog.Logger + var e *ollama.Error if !errors.As(err, &e) { e = oweb.ErrInternal } From cd5df121a5c8f0d985f9ec4cab7ffc9164d2840a Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sun, 31 Mar 2024 09:30:01 -0700 Subject: [PATCH 03/29] client: include Status in json Error response for symmetry. Also, remove RawBody from error, which was previously used for debugging. --- client/ollama/ollama.go | 8 ++------ oweb/oweb.go | 3 ++- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/client/ollama/ollama.go b/client/ollama/ollama.go index 45b82126..06ee7fa2 100644 --- a/client/ollama/ollama.go +++ b/client/ollama/ollama.go @@ -74,11 +74,10 @@ func (c *Client) Run(ctx context.Context, ref string, messages []apitype.Message } type Error struct { - Status int `json:"-"` + Status int `json:"status"` Code string `json:"code"` Message string `json:"message"` Field string `json:"field,omitempty"` - RawBody []byte `json:"-"` } func (e *Error) Error() string { @@ -110,13 +109,10 @@ func Do[Res any](ctx context.Context, method, urlStr string, in any) (*Res, erro defer res.Body.Close() if res.StatusCode/100 != 2 { - var b bytes.Buffer - body := io.TeeReader(res.Body, &b) - e, err := decodeJSON[Error](body) + e, err := decodeJSON[Error](res.Body) if err != nil { return nil, err } - e.RawBody = b.Bytes() return nil, e } diff --git a/oweb/oweb.go b/oweb/oweb.go index f48e0496..352ba0c6 100644 --- a/oweb/oweb.go +++ b/oweb/oweb.go @@ -55,7 +55,8 @@ func Serve(h HandlerFunc, w http.ResponseWriter, r *http.Request) { if !errors.As(err, &oe) { oe = ErrInternal } - w.WriteHeader(cmp.Or(oe.Status, 400)) + oe.Status = cmp.Or(oe.Status, 400) + w.WriteHeader(oe.Status) if err := EncodeJSON(w, oe); err != nil { log.Printf("error encoding error: %v", err) } From e1d457c73ee80fe9bb4044d5b7106bc6bf4301c7 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sun, 31 Mar 2024 09:34:58 -0700 Subject: [PATCH 04/29] client/ollama: report invalid server error response with raw bytes --- client/ollama/ollama.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/client/ollama/ollama.go b/client/ollama/ollama.go index 06ee7fa2..1101b5b9 100644 --- a/client/ollama/ollama.go +++ b/client/ollama/ollama.go @@ -5,6 +5,7 @@ import ( "cmp" "context" "encoding/json" + "fmt" "io" "io/fs" "iter" @@ -109,8 +110,11 @@ func Do[Res any](ctx context.Context, method, urlStr string, in any) (*Res, erro defer res.Body.Close() if res.StatusCode/100 != 2 { - e, err := decodeJSON[Error](res.Body) + var buf bytes.Buffer + body := io.TeeReader(res.Body, &buf) + e, err := decodeJSON[Error](body) if err != nil { + err := fmt.Errorf("ollama: invalid error response from server (status %d): %q", res.StatusCode, buf.String()) return nil, err } return nil, e From f6e02d4bc7a5b5a33d6a4e0809598a8b68bb0a05 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sun, 31 Mar 2024 10:52:56 -0700 Subject: [PATCH 05/29] client/ollama: Do take a *Client --- client/ollama/ollama.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/client/ollama/ollama.go b/client/ollama/ollama.go index 1101b5b9..631d9417 100644 --- a/client/ollama/ollama.go +++ b/client/ollama/ollama.go @@ -46,7 +46,7 @@ func (c *Client) Build(ctx context.Context, ref string, modelfile []byte, source // Push requests the remote Ollama service to push a model to the server. func (c *Client) Push(ctx context.Context, ref string) error { - _, err := Do[empty.Message](ctx, "POST", c.BaseURL+"/v1/push", apitype.PushRequest{Name: ref}) + _, err := Do[empty.Message](ctx, c, "POST", "/v1/push", apitype.PushRequest{Name: ref}) return err } @@ -92,12 +92,13 @@ func (e *Error) Error() string { return b.String() } -func Do[Res any](ctx context.Context, method, urlStr string, in any) (*Res, error) { +func Do[Res any](ctx context.Context, c *Client, method, path string, in any) (*Res, error) { var body bytes.Buffer // TODO(bmizerany): pool and reuse this buffer AND the encoder if err := encodeJSON(&body, in); err != nil { return nil, err } + urlStr := c.BaseURL + path req, err := http.NewRequestWithContext(ctx, method, urlStr, &body) if err != nil { return nil, err From 6acc205de045527b83d12c0aa1d2bb598de067fa Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sun, 31 Mar 2024 10:54:17 -0700 Subject: [PATCH 06/29] client/ollama: install and use (*Client).HTTPClient --- client/ollama/ollama.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/client/ollama/ollama.go b/client/ollama/ollama.go index 631d9417..b491cda5 100644 --- a/client/ollama/ollama.go +++ b/client/ollama/ollama.go @@ -36,6 +36,8 @@ var I_Acknowledge_This_API_Is_Under_Development bool type Client struct { // BaseURL is the base URL of the Ollama API. BaseURL string + + HTTPClient *http.Client // The HTTP client to use. If nil, http.DefaultClient is used. } // Build requests the remote Ollama service to build a model. It uploads any @@ -104,7 +106,8 @@ func Do[Res any](ctx context.Context, c *Client, method, path string, in any) (* return nil, err } - res, err := http.DefaultClient.Do(req) + hc := cmp.Or(c.HTTPClient, http.DefaultClient) + res, err := hc.Do(req) if err != nil { return nil, err } From a32e7857b221e2eba0f98fd1fa47ed24ab8c02c4 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sun, 31 Mar 2024 11:00:07 -0700 Subject: [PATCH 07/29] client/ollama: docs for Error type --- client/ollama/ollama.go | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/client/ollama/ollama.go b/client/ollama/ollama.go index b491cda5..6e8b6f94 100644 --- a/client/ollama/ollama.go +++ b/client/ollama/ollama.go @@ -77,10 +77,21 @@ func (c *Client) Run(ctx context.Context, ref string, messages []apitype.Message } type Error struct { - Status int `json:"status"` - Code string `json:"code"` + // Status is the HTTP status code returned by the server. + Status int `json:"status"` + + // Code specifies a machine readable code indicating the class of + // error this error is. See http://docs.ollama.com/errors for a full + // list of error codes. + Code string `json:"code"` + + // Message is a humage readable message that describes the error. It + // may change across versions of the API, so it should not be used for + // programmatic decisions. Message string `json:"message"` - Field string `json:"field,omitempty"` + + // Field is the field in the request that caused the error, if any. + Field string `json:"field,omitempty"` } func (e *Error) Error() string { From 5182a1dfb1407d8e47b6fb19e6b10b66baec1e14 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sun, 31 Mar 2024 11:04:20 -0700 Subject: [PATCH 08/29] client/ollama: document Do --- client/ollama/ollama.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/client/ollama/ollama.go b/client/ollama/ollama.go index 6e8b6f94..9edfc73c 100644 --- a/client/ollama/ollama.go +++ b/client/ollama/ollama.go @@ -105,6 +105,9 @@ func (e *Error) Error() string { return b.String() } +// Do encodes in and sends it in a request to the Ollama server and decodes +// the response into Res, or an error response (non-2xx) into an *Error, or +// any error encounted decoding the response. func Do[Res any](ctx context.Context, c *Client, method, path string, in any) (*Res, error) { var body bytes.Buffer // TODO(bmizerany): pool and reuse this buffer AND the encoder From c87fe7df48a766e2a0ed74d54a2cffb4ea46623c Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sun, 31 Mar 2024 11:12:50 -0700 Subject: [PATCH 09/29] client/ollama: make Error.Message optional --- client/ollama/ollama.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/ollama/ollama.go b/client/ollama/ollama.go index 9edfc73c..c87a5656 100644 --- a/client/ollama/ollama.go +++ b/client/ollama/ollama.go @@ -88,7 +88,7 @@ type Error struct { // Message is a humage readable message that describes the error. It // may change across versions of the API, so it should not be used for // programmatic decisions. - Message string `json:"message"` + Message string `json:"message,omitempty"` // Field is the field in the request that caused the error, if any. Field string `json:"field,omitempty"` From eb2c442a015741fc37af6d22109dd3adea594105 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sun, 31 Mar 2024 11:36:51 -0700 Subject: [PATCH 10/29] oweb: make DecodeUserJSON take a field name This allows for better error messages when decoding fails. For example, instead of: {"code":"invalid_json","message":"unexpected end of JSON input"} We now get: {"code":"invalid_json","field":"manifest","message":"unexpected end of JSON input"} --- api/api.go | 4 ++++ oweb/oweb.go | 15 ++++++++------- registry/apitypes.go | 5 ++++- registry/client.go | 15 ++++++++++----- registry/server.go | 14 ++++++++++---- 5 files changed, 36 insertions(+), 17 deletions(-) diff --git a/api/api.go b/api/api.go index fd525937..d8fad5dd 100644 --- a/api/api.go +++ b/api/api.go @@ -104,3 +104,7 @@ func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error { return err } + +func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error { + return oweb.ErrNotFound +} diff --git a/oweb/oweb.go b/oweb/oweb.go index 352ba0c6..6c586c14 100644 --- a/oweb/oweb.go +++ b/oweb/oweb.go @@ -63,20 +63,21 @@ func Serve(h HandlerFunc, w http.ResponseWriter, r *http.Request) { } } -func DecodeUserJSON[T any](r io.Reader) (*T, error) { +func DecodeUserJSON[T any](field string, r io.Reader) (*T, error) { v, err := DecodeJSON[T](r) + if err == nil { + return v, nil + } + var msg string var e *json.SyntaxError if errors.As(err, &e) { - return nil, &ollama.Error{Code: "invalid_json", Message: e.Error()} + msg = e.Error() } var se *json.UnmarshalTypeError if errors.As(err, &se) { - return nil, &ollama.Error{ - Code: "invalid_json", - Message: fmt.Sprintf("%s (%q) is not a %s", se.Field, se.Value, se.Type), - } + msg = fmt.Sprintf("%s (%q) is not a %s", se.Field, se.Value, se.Type) } - return v, err + return nil, Mistake("invalid_json", field, msg) } func DecodeJSON[T any](r io.Reader) (*T, error) { diff --git a/registry/apitypes.go b/registry/apitypes.go index dc3e22d9..599bfe9a 100644 --- a/registry/apitypes.go +++ b/registry/apitypes.go @@ -1,5 +1,7 @@ package registry +import "encoding/json" + type Manifest struct { Layers []Layer `json:"layers"` } @@ -11,7 +13,8 @@ type Layer struct { } type PushRequest struct { - Manifest Manifest `json:"manifest"` + Ref string `json:"ref"` + Manifest json.RawMessage } type Requirement struct { diff --git a/registry/client.go b/registry/client.go index b26be554..1e3e9c88 100644 --- a/registry/client.go +++ b/registry/client.go @@ -2,7 +2,6 @@ package registry import ( "context" - "encoding/json" "io" "net/http" @@ -10,15 +9,21 @@ import ( ) type Client struct { - BaseURL string + BaseURL string + HTTPClient *http.Client +} + +func (c *Client) oclient() *ollama.Client { + return (*ollama.Client)(c) } // Push pushes a manifest to the server. func (c *Client) Push(ctx context.Context, ref string, manifest []byte) ([]Requirement, error) { // TODO(bmizerany): backoff - v, err := ollama.Do[PushResponse](ctx, "POST", c.BaseURL+"/v1/push/"+ref, struct { - Manifest json.RawMessage `json:"manifest"` - }{manifest}) + v, err := ollama.Do[PushResponse](ctx, c.oclient(), "POST", "/v1/push", &PushRequest{ + Ref: ref, + Manifest: manifest, + }) if err != nil { return nil, err } diff --git a/registry/server.go b/registry/server.go index 991a1240..f1d03cc5 100644 --- a/registry/server.go +++ b/registry/server.go @@ -2,6 +2,7 @@ package registry import ( + "bytes" "cmp" "context" "errors" @@ -34,16 +35,16 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error { switch { - case strings.HasPrefix(r.URL.Path, "/v1/push/"): + case strings.HasPrefix(r.URL.Path, "/v1/push"): return s.handlePush(w, r) - case strings.HasPrefix(r.URL.Path, "/v1/pull/"): + case strings.HasPrefix(r.URL.Path, "/v1/pull"): return s.handlePull(w, r) } return oweb.ErrNotFound } func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { - pr, err := oweb.DecodeUserJSON[PushRequest](r.Body) + pr, err := oweb.DecodeUserJSON[PushRequest]("", r.Body) if err != nil { return err } @@ -53,9 +54,14 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { Secure: false, }) + m, err := oweb.DecodeUserJSON[Manifest]("manifest", bytes.NewReader(pr.Manifest)) + if err != nil { + return err + } + // TODO(bmizerany): parallelize var requirements []Requirement - for _, l := range pr.Manifest.Layers { + for _, l := range m.Layers { if l.Size == 0 { continue } From 48c60c01e2186c061ab1a55842c9087e310fe299 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sun, 31 Mar 2024 12:23:10 -0700 Subject: [PATCH 11/29] registry: move req/resp types to registry/apitype --- registry/{apitypes.go => apitype/apitype.go} | 2 +- registry/client.go | 5 +++-- registry/server.go | 11 ++++++----- registry/server_test.go | 5 +++-- 4 files changed, 13 insertions(+), 10 deletions(-) rename registry/{apitypes.go => apitype/apitype.go} (97%) diff --git a/registry/apitypes.go b/registry/apitype/apitype.go similarity index 97% rename from registry/apitypes.go rename to registry/apitype/apitype.go index 599bfe9a..e33dfe34 100644 --- a/registry/apitypes.go +++ b/registry/apitype/apitype.go @@ -1,4 +1,4 @@ -package registry +package apitype import "encoding/json" diff --git a/registry/client.go b/registry/client.go index 1e3e9c88..82616380 100644 --- a/registry/client.go +++ b/registry/client.go @@ -6,6 +6,7 @@ import ( "net/http" "bllamo.com/client/ollama" + "bllamo.com/registry/apitype" ) type Client struct { @@ -18,9 +19,9 @@ func (c *Client) oclient() *ollama.Client { } // Push pushes a manifest to the server. -func (c *Client) Push(ctx context.Context, ref string, manifest []byte) ([]Requirement, error) { +func (c *Client) Push(ctx context.Context, ref string, manifest []byte) ([]apitype.Requirement, error) { // TODO(bmizerany): backoff - v, err := ollama.Do[PushResponse](ctx, c.oclient(), "POST", "/v1/push", &PushRequest{ + v, err := ollama.Do[apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{ Ref: ref, Manifest: manifest, }) diff --git a/registry/server.go b/registry/server.go index f1d03cc5..a12e210c 100644 --- a/registry/server.go +++ b/registry/server.go @@ -13,6 +13,7 @@ import ( "bllamo.com/client/ollama" "bllamo.com/oweb" + "bllamo.com/registry/apitype" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" ) @@ -44,7 +45,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error { } func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { - pr, err := oweb.DecodeUserJSON[PushRequest]("", r.Body) + pr, err := oweb.DecodeUserJSON[apitype.PushRequest]("", r.Body) if err != nil { return err } @@ -54,13 +55,13 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { Secure: false, }) - m, err := oweb.DecodeUserJSON[Manifest]("manifest", bytes.NewReader(pr.Manifest)) + m, err := oweb.DecodeUserJSON[apitype.Manifest]("manifest", bytes.NewReader(pr.Manifest)) if err != nil { return err } // TODO(bmizerany): parallelize - var requirements []Requirement + var requirements []apitype.Requirement for _, l := range m.Layers { if l.Size == 0 { continue @@ -76,7 +77,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { if err != nil { return err } - requirements = append(requirements, Requirement{ + requirements = append(requirements, apitype.Requirement{ Digest: l.Digest, Size: l.Size, @@ -89,7 +90,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { // TODO(bmizerany): commit to db // ref, _ := strings.CutPrefix(r.URL.Path, "/v1/push/") - return oweb.EncodeJSON(w, &PushResponse{Requirements: requirements}) + return oweb.EncodeJSON(w, &apitype.PushResponse{Requirements: requirements}) } func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error { diff --git a/registry/server_test.go b/registry/server_test.go index 1457a0d8..79c1875b 100644 --- a/registry/server_test.go +++ b/registry/server_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "bllamo.com/registry/apitype" "github.com/kr/pretty" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" @@ -35,11 +36,11 @@ func TestPush(t *testing.T) { t.Fatal(err) } - diff.Test(t, t.Errorf, got, []Requirement{ + diff.Test(t, t.Errorf, got, []apitype.Requirement{ {Digest: "sha256-1", Size: 1}, {Digest: "sha256-2", Size: 2}, {Digest: "sha256-3", Size: 3}, - }, diff.ZeroFields[Requirement]("URL")) + }, diff.ZeroFields[apitype.Requirement]("URL")) for _, r := range got { body := strings.NewReader(strings.Repeat("x", int(r.Size))) From 60ef0e6b4a4ca13be78eb2b0ef9cacbaed87f396 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sun, 31 Mar 2024 15:00:25 -0700 Subject: [PATCH 12/29] oweb: remove Fault Also, fix typo in the comment. --- oweb/oweb.go | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/oweb/oweb.go b/oweb/oweb.go index 6c586c14..21da688f 100644 --- a/oweb/oweb.go +++ b/oweb/oweb.go @@ -30,15 +30,7 @@ func Mistake(code, field, message string) error { } } -func Fault(code, message string) error { - return &ollama.Error{ - Status: 500, - Code: "fault", - Message: message, - } -} - -// Convinience errors +// Convenience errors var ( ErrNotFound = &ollama.Error{Status: 404, Code: "not_found"} ErrInternal = &ollama.Error{Status: 500, Code: "internal_error"} From c0eddb10fd98714583b2197dbd9f8ce84f328241 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sun, 31 Mar 2024 15:01:26 -0700 Subject: [PATCH 13/29] registry: use exact match on path --- registry/server.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/registry/server.go b/registry/server.go index a12e210c..60605fc2 100644 --- a/registry/server.go +++ b/registry/server.go @@ -8,7 +8,6 @@ import ( "errors" "log" "net/http" - "strings" "time" "bllamo.com/client/ollama" @@ -35,13 +34,14 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error { - switch { - case strings.HasPrefix(r.URL.Path, "/v1/push"): + switch r.URL.Path { + case "/v1/push": return s.handlePush(w, r) - case strings.HasPrefix(r.URL.Path, "/v1/pull"): + case "/v1/pull": return s.handlePull(w, r) + default: + return oweb.ErrNotFound } - return oweb.ErrNotFound } func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { From 04f38cf3f4b4f730c0e814b2b7c2e9af0f25f45b Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sun, 31 Mar 2024 15:09:04 -0700 Subject: [PATCH 14/29] registry: commit manifest on successful /v1/push --- build/build.go | 5 +++++ registry/server.go | 17 +++++++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/build/build.go b/build/build.go index f66138db..21491102 100644 --- a/build/build.go +++ b/build/build.go @@ -6,6 +6,7 @@ import ( "fmt" "io/fs" "os" + "path" "bllamo.com/build/blob" "bllamo.com/build/internal/blobstore" @@ -21,6 +22,10 @@ var ( ErrNotFound = errors.New("not found") ) +func ManifestKey(domain string, ref blob.Ref) string { + return path.Join("manifests", domain, ref.Name(), ref.Tag(), ref.Build()) +} + type mediaType string // Known media types diff --git a/registry/server.go b/registry/server.go index 60605fc2..b22042bb 100644 --- a/registry/server.go +++ b/registry/server.go @@ -10,6 +10,8 @@ import ( "net/http" "time" + "bllamo.com/build" + "bllamo.com/build/blob" "bllamo.com/client/ollama" "bllamo.com/oweb" "bllamo.com/registry/apitype" @@ -50,6 +52,11 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { return err } + ref := blob.ParseRef(pr.Ref) + if !ref.FullyQualified() { + return oweb.Mistake("invalid", "name", "must be fully qualified") + } + mc, err := minio.New("localhost:9000", &minio.Options{ Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""), Secure: false, @@ -87,8 +94,14 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { } } - // TODO(bmizerany): commit to db - // ref, _ := strings.CutPrefix(r.URL.Path, "/v1/push/") + if len(requirements) == 0 { + const cheatTODO = "registry.ollama.ai/library" + key := build.ManifestKey(cheatTODO, ref) + _, err := mc.PutObject(r.Context(), "test", key, bytes.NewReader(pr.Manifest), int64(len(pr.Manifest)), minio.PutObjectOptions{}) + if err != nil { + return err + } + } return oweb.EncodeJSON(w, &apitype.PushResponse{Requirements: requirements}) } From fd411b3cf685235a04115b47898a296f11ce4050 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sun, 31 Mar 2024 18:20:19 -0700 Subject: [PATCH 15/29] registry: commit Manifest --- build/blob/ref.go | 42 ++++++++++++++++++++--------- build/blob/ref_test.go | 34 ++++++++++++------------ build/build.go | 5 ---- build/internal/blobstore/blob.go | 7 +---- registry/server.go | 25 +++++++++++++----- registry/server_test.go | 45 ++++++++++++++++++++++++++++++++ 6 files changed, 111 insertions(+), 47 deletions(-) diff --git a/build/blob/ref.go b/build/blob/ref.go index 9a033fcb..dc7bdcef 100644 --- a/build/blob/ref.go +++ b/build/blob/ref.go @@ -2,6 +2,8 @@ package blob import ( "cmp" + "path" + "path/filepath" "strings" ) @@ -11,26 +13,31 @@ import ( // // Users or Ref must check Valid before using it. type Ref struct { - name string - tag string - build string + domain string + name string + tag string + build string } // WithBuild returns a copy of r with the provided build. If the provided // build is empty, it returns the short, unqualified copy of r. func (r Ref) WithBuild(build string) Ref { if build == "" { - return Ref{r.name, r.tag, ""} + return Ref{r.domain, r.name, r.tag, ""} } if !isValidPart(build) { return Ref{} } - return makeRef(r.name, r.tag, build) + return makeRef(r.domain, r.name, r.tag, build) } // String returns the fully qualified ref string. func (r Ref) String() string { var b strings.Builder + if r.domain != "" { + b.WriteString(r.domain) + b.WriteString("/") + } b.WriteString(r.name) if r.tag != "" { b.WriteString(":") @@ -49,7 +56,7 @@ func (r Ref) Full() string { if !r.Valid() { return "" } - return makeRef(r.name, r.tag, cmp.Or(r.build, "!(MISSING BUILD)")).String() + return makeRef(r.domain, r.name, r.tag, cmp.Or(r.build, "!(MISSING BUILD)")).String() } // Short returns the short ref string which does not include the build. @@ -65,9 +72,18 @@ func (r Ref) FullyQualified() bool { return r.name != "" && r.tag != "" && r.build != "" } -func (r Ref) Name() string { return r.name } -func (r Ref) Tag() string { return r.tag } -func (r Ref) Build() string { return r.build } +func (r Ref) Path() string { + return path.Join(r.domain, r.name, r.tag, r.build) +} + +func (r Ref) Filepath() string { + return filepath.Join(r.domain, r.name, r.tag, r.build) +} + +func (r Ref) Domain() string { return r.domain } +func (r Ref) Name() string { return r.name } +func (r Ref) Tag() string { return r.tag } +func (r Ref) Build() string { return r.build } // ParseRef parses a ref string into a Ref. A ref string is a name, an // optional tag, and an optional build, separated by colons and pluses. @@ -107,12 +123,14 @@ func ParseRef(s string) Ref { if expectBuild && !isValidPart(build) { return Ref{} } - return makeRef(name, tag, build) + + const TODO = "registry.ollama.ai" + return makeRef(TODO, name, tag, build) } // makeRef makes a ref, skipping validation. -func makeRef(name, tag, build string) Ref { - return Ref{name, cmp.Or(tag, "latest"), strings.ToUpper(build)} +func makeRef(domain, name, tag, build string) Ref { + return Ref{domain, name, cmp.Or(tag, "latest"), strings.ToUpper(build)} } // isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-] diff --git a/build/blob/ref_test.go b/build/blob/ref_test.go index b49d39df..1322022f 100644 --- a/build/blob/ref_test.go +++ b/build/blob/ref_test.go @@ -12,24 +12,24 @@ func TestParseRef(t *testing.T) { in string want Ref }{ - {"mistral:latest", Ref{"mistral", "latest", ""}}, - {"mistral", Ref{"mistral", "latest", ""}}, - {"mistral:30B", Ref{"mistral", "30B", ""}}, - {"mistral:7b", Ref{"mistral", "7b", ""}}, - {"mistral:7b+Q4_0", Ref{"mistral", "7b", "Q4_0"}}, - {"mistral+KQED", Ref{"mistral", "latest", "KQED"}}, - {"mistral.x-3:7b+Q4_0", Ref{"mistral.x-3", "7b", "Q4_0"}}, + {"mistral:latest", Ref{"registry.ollama.ai", "mistral", "latest", ""}}, + {"mistral", Ref{"registry.ollama.ai", "mistral", "latest", ""}}, + {"mistral:30B", Ref{"registry.ollama.ai", "mistral", "30B", ""}}, + {"mistral:7b", Ref{"registry.ollama.ai", "mistral", "7b", ""}}, + {"mistral:7b+Q4_0", Ref{"registry.ollama.ai", "mistral", "7b", "Q4_0"}}, + {"mistral+KQED", Ref{"registry.ollama.ai", "mistral", "latest", "KQED"}}, + {"mistral.x-3:7b+Q4_0", Ref{"registry.ollama.ai", "mistral.x-3", "7b", "Q4_0"}}, // lowecase build - {"mistral:7b+q4_0", Ref{"mistral", "7b", "Q4_0"}}, + {"mistral:7b+q4_0", Ref{"registry.ollama.ai", "mistral", "7b", "Q4_0"}}, // Invalid - {"mistral:7b+Q4_0:latest", Ref{"", "", ""}}, - {"mi tral", Ref{"", "", ""}}, - {"llama2:+", Ref{"", "", ""}}, + {"mistral:7b+Q4_0:latest", Ref{"", "", "", ""}}, + {"mi tral", Ref{"", "", "", ""}}, + {"llama2:+", Ref{"", "", "", ""}}, // too long - {refTooLong, Ref{"", "", ""}}, + {refTooLong, Ref{"", "", "", ""}}, } for _, tt := range cases { t.Run(tt.in, func(t *testing.T) { @@ -48,11 +48,11 @@ func TestRefFull(t *testing.T) { wantFull string }{ {"", "", ""}, - {"mistral:7b+x", "mistral:7b", "mistral:7b+X"}, - {"mistral:7b+Q4_0", "mistral:7b", "mistral:7b+Q4_0"}, - {"mistral:latest", "mistral:latest", "mistral:latest+!(MISSING BUILD)"}, - {"mistral", "mistral:latest", "mistral:latest+!(MISSING BUILD)"}, - {"mistral:30b", "mistral:30b", "mistral:30b+!(MISSING BUILD)"}, + {"mistral:7b+x", "registry.ollama.ai/mistral:7b", "registry.ollama.ai/mistral:7b+X"}, + {"mistral:7b+Q4_0", "registry.ollama.ai/mistral:7b", "registry.ollama.ai/mistral:7b+Q4_0"}, + {"mistral:latest", "registry.ollama.ai/mistral:latest", "registry.ollama.ai/mistral:latest+!(MISSING BUILD)"}, + {"mistral", "registry.ollama.ai/mistral:latest", "registry.ollama.ai/mistral:latest+!(MISSING BUILD)"}, + {"mistral:30b", "registry.ollama.ai/mistral:30b", "registry.ollama.ai/mistral:30b+!(MISSING BUILD)"}, } for _, tt := range cases { diff --git a/build/build.go b/build/build.go index 21491102..f66138db 100644 --- a/build/build.go +++ b/build/build.go @@ -6,7 +6,6 @@ import ( "fmt" "io/fs" "os" - "path" "bllamo.com/build/blob" "bllamo.com/build/internal/blobstore" @@ -22,10 +21,6 @@ var ( ErrNotFound = errors.New("not found") ) -func ManifestKey(domain string, ref blob.Ref) string { - return path.Join("manifests", domain, ref.Name(), ref.Tag(), ref.Build()) -} - type mediaType string // Known media types diff --git a/build/internal/blobstore/blob.go b/build/internal/blobstore/blob.go index e5981416..3c29538a 100644 --- a/build/internal/blobstore/blob.go +++ b/build/internal/blobstore/blob.go @@ -1,7 +1,3 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - // Package blobstore implements a blob store. package blobstore @@ -228,8 +224,7 @@ func (s *Store) refFileName(ref blob.Ref) (string, error) { if !ref.FullyQualified() { return "", fmt.Errorf("ref not fully qualified: %q", ref) } - const cheatTODO = "registry.ollama.ai/library" - return filepath.Join(s.dir, "manifests", cheatTODO, ref.Name(), ref.Tag(), ref.Build()), nil + return filepath.Join(s.dir, "manifests", ref.Domain(), ref.Name(), ref.Tag(), ref.Build()), nil } // Get looks up the blob ID in the store, diff --git a/registry/server.go b/registry/server.go index b22042bb..f22d0b1f 100644 --- a/registry/server.go +++ b/registry/server.go @@ -1,4 +1,4 @@ -// Package implements an Ollama registry client and server +// Package implements an Ollama registry client and server package registry package registry import ( @@ -8,9 +8,10 @@ import ( "errors" "log" "net/http" + "os" + "path" "time" - "bllamo.com/build" "bllamo.com/build/blob" "bllamo.com/client/ollama" "bllamo.com/oweb" @@ -19,6 +20,13 @@ import ( "github.com/minio/minio-go/v7/pkg/credentials" ) +// TODO(bmizerany): move all env things to package envkobs? +var defaultLibrary = cmp.Or(os.Getenv("OLLAMA_REGISTRY"), "registry.ollama.ai/library") + +func DefaultLibrary() string { + return defaultLibrary +} + type Server struct{} func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -80,7 +88,8 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { } if !pushed { const expires = 1 * time.Hour - signedURL, err := mc.PresignedPutObject(r.Context(), "test", l.Digest, expires) + key := path.Join("blobs", l.Digest) + signedURL, err := mc.PresignedPutObject(r.Context(), "test", key, expires) if err != nil { return err } @@ -95,9 +104,10 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { } if len(requirements) == 0 { - const cheatTODO = "registry.ollama.ai/library" - key := build.ManifestKey(cheatTODO, ref) - _, err := mc.PutObject(r.Context(), "test", key, bytes.NewReader(pr.Manifest), int64(len(pr.Manifest)), minio.PutObjectOptions{}) + // Commit the manifest + body := bytes.NewReader(pr.Manifest) + path := path.Join("manifests", ref.Path()) + _, err := mc.PutObject(r.Context(), "test", path, body, int64(len(pr.Manifest)), minio.PutObjectOptions{}) if err != nil { return err } @@ -122,7 +132,8 @@ func (s *Server) statObject(ctx context.Context, digest string) (pushed bool, er } // HEAD the object - _, err = mc.StatObject(ctx, "test", digest, minio.StatObjectOptions{}) + path := path.Join("blobs", digest) + _, err = mc.StatObject(ctx, "test", path, minio.StatObjectOptions{}) if err != nil { if isNoSuchKey(err) { err = nil diff --git a/registry/server_test.go b/registry/server_test.go index 79c1875b..0e268c76 100644 --- a/registry/server_test.go +++ b/registry/server_test.go @@ -2,6 +2,7 @@ package registry import ( "context" + "encoding/json" "net/http/httptest" "os/exec" "strings" @@ -57,6 +58,50 @@ func TestPush(t *testing.T) { if len(got) != 0 { t.Fatalf("unexpected requirements: % #v", pretty.Formatter(got)) } + + mc, err := minio.New("localhost:9000", &minio.Options{ + Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""), + Secure: false, + }) + if err != nil { + t.Fatal(err) + } + + var paths []string + keys := mc.ListObjects(context.Background(), "test", minio.ListObjectsOptions{ + Recursive: true, + }) + for k := range keys { + paths = append(paths, k.Key) + } + + t.Logf("paths: %v", paths) + + diff.Test(t, t.Errorf, paths, []string{ + "blobs/sha256-1", + "blobs/sha256-2", + "blobs/sha256-3", + "manifests/registry.ollama.ai/x/latest/Y", + }) + + obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/latest/Y", minio.GetObjectOptions{}) + if err != nil { + t.Fatal(err) + } + defer obj.Close() + + var gotM apitype.Manifest + if err := json.NewDecoder(obj).Decode(&gotM); err != nil { + t.Fatal(err) + } + + diff.Test(t, t.Errorf, gotM, apitype.Manifest{ + Layers: []apitype.Layer{ + {Digest: "sha256-1", Size: 1}, + {Digest: "sha256-2", Size: 2}, + {Digest: "sha256-3", Size: 3}, + }, + }) } func startMinio(t *testing.T) { From 7cfc8a0838d88a94815446bb300cd650614963d8 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Mon, 1 Apr 2024 21:19:58 -0700 Subject: [PATCH 16/29] build/blob: fix awkward Ref type --- api/api.go | 8 +- build/blob/ref.go | 242 +++++++++++++++++++------ build/blob/ref_test.go | 86 ++++++--- build/build.go | 40 ++-- build/build_test.go | 29 ++- build/internal/blobstore/blob.go | 10 +- build/internal/blobstore/store_test.go | 10 +- registry/server_test.go | 22 ++- 8 files changed, 325 insertions(+), 122 deletions(-) diff --git a/api/api.go b/api/api.go index d8fad5dd..1db6b933 100644 --- a/api/api.go +++ b/api/api.go @@ -7,7 +7,6 @@ import ( "os" "bllamo.com/build" - "bllamo.com/build/blob" "bllamo.com/client/ollama/apitype" "bllamo.com/oweb" "bllamo.com/registry" @@ -56,12 +55,7 @@ func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error { const registryURLTODO = "http://localhost:8888" - ref := blob.ParseRef(params.Name) - if !ref.FullyQualified() { - return errUnqualifiedRef - } - - man, err := s.Build.Manifest(ref) + man, err := s.Build.Manifest(params.Name) if err != nil { if errors.Is(err, build.ErrNotFound) { return errRefNotFound diff --git a/build/blob/ref.go b/build/blob/ref.go index dc7bdcef..b8ad5203 100644 --- a/build/blob/ref.go +++ b/build/blob/ref.go @@ -2,33 +2,99 @@ package blob import ( "cmp" - "path" - "path/filepath" + "fmt" + "slices" "strings" ) +// Levels of concreteness +const ( + domain = iota + namespace + name + tag + build +) + // Ref is an opaque reference to a blob. // // It is comparable and can be used as a map key. // // Users or Ref must check Valid before using it. type Ref struct { - domain string - name string - tag string - build string + domain string + namespace string + name string + tag string + build string +} + +// WithDomain returns a copy of r with the provided domain. If the provided +// domain is empty, it returns the short, unqualified copy of r. +func (r Ref) WithDomain(s string) Ref { + return with(r, domain, s) +} + +// WithNamespace returns a copy of r with the provided namespace. If the +// provided namespace is empty, it returns the short, unqualified copy of r. +func (r Ref) WithNamespace(s string) Ref { + return with(r, namespace, s) +} + +func (r Ref) WithTag(s string) Ref { + return with(r, tag, s) } // WithBuild returns a copy of r with the provided build. If the provided // build is empty, it returns the short, unqualified copy of r. -func (r Ref) WithBuild(build string) Ref { - if build == "" { - return Ref{r.domain, r.name, r.tag, ""} - } - if !isValidPart(build) { +func (r Ref) WithBuild(s string) Ref { + return with(r, build, s) +} + +func with(r Ref, part int, value string) Ref { + if value != "" && !isValidPart(value) { return Ref{} } - return makeRef(r.domain, r.name, r.tag, build) + switch part { + case domain: + r.domain = value + case namespace: + r.namespace = value + case name: + r.name = value + case tag: + r.tag = value + case build: + r.build = value + default: + panic(fmt.Sprintf("invalid completeness: %d", part)) + } + return r +} + +// Format returns a string representation of the ref with the given +// concreteness. If a part is missing, it is replaced with a loud +// placeholder. +func (r Ref) Full() string { + r.domain = cmp.Or(r.domain, "!(MISSING DOMAIN)") + r.namespace = cmp.Or(r.namespace, "!(MISSING NAMESPACE)") + r.name = cmp.Or(r.name, "!(MISSING NAME)") + r.tag = cmp.Or(r.tag, "!(MISSING TAG)") + r.build = cmp.Or(r.build, "!(MISSING BUILD)") + return r.String() +} + +func (r Ref) NameAndTag() string { + r.domain = "" + r.namespace = "" + r.build = "" + return r.String() +} + +func (r Ref) NameTagAndBuild() string { + r.domain = "" + r.namespace = "" + return r.String() } // String returns the fully qualified ref string. @@ -38,6 +104,10 @@ func (r Ref) String() string { b.WriteString(r.domain) b.WriteString("/") } + if r.namespace != "" { + b.WriteString(r.namespace) + b.WriteString("/") + } b.WriteString(r.name) if r.tag != "" { b.WriteString(":") @@ -50,40 +120,41 @@ func (r Ref) String() string { return b.String() } -// Full returns the fully qualified ref string, or a string indicating the -// build is missing, or an empty string if the ref is invalid. -func (r Ref) Full() string { - if !r.Valid() { - return "" +// Complete returns true if the ref is valid and has no empty parts. +func (r Ref) Complete() bool { + return r.Valid() && !slices.Contains(r.Parts(), "") +} + +// Less returns true if r is less concrete than o; false otherwise. +func (r Ref) Less(o Ref) bool { + rp := r.Parts() + op := o.Parts() + for i := range rp { + if rp[i] < op[i] { + return true + } } - return makeRef(r.domain, r.name, r.tag, cmp.Or(r.build, "!(MISSING BUILD)")).String() + return false } -// Short returns the short ref string which does not include the build. -func (r Ref) Short() string { - return r.WithBuild("").String() +// Parts returns the parts of the ref in order of concreteness. +// +// The length of the returned slice is always 5. +func (r Ref) Parts() []string { + return []string{ + domain: r.domain, + namespace: r.namespace, + name: r.name, + tag: r.tag, + build: r.build, + } } -func (r Ref) Valid() bool { - return r.name != "" -} - -func (r Ref) FullyQualified() bool { - return r.name != "" && r.tag != "" && r.build != "" -} - -func (r Ref) Path() string { - return path.Join(r.domain, r.name, r.tag, r.build) -} - -func (r Ref) Filepath() string { - return filepath.Join(r.domain, r.name, r.tag, r.build) -} - -func (r Ref) Domain() string { return r.domain } -func (r Ref) Name() string { return r.name } -func (r Ref) Tag() string { return r.tag } -func (r Ref) Build() string { return r.build } +func (r Ref) Domain() string { return r.namespace } +func (r Ref) Namespace() string { return r.namespace } +func (r Ref) Name() string { return r.name } +func (r Ref) Tag() string { return r.tag } +func (r Ref) Build() string { return r.build } // ParseRef parses a ref string into a Ref. A ref string is a name, an // optional tag, and an optional build, separated by colons and pluses. @@ -112,25 +183,86 @@ func ParseRef(s string) Ref { return Ref{} } - nameAndTag, build, expectBuild := strings.Cut(s, "+") - name, tag, expectTag := strings.Cut(nameAndTag, ":") - if !isValidPart(name) { - return Ref{} + if strings.HasPrefix(s, "http://") { + s = s[len("http://"):] } - if expectTag && !isValidPart(tag) { - return Ref{} - } - if expectBuild && !isValidPart(build) { - return Ref{} + if strings.HasPrefix(s, "https://") { + s = s[len("https://"):] } - const TODO = "registry.ollama.ai" - return makeRef(TODO, name, tag, build) + var r Ref + + state, j := build, len(s) + for i := len(s) - 1; i >= 0; i-- { + c := s[i] + switch c { + case '+': + switch state { + case build: + r.build = s[i+1 : j] + r.build = strings.ToUpper(r.build) + state, j = tag, i + default: + return Ref{} + } + case ':': + switch state { + case build, tag: + r.tag = s[i+1 : j] + state, j = name, i + default: + return Ref{} + } + case '/': + switch state { + case name, tag, build: + r.name = s[i+1 : j] + state, j = namespace, i + case namespace: + r.namespace = s[i+1 : j] + state, j = domain, i + default: + return Ref{} + } + } + } + + // handle the first part based on final state + switch state { + case domain: + r.domain = s[:j] + case namespace: + r.namespace = s[:j] + default: + r.name = s[:j] + } + + if !r.Valid() { + return Ref{} + } + return r } -// makeRef makes a ref, skipping validation. -func makeRef(domain, name, tag, build string) Ref { - return Ref{domain, name, cmp.Or(tag, "latest"), strings.ToUpper(build)} +func (r Ref) Valid() bool { + // Name is required + if !isValidPart(r.name) { + return false + } + + // Optional parts must be valid if present + if r.domain != "" && !isValidPart(r.domain) { + return false + } + if r.namespace != "" && !isValidPart(r.namespace) { + return false + } + if r.tag != "" && !isValidPart(r.tag) { + return false + } + if r.build != "" && !isValidPart(r.build) { + return false + } + return true } // isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-] diff --git a/build/blob/ref_test.go b/build/blob/ref_test.go index 1322022f..1b6161c6 100644 --- a/build/blob/ref_test.go +++ b/build/blob/ref_test.go @@ -7,29 +7,63 @@ const ( refTooLong = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" ) +func TestRefParts(t *testing.T) { + const wantNumParts = 5 + var ref Ref + if len(ref.Parts()) != wantNumParts { + t.Errorf("Parts() = %d; want %d", len(ref.Parts()), wantNumParts) + } +} + func TestParseRef(t *testing.T) { cases := []struct { in string want Ref }{ - {"mistral:latest", Ref{"registry.ollama.ai", "mistral", "latest", ""}}, - {"mistral", Ref{"registry.ollama.ai", "mistral", "latest", ""}}, - {"mistral:30B", Ref{"registry.ollama.ai", "mistral", "30B", ""}}, - {"mistral:7b", Ref{"registry.ollama.ai", "mistral", "7b", ""}}, - {"mistral:7b+Q4_0", Ref{"registry.ollama.ai", "mistral", "7b", "Q4_0"}}, - {"mistral+KQED", Ref{"registry.ollama.ai", "mistral", "latest", "KQED"}}, - {"mistral.x-3:7b+Q4_0", Ref{"registry.ollama.ai", "mistral.x-3", "7b", "Q4_0"}}, + {"mistral:latest", Ref{ + name: "mistral", + tag: "latest", + }}, + {"mistral", Ref{ + name: "mistral", + }}, + {"mistral:30B", Ref{ + name: "mistral", + tag: "30B", + }}, + {"mistral:7b", Ref{ + name: "mistral", + tag: "7b", + }}, + {"mistral:7b+Q4_0", Ref{ + name: "mistral", + tag: "7b", + build: "Q4_0", + }}, + {"mistral+KQED", Ref{ + name: "mistral", + build: "KQED", + }}, + {"mistral.x-3:7b+Q4_0", Ref{ + name: "mistral.x-3", + tag: "7b", + build: "Q4_0", + }}, // lowecase build - {"mistral:7b+q4_0", Ref{"registry.ollama.ai", "mistral", "7b", "Q4_0"}}, + {"mistral:7b+q4_0", Ref{ + name: "mistral", + tag: "7b", + build: "Q4_0", + }}, + {"llama2:+", Ref{name: "llama2"}}, // Invalid - {"mistral:7b+Q4_0:latest", Ref{"", "", "", ""}}, - {"mi tral", Ref{"", "", "", ""}}, - {"llama2:+", Ref{"", "", "", ""}}, + {"mistral:7b+Q4_0:latest", Ref{}}, + {"mi tral", Ref{}}, // too long - {refTooLong, Ref{"", "", "", ""}}, + {refTooLong, Ref{}}, } for _, tt := range cases { t.Run(tt.in, func(t *testing.T) { @@ -42,25 +76,29 @@ func TestParseRef(t *testing.T) { } func TestRefFull(t *testing.T) { + const empty = "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/!(MISSING NAME):!(MISSING TAG)+!(MISSING BUILD)" + cases := []struct { - in string - wantShort string - wantFull string + in string + wantFull string }{ - {"", "", ""}, - {"mistral:7b+x", "registry.ollama.ai/mistral:7b", "registry.ollama.ai/mistral:7b+X"}, - {"mistral:7b+Q4_0", "registry.ollama.ai/mistral:7b", "registry.ollama.ai/mistral:7b+Q4_0"}, - {"mistral:latest", "registry.ollama.ai/mistral:latest", "registry.ollama.ai/mistral:latest+!(MISSING BUILD)"}, - {"mistral", "registry.ollama.ai/mistral:latest", "registry.ollama.ai/mistral:latest+!(MISSING BUILD)"}, - {"mistral:30b", "registry.ollama.ai/mistral:30b", "registry.ollama.ai/mistral:30b+!(MISSING BUILD)"}, + {"", empty}, + {"example.com/mistral:7b+x", "!(MISSING DOMAIN)/example.com/mistral:7b+X"}, + {"example.com/mistral:7b+Q4_0", "!(MISSING DOMAIN)/example.com/mistral:7b+Q4_0"}, + {"example.com/x/mistral:latest", "example.com/x/mistral:latest+!(MISSING BUILD)"}, + {"example.com/x/mistral:latest+Q4_0", "example.com/x/mistral:latest+Q4_0"}, + + {"mistral:7b+x", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:7b+X"}, + {"mistral:7b+Q4_0", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:7b+Q4_0"}, + {"mistral:latest", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:latest+!(MISSING BUILD)"}, + {"mistral", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:!(MISSING TAG)+!(MISSING BUILD)"}, + {"mistral:30b", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:30b+!(MISSING BUILD)"}, } for _, tt := range cases { t.Run(tt.in, func(t *testing.T) { ref := ParseRef(tt.in) - if g := ref.Short(); g != tt.wantShort { - t.Errorf("Short(%q) = %q; want %q", tt.in, g, tt.wantShort) - } + t.Logf("ParseRef(%q) = %#v", tt.in, ref) if g := ref.Full(); g != tt.wantFull { t.Errorf("Full(%q) = %q; want %q", tt.in, g, tt.wantFull) } diff --git a/build/build.go b/build/build.go index f66138db..374ba827 100644 --- a/build/build.go +++ b/build/build.go @@ -14,7 +14,8 @@ import ( // Errors var ( - ErrInvalidRef = errors.New("invalid ref") + ErrRefUnqualified = errors.New("unqualified ref") + ErrRefBuildPresent = errors.New("ref too long") ErrUnsupportedModelFormat = errors.New("unsupported model format") ErrMissingFileType = errors.New("missing 'general.file_type' key") ErrNoSuchBlob = errors.New("no such blob") @@ -53,14 +54,12 @@ func Open(dir string) (*Server, error) { func (s *Server) Build(ref string, f model.File) error { br := blob.ParseRef(ref) - if !br.Valid() { - return invalidRef(ref) + if !br.Complete() { + return fmt.Errorf("%w: %q", ErrRefUnqualified, br.Full()) } // 1. Resolve FROM // a. If it's a local file (gguf), hash it and add it to the store. - // b. If it's a local dir (safetensor), convert to gguf and add to - // store. // c. If it's a remote file (http), refuse. // 2. Turn other pragmas into layers, and add them to the store. // 3. Create a manifest from the layers. @@ -109,17 +108,22 @@ func (s *Server) LayerFile(digest string) (string, error) { return fileName, nil } -func (s *Server) Manifest(ref blob.Ref) ([]byte, error) { - data, _, err := s.getManifestData(ref) - if errors.Is(err, fs.ErrNotExist) { - return nil, fmt.Errorf("%w: %q", ErrNotFound, ref) +func (s *Server) Manifest(ref string) ([]byte, error) { + br, err := parseFullRef(ref) + if err != nil { + return nil, err } + data, _, err := s.getManifestData(br) return data, err } // WeightFile returns the absolute path to the weights file for the given model ref. -func (s *Server) WeightsFile(ref blob.Ref) (string, error) { - m, err := s.getManifest(ref) +func (s *Server) WeightsFile(ref string) (string, error) { + br, err := parseFullRef(ref) + if err != nil { + return "", err + } + m, err := s.getManifest(br) if err != nil { return "", err } @@ -157,9 +161,17 @@ func (s *Server) getManifest(ref blob.Ref) (manifestJSON, error) { } func (s *Server) getManifestData(ref blob.Ref) (data []byte, path string, err error) { - return s.st.Resolve(ref) + data, path, err = s.st.Resolve(ref) + if errors.Is(err, blobstore.ErrUnknownRef) { + return nil, "", fmt.Errorf("%w: %q", ErrNotFound, ref) + } + return data, path, err } -func invalidRef(ref string) error { - return fmt.Errorf("%w: %q", ErrInvalidRef, ref) +func parseFullRef(ref string) (blob.Ref, error) { + br := blob.ParseRef(ref) + if !br.Complete() { + return blob.Ref{}, fmt.Errorf("%w: %q", ErrRefUnqualified, ref) + } + return br, nil } diff --git a/build/build_test.go b/build/build_test.go index c146717e..2872f12c 100644 --- a/build/build_test.go +++ b/build/build_test.go @@ -6,11 +6,12 @@ import ( "path/filepath" "testing" - "bllamo.com/build/blob" "bllamo.com/encoding/gguf" "bllamo.com/model" ) +const qualifiedRef = "x/y/z:latest+Q4_0" + func TestServerBuildErrors(t *testing.T) { dir := t.TempDir() @@ -19,8 +20,15 @@ func TestServerBuildErrors(t *testing.T) { t.Fatal(err) } + t.Run("unqualified ref", func(t *testing.T) { + err := s.Build("x", model.File{}) + if !errors.Is(err, ErrRefUnqualified) { + t.Fatalf("Build() err = %v; want unqualified ref", err) + } + }) + t.Run("FROM pragma missing", func(t *testing.T) { - err := s.Build("foo", model.File{}) + err := s.Build(qualifiedRef, model.File{}) var e *model.Error if !errors.As(err, &e) { t.Fatalf("unexpected error: %v", err) @@ -34,7 +42,7 @@ func TestServerBuildErrors(t *testing.T) { }) t.Run("FROM file not found", func(t *testing.T) { - err := s.Build("x", model.File{From: "bar"}) + err := s.Build(qualifiedRef, model.File{From: "bar"}) if !errors.Is(err, os.ErrNotExist) { t.Fatalf("Build() err = %v; want file not found", err) } @@ -51,7 +59,7 @@ func TestServerBuildErrors(t *testing.T) { "", ) - err := s.Build("x", model.File{From: w.fileName("gguf")}) + err := s.Build(qualifiedRef, model.File{From: w.fileName("gguf")}) if !errors.Is(err, ErrMissingFileType) { t.Fatalf("Build() err = %#v; want missing file type", err) } @@ -60,7 +68,7 @@ func TestServerBuildErrors(t *testing.T) { t.Run("FROM obscure dir", func(t *testing.T) { w := newWorkDir(t) w.mkdirAll("unknown") - if err := s.Build("x", model.File{From: w.fileName("unknown")}); err != ErrUnsupportedModelFormat { + if err := s.Build(qualifiedRef, model.File{From: w.fileName("unknown")}); err != ErrUnsupportedModelFormat { t.Fatalf("Build() err = %#v; want unsupported model type", err) } }) @@ -68,7 +76,7 @@ func TestServerBuildErrors(t *testing.T) { t.Run("FROM unsupported model type", func(t *testing.T) { w := newWorkDir(t) from := w.write("unknown", "unknown content") - err := s.Build("x", model.File{From: from}) + err := s.Build(qualifiedRef, model.File{From: from}) if !errors.Is(err, ErrUnsupportedModelFormat) { t.Fatalf("Build() err = %#v; want unsupported model type", err) } @@ -96,7 +104,7 @@ func TestBuildBasicGGUF(t *testing.T) { if err != nil { t.Fatal(err) } - if err := s.Build("x", model.File{From: w.fileName("gguf")}); err != nil { + if err := s.Build(qualifiedRef, model.File{From: w.fileName("gguf")}); err != nil { t.Fatal(err) } @@ -105,7 +113,12 @@ func TestBuildBasicGGUF(t *testing.T) { return nil }) - path, err := s.WeightsFile(blob.ParseRef("x+Q4_0")) + _, err = s.WeightsFile("unknown/y/z:latest+Q4_0") + if !errors.Is(err, ErrNotFound) { + t.Fatalf("WeightsFile() err = %v; want not found", err) + } + + path, err := s.WeightsFile("x/y/z:latest+Q4_0") if err != nil { t.Fatal(err) } diff --git a/build/internal/blobstore/blob.go b/build/internal/blobstore/blob.go index 3c29538a..18664abe 100644 --- a/build/internal/blobstore/blob.go +++ b/build/internal/blobstore/blob.go @@ -18,7 +18,8 @@ import ( ) var ( - ErrInvalidID = errors.New("invalid ID") + ErrInvalidID = errors.New("invalid ID") + ErrUnknownRef = errors.New("unknown ref") ) const HashSize = 32 @@ -199,6 +200,9 @@ func (s *Store) Resolve(ref blob.Ref) (data []byte, path string, err error) { return nil, "", err } data, err = os.ReadFile(path) + if errors.Is(err, fs.ErrNotExist) { + return nil, "", fmt.Errorf("%w: %q", ErrUnknownRef, ref) + } if err != nil { return nil, "", &entryNotFoundError{Err: err} } @@ -221,10 +225,10 @@ func (s *Store) Set(ref blob.Ref, data []byte) error { } func (s *Store) refFileName(ref blob.Ref) (string, error) { - if !ref.FullyQualified() { + if !ref.Complete() { return "", fmt.Errorf("ref not fully qualified: %q", ref) } - return filepath.Join(s.dir, "manifests", ref.Domain(), ref.Name(), ref.Tag(), ref.Build()), nil + return filepath.Join(s.dir, "manifests", filepath.Join(ref.Parts()...)), nil } // Get looks up the blob ID in the store, diff --git a/build/internal/blobstore/store_test.go b/build/internal/blobstore/store_test.go index ddcc05aa..6f698f9f 100644 --- a/build/internal/blobstore/store_test.go +++ b/build/internal/blobstore/store_test.go @@ -70,14 +70,13 @@ func TestStoreBasicBlob(t *testing.T) { } // Check tags - ref := blob.ParseRef("test+KQED") + ref := blob.ParseRef("registry.ollama.ai/library/test:latest+KQED") - t.Logf("resolving %s", ref) + t.Logf("RESOLVING: %q", ref.Parts()) data, _, err := st.Resolve(ref) - var e *entryNotFoundError - if !errors.As(err, &e) { - t.Fatal(err) + if !errors.Is(err, ErrUnknownRef) { + t.Fatalf("unexpected error: %v", err) } if data != nil { t.Errorf("unexpected data: %q", data) @@ -119,6 +118,7 @@ func checkDir(t testing.TB, dir string, want []string) { var matches []string for path, err := range walkDir(dir) { + t.Helper() if err != nil { t.Fatal(err) } diff --git a/registry/server_test.go b/registry/server_test.go index 0e268c76..466fd787 100644 --- a/registry/server_test.go +++ b/registry/server_test.go @@ -3,6 +3,8 @@ package registry import ( "context" "encoding/json" + "errors" + "io" "net/http/httptest" "os/exec" "strings" @@ -32,7 +34,9 @@ func TestPush(t *testing.T) { ] }`) - got, err := c.Push(context.Background(), "x+y", manifest) + const ref = "registry.ollama.ai/x/y:latest+Z" + + got, err := c.Push(context.Background(), ref, manifest) if err != nil { t.Fatal(err) } @@ -44,13 +48,13 @@ func TestPush(t *testing.T) { }, diff.ZeroFields[apitype.Requirement]("URL")) for _, r := range got { - body := strings.NewReader(strings.Repeat("x", int(r.Size))) + body := io.Reader(strings.NewReader(strings.Repeat("x", int(r.Size)))) if err := PushLayer(context.Background(), r.URL, r.Size, body); err != nil { t.Fatal(err) } } - got, err = c.Push(context.Background(), "x+y", manifest) + got, err = c.Push(context.Background(), ref, manifest) if err != nil { t.Fatal(err) } @@ -81,10 +85,10 @@ func TestPush(t *testing.T) { "blobs/sha256-1", "blobs/sha256-2", "blobs/sha256-3", - "manifests/registry.ollama.ai/x/latest/Y", + "manifests/registry.ollama.ai/x/y/latest/Z", }) - obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/latest/Y", minio.GetObjectOptions{}) + obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{}) if err != nil { t.Fatal(err) } @@ -117,7 +121,13 @@ func startMinio(t *testing.T) { t.Cleanup(func() { cmd.Process.Kill() if err := cmd.Wait(); err != nil { - t.Log(err) + var e *exec.ExitError + if errors.As(err, &e) && e.Exited() { + t.Logf("minio stderr: %s", e.Stderr) + t.Logf("minio exit status: %v", e.ExitCode()) + t.Logf("minio exited: %v", e.Exited()) + t.Error(err) + } } }) From 876f7eab812449fbae8d3f24594f67acd3c1bc9a Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Mon, 1 Apr 2024 21:43:30 -0700 Subject: [PATCH 17/29] build: move Manifest from internal/blobstore to build It was getting confusing to have the arbirary handling of manifests in the blobstore. It also prevented us from using model.Ref in the blobstore because of cyclic dependencies. This is much easier to grok now. --- build/build.go | 57 +++++++++++++++++++++-- build/internal/blobstore/blob.go | 62 +++----------------------- build/internal/blobstore/store_test.go | 33 -------------- registry/server.go | 4 +- 4 files changed, 63 insertions(+), 93 deletions(-) diff --git a/build/build.go b/build/build.go index 374ba827..f668bb73 100644 --- a/build/build.go +++ b/build/build.go @@ -6,6 +6,7 @@ import ( "fmt" "io/fs" "os" + "path/filepath" "bllamo.com/build/blob" "bllamo.com/build/internal/blobstore" @@ -20,6 +21,7 @@ var ( ErrMissingFileType = errors.New("missing 'general.file_type' key") ErrNoSuchBlob = errors.New("no such blob") ErrNotFound = errors.New("not found") + ErrUnknownRef = errors.New("unknown ref") ) type mediaType string @@ -96,7 +98,7 @@ func (s *Server) Build(ref string, f model.File) error { if err != nil { return err } - return s.st.Set(br.WithBuild(info.FileType.String()), data) + return s.Set(br.WithBuild(info.FileType.String()), data) } func (s *Server) LayerFile(digest string) (string, error) { @@ -135,6 +137,55 @@ func (s *Server) WeightsFile(ref string) (string, error) { return "", fmt.Errorf("missing weights layer for %q", ref) } +// resolve returns the data for the given ref, if any. +// +// TODO: This should ideally return an ID, but the current on +// disk layout is that the actual manifest is stored in the "ref" instead of +// a pointer to a content-addressed blob. I (bmizerany) think we should +// change the on-disk layout to store the manifest in a content-addressed +// blob, and then have the ref point to that blob. This would simplify the +// code, allow us to have integrity checks on the manifest, and clean up +// this interface. +func (s *Server) resolve(ref blob.Ref) (data []byte, path string, err error) { + path, err = s.refFileName(ref) + if err != nil { + return nil, "", err + } + data, err = os.ReadFile(path) + if errors.Is(err, fs.ErrNotExist) { + return nil, "", fmt.Errorf("%w: %q", ErrUnknownRef, ref) + } + if err != nil { + // do not wrap the error here, as it is likely an I/O error + // and we want to preserve the absraction since we may not + // be on disk later. + return nil, "", fmt.Errorf("manifest read error: %v", err) + } + return data, path, nil +} + +// Set sets the data for the given ref. +func (s *Server) Set(ref blob.Ref, data []byte) error { + path, err := s.refFileName(ref) + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil { + return err + } + if err := os.WriteFile(path, data, 0666); err != nil { + return err + } + return nil +} + +func (s *Server) refFileName(ref blob.Ref) (string, error) { + if !ref.Complete() { + return "", fmt.Errorf("ref not fully qualified: %q", ref) + } + return filepath.Join(s.st.Dir(), "manifests", filepath.Join(ref.Parts()...)), nil +} + type manifestJSON struct { // Layers is the list of layers in the manifest. Layers []layerJSON `json:"layers"` @@ -161,8 +212,8 @@ func (s *Server) getManifest(ref blob.Ref) (manifestJSON, error) { } func (s *Server) getManifestData(ref blob.Ref) (data []byte, path string, err error) { - data, path, err = s.st.Resolve(ref) - if errors.Is(err, blobstore.ErrUnknownRef) { + data, path, err = s.resolve(ref) + if errors.Is(err, ErrUnknownRef) { return nil, "", fmt.Errorf("%w: %q", ErrNotFound, ref) } return data, path, err diff --git a/build/internal/blobstore/blob.go b/build/internal/blobstore/blob.go index 18664abe..24dbad6f 100644 --- a/build/internal/blobstore/blob.go +++ b/build/internal/blobstore/blob.go @@ -13,13 +13,11 @@ import ( "strings" "time" - "bllamo.com/build/blob" "bllamo.com/types/structs" ) var ( - ErrInvalidID = errors.New("invalid ID") - ErrUnknownRef = errors.New("unknown ref") + ErrInvalidID = errors.New("invalid ID") ) const HashSize = 32 @@ -100,13 +98,9 @@ func Open(dir string) (*Store, error) { if !info.IsDir() { return nil, &fs.PathError{Op: "open", Path: dir, Err: fmt.Errorf("not a directory")} } - - for _, sub := range []string{"blobs", "manifests"} { - if err := os.MkdirAll(filepath.Join(dir, sub), 0777); err != nil { - return nil, err - } + if err := os.MkdirAll(filepath.Join(dir, "blobs"), 0777); err != nil { + return nil, err } - c := &Store{ dir: dir, now: time.Now, @@ -114,6 +108,10 @@ func Open(dir string) (*Store, error) { return c, nil } +func (s *Store) Dir() string { + return s.dir +} + // fileName returns the name of the blob file corresponding to the given id. func (s *Store) fileName(id ID) string { return filepath.Join(s.dir, "blobs", fmt.Sprintf("sha256-%x", id.a[:])) @@ -185,52 +183,6 @@ func (s *Store) OutputFilename(id ID) string { return file } -// Resolve returns the data for the given ref, if any. -// -// TODO: This should ideally return an ID, but the current on -// disk layout is that the actual manifest is stored in the "ref" instead of -// a pointer to a content-addressed blob. I (bmizerany) think we should -// change the on-disk layout to store the manifest in a content-addressed -// blob, and then have the ref point to that blob. This would simplify the -// code, allow us to have integrity checks on the manifest, and clean up -// this interface. -func (s *Store) Resolve(ref blob.Ref) (data []byte, path string, err error) { - path, err = s.refFileName(ref) - if err != nil { - return nil, "", err - } - data, err = os.ReadFile(path) - if errors.Is(err, fs.ErrNotExist) { - return nil, "", fmt.Errorf("%w: %q", ErrUnknownRef, ref) - } - if err != nil { - return nil, "", &entryNotFoundError{Err: err} - } - return data, path, nil -} - -// Set sets the data for the given ref. -func (s *Store) Set(ref blob.Ref, data []byte) error { - path, err := s.refFileName(ref) - if err != nil { - return err - } - if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil { - return err - } - if err := os.WriteFile(path, data, 0666); err != nil { - return err - } - return nil -} - -func (s *Store) refFileName(ref blob.Ref) (string, error) { - if !ref.Complete() { - return "", fmt.Errorf("ref not fully qualified: %q", ref) - } - return filepath.Join(s.dir, "manifests", filepath.Join(ref.Parts()...)), nil -} - // Get looks up the blob ID in the store, // returning the corresponding output ID and file size, if any. // Note that finding an output ID does not guarantee that the diff --git a/build/internal/blobstore/store_test.go b/build/internal/blobstore/store_test.go index 6f698f9f..0e64fc1a 100644 --- a/build/internal/blobstore/store_test.go +++ b/build/internal/blobstore/store_test.go @@ -31,7 +31,6 @@ func TestStoreBasicBlob(t *testing.T) { checkDir(t, dir, []string{ "blobs/", - "manifests/", }) id, size, err := PutBytes(st, []byte("hello")) @@ -49,7 +48,6 @@ func TestStoreBasicBlob(t *testing.T) { checkDir(t, dir, []string{ "blobs/", "blobs/" + blobNameHello, - "manifests/", }) got, err := st.Get(id) @@ -74,37 +72,6 @@ func TestStoreBasicBlob(t *testing.T) { t.Logf("RESOLVING: %q", ref.Parts()) - data, _, err := st.Resolve(ref) - if !errors.Is(err, ErrUnknownRef) { - t.Fatalf("unexpected error: %v", err) - } - if data != nil { - t.Errorf("unexpected data: %q", data) - } - - if err := st.Set(ref, []byte("{}")); err != nil { - t.Fatal(err) - } - - data, _, err = st.Resolve(ref) - if err != nil { - t.Fatal(err) - } - - if g := string(data); g != "{}" { - t.Errorf("g = %q; want %q", g, "{}") - } - - checkDir(t, dir, []string{ - "blobs/", - "blobs/sha256-2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", - "manifests/", - "manifests/registry.ollama.ai/", - "manifests/registry.ollama.ai/library/", - "manifests/registry.ollama.ai/library/test/", - "manifests/registry.ollama.ai/library/test/latest/", - "manifests/registry.ollama.ai/library/test/latest/KQED", - }) } // checkDir checks that the directory at dir contains the files in want. The diff --git a/registry/server.go b/registry/server.go index f22d0b1f..1dbc6392 100644 --- a/registry/server.go +++ b/registry/server.go @@ -61,7 +61,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { } ref := blob.ParseRef(pr.Ref) - if !ref.FullyQualified() { + if !ref.Complete() { return oweb.Mistake("invalid", "name", "must be fully qualified") } @@ -106,7 +106,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { if len(requirements) == 0 { // Commit the manifest body := bytes.NewReader(pr.Manifest) - path := path.Join("manifests", ref.Path()) + path := path.Join("manifests", path.Join(ref.Parts()...)) _, err := mc.PutObject(r.Context(), "test", path, body, int64(len(pr.Manifest)), minio.PutObjectOptions{}) if err != nil { return err From b1b8be33d99314799049bb97b6d13722be01bddf Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Mon, 1 Apr 2024 21:57:34 -0700 Subject: [PATCH 18/29] build: cleanup error names and other things --- api/api.go | 2 +- build/build.go | 53 +++++++++++++++++++++------------------------ build/build_test.go | 2 +- 3 files changed, 27 insertions(+), 30 deletions(-) diff --git a/api/api.go b/api/api.go index 1db6b933..b1955a18 100644 --- a/api/api.go +++ b/api/api.go @@ -55,7 +55,7 @@ func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error { const registryURLTODO = "http://localhost:8888" - man, err := s.Build.Manifest(params.Name) + man, err := s.Build.ManifestData(params.Name) if err != nil { if errors.Is(err, build.ErrNotFound) { return errRefNotFound diff --git a/build/build.go b/build/build.go index f668bb73..f8e2fbcf 100644 --- a/build/build.go +++ b/build/build.go @@ -15,13 +15,10 @@ import ( // Errors var ( - ErrRefUnqualified = errors.New("unqualified ref") - ErrRefBuildPresent = errors.New("ref too long") + ErrIncompleteRef = errors.New("unqualified ref") ErrUnsupportedModelFormat = errors.New("unsupported model format") ErrMissingFileType = errors.New("missing 'general.file_type' key") - ErrNoSuchBlob = errors.New("no such blob") ErrNotFound = errors.New("not found") - ErrUnknownRef = errors.New("unknown ref") ) type mediaType string @@ -57,7 +54,7 @@ func Open(dir string) (*Server, error) { func (s *Server) Build(ref string, f model.File) error { br := blob.ParseRef(ref) if !br.Complete() { - return fmt.Errorf("%w: %q", ErrRefUnqualified, br.Full()) + return fmt.Errorf("%w: %q", ErrIncompleteRef, br.Full()) } // 1. Resolve FROM @@ -94,7 +91,7 @@ func (s *Server) Build(ref string, f model.File) error { Size: size, }) - data, err := json.Marshal(manifestJSON{Layers: layers}) + data, err := json.Marshal(Manifest{Layers: layers}) if err != nil { return err } @@ -105,23 +102,31 @@ func (s *Server) LayerFile(digest string) (string, error) { fileName := s.st.OutputFilename(blobstore.ParseID(digest)) _, err := os.Stat(fileName) if errors.Is(err, fs.ErrNotExist) { - return "", fmt.Errorf("%w: %q", ErrNoSuchBlob, digest) + return "", fmt.Errorf("%w: %q", ErrNotFound, digest) } return fileName, nil } -func (s *Server) Manifest(ref string) ([]byte, error) { - br, err := parseFullRef(ref) +func (s *Server) Manifest(ref string) (Manifest, error) { + br, err := parseCompleteRef(ref) + if err != nil { + return Manifest{}, err + } + return s.getManifest(br) +} + +func (s *Server) ManifestData(ref string) ([]byte, error) { + br, err := parseCompleteRef(ref) if err != nil { return nil, err } - data, _, err := s.getManifestData(br) + data, _, err := s.resolve(br) return data, err } // WeightFile returns the absolute path to the weights file for the given model ref. func (s *Server) WeightsFile(ref string) (string, error) { - br, err := parseFullRef(ref) + br, err := parseCompleteRef(ref) if err != nil { return "", err } @@ -153,7 +158,7 @@ func (s *Server) resolve(ref blob.Ref) (data []byte, path string, err error) { } data, err = os.ReadFile(path) if errors.Is(err, fs.ErrNotExist) { - return nil, "", fmt.Errorf("%w: %q", ErrUnknownRef, ref) + return nil, "", fmt.Errorf("%w: %q", ErrNotFound, ref) } if err != nil { // do not wrap the error here, as it is likely an I/O error @@ -186,7 +191,7 @@ func (s *Server) refFileName(ref blob.Ref) (string, error) { return filepath.Join(s.st.Dir(), "manifests", filepath.Join(ref.Parts()...)), nil } -type manifestJSON struct { +type Manifest struct { // Layers is the list of layers in the manifest. Layers []layerJSON `json:"layers"` } @@ -199,30 +204,22 @@ type layerJSON struct { Size int64 `json:"size"` } -func (s *Server) getManifest(ref blob.Ref) (manifestJSON, error) { - data, path, err := s.getManifestData(ref) +func (s *Server) getManifest(ref blob.Ref) (Manifest, error) { + data, path, err := s.resolve(ref) if err != nil { - return manifestJSON{}, err + return Manifest{}, err } - var m manifestJSON + var m Manifest if err := json.Unmarshal(data, &m); err != nil { - return manifestJSON{}, &fs.PathError{Op: "unmarshal", Path: path, Err: err} + return Manifest{}, &fs.PathError{Op: "unmarshal", Path: path, Err: err} } return m, nil } -func (s *Server) getManifestData(ref blob.Ref) (data []byte, path string, err error) { - data, path, err = s.resolve(ref) - if errors.Is(err, ErrUnknownRef) { - return nil, "", fmt.Errorf("%w: %q", ErrNotFound, ref) - } - return data, path, err -} - -func parseFullRef(ref string) (blob.Ref, error) { +func parseCompleteRef(ref string) (blob.Ref, error) { br := blob.ParseRef(ref) if !br.Complete() { - return blob.Ref{}, fmt.Errorf("%w: %q", ErrRefUnqualified, ref) + return blob.Ref{}, fmt.Errorf("%w: %q", ErrIncompleteRef, ref) } return br, nil } diff --git a/build/build_test.go b/build/build_test.go index 2872f12c..eecd300f 100644 --- a/build/build_test.go +++ b/build/build_test.go @@ -22,7 +22,7 @@ func TestServerBuildErrors(t *testing.T) { t.Run("unqualified ref", func(t *testing.T) { err := s.Build("x", model.File{}) - if !errors.Is(err, ErrRefUnqualified) { + if !errors.Is(err, ErrIncompleteRef) { t.Fatalf("Build() err = %v; want unqualified ref", err) } }) From 2318ed2919427f081432e7848409532088d0718a Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Mon, 1 Apr 2024 21:59:38 -0700 Subject: [PATCH 19/29] build: remove unused manifest() --- build/build.go | 8 -------- 1 file changed, 8 deletions(-) diff --git a/build/build.go b/build/build.go index f8e2fbcf..bca31eac 100644 --- a/build/build.go +++ b/build/build.go @@ -107,14 +107,6 @@ func (s *Server) LayerFile(digest string) (string, error) { return fileName, nil } -func (s *Server) Manifest(ref string) (Manifest, error) { - br, err := parseCompleteRef(ref) - if err != nil { - return Manifest{}, err - } - return s.getManifest(br) -} - func (s *Server) ManifestData(ref string) ([]byte, error) { br, err := parseCompleteRef(ref) if err != nil { From f488652ba797cef090e1e669c3ea5132c1f9fec8 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Mon, 1 Apr 2024 22:12:43 -0700 Subject: [PATCH 20/29] build: make Build accept only refs without builds --- build/blob/ref.go | 4 ++++ build/build.go | 30 +++++++++++++++++++++--------- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/build/blob/ref.go b/build/blob/ref.go index b8ad5203..093701e1 100644 --- a/build/blob/ref.go +++ b/build/blob/ref.go @@ -125,6 +125,10 @@ func (r Ref) Complete() bool { return r.Valid() && !slices.Contains(r.Parts(), "") } +func (r Ref) CompleteWithoutBuild() bool { + return r.Valid() && !slices.Contains(r.Parts()[:tag], "") +} + // Less returns true if r is less concrete than o; false otherwise. func (r Ref) Less(o Ref) bool { rp := r.Parts() diff --git a/build/build.go b/build/build.go index bca31eac..9f80820d 100644 --- a/build/build.go +++ b/build/build.go @@ -16,6 +16,7 @@ import ( // Errors var ( ErrIncompleteRef = errors.New("unqualified ref") + ErrBuildPresentInRef = errors.New("build present in ref") ErrUnsupportedModelFormat = errors.New("unsupported model format") ErrMissingFileType = errors.New("missing 'general.file_type' key") ErrNotFound = errors.New("not found") @@ -53,8 +54,8 @@ func Open(dir string) (*Server, error) { func (s *Server) Build(ref string, f model.File) error { br := blob.ParseRef(ref) - if !br.Complete() { - return fmt.Errorf("%w: %q", ErrIncompleteRef, br.Full()) + if !br.CompleteWithoutBuild() { + return fmt.Errorf("%w: %q", ErrIncompleteRef, ref) } // 1. Resolve FROM @@ -91,11 +92,10 @@ func (s *Server) Build(ref string, f model.File) error { Size: size, }) - data, err := json.Marshal(Manifest{Layers: layers}) - if err != nil { - return err - } - return s.Set(br.WithBuild(info.FileType.String()), data) + return s.setManifestData( + br.WithBuild(info.FileType.String()), + Manifest{Layers: layers}, + ) } func (s *Server) LayerFile(digest string) (string, error) { @@ -161,9 +161,21 @@ func (s *Server) resolve(ref blob.Ref) (data []byte, path string, err error) { return data, path, nil } +func (s *Server) SetManifestData(ref string, m Manifest) error { + br, err := parseCompleteRef(ref) + if err != nil { + return err + } + return s.setManifestData(br, m) +} + // Set sets the data for the given ref. -func (s *Server) Set(ref blob.Ref, data []byte) error { - path, err := s.refFileName(ref) +func (s *Server) setManifestData(br blob.Ref, m Manifest) error { + data, err := json.Marshal(m) + if err != nil { + return err + } + path, err := s.refFileName(br) if err != nil { return err } From ce3125afd5e3fb96e93f301728806ef3c5579d07 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Mon, 1 Apr 2024 22:53:49 -0700 Subject: [PATCH 21/29] registry: add New and take a minio client as argument --- registry/server.go | 42 ++++++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/registry/server.go b/registry/server.go index 1dbc6392..5c7ba691 100644 --- a/registry/server.go +++ b/registry/server.go @@ -27,7 +27,13 @@ func DefaultLibrary() string { return defaultLibrary } -type Server struct{} +type Server struct { + minioClient *minio.Client +} + +func New(mc *minio.Client) *Server { + return &Server{minioClient: mc} +} func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err := s.serveHTTP(w, r); err != nil { @@ -65,11 +71,6 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { return oweb.Mistake("invalid", "name", "must be fully qualified") } - mc, err := minio.New("localhost:9000", &minio.Options{ - Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""), - Secure: false, - }) - m, err := oweb.DecodeUserJSON[apitype.Manifest]("manifest", bytes.NewReader(pr.Manifest)) if err != nil { return err @@ -89,7 +90,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { if !pushed { const expires = 1 * time.Hour key := path.Join("blobs", l.Digest) - signedURL, err := mc.PresignedPutObject(r.Context(), "test", key, expires) + signedURL, err := s.mc().PresignedPutObject(r.Context(), "test", key, expires) if err != nil { return err } @@ -107,7 +108,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { // Commit the manifest body := bytes.NewReader(pr.Manifest) path := path.Join("manifests", path.Join(ref.Parts()...)) - _, err := mc.PutObject(r.Context(), "test", path, body, int64(len(pr.Manifest)), minio.PutObjectOptions{}) + _, err := s.mc().PutObject(r.Context(), "test", path, body, int64(len(pr.Manifest)), minio.PutObjectOptions{}) if err != nil { return err } @@ -122,18 +123,9 @@ func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error { } func (s *Server) statObject(ctx context.Context, digest string) (pushed bool, err error) { - // TODO(bmizerany): hold client on *Server (hack for now) - mc, err := minio.New("localhost:9000", &minio.Options{ - Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""), - Secure: false, - }) - if err != nil { - return false, err - } - // HEAD the object path := path.Join("blobs", digest) - _, err = mc.StatObject(ctx, "test", path, minio.StatObjectOptions{}) + _, err = s.mc().StatObject(ctx, "test", path, minio.StatObjectOptions{}) if err != nil { if isNoSuchKey(err) { err = nil @@ -147,3 +139,17 @@ func isNoSuchKey(err error) bool { var e minio.ErrorResponse return errors.As(err, &e) && e.Code == "NoSuchKey" } + +func (s *Server) mc() *minio.Client { + if s.minioClient != nil { + return s.minioClient + } + mc, err := minio.New("localhost:9000", &minio.Options{ + Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""), + Secure: false, + }) + if err != nil { + panic(err) + } + return mc +} From 628f1feb3686a7525d582dc4851ca1bf7fc6d6fd Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Mon, 1 Apr 2024 23:16:18 -0700 Subject: [PATCH 22/29] build: back to taking manifests as []byte Its nicer to have the manifests be an opaque []byte, rather than a struct. This way users of the build package don't need to know about the internal structure of the manifests. The registry can interpret the manifests as it sees fit, while letting build keep its own Go type of manifest which is easier to work with in the build package. --- build/build.go | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/build/build.go b/build/build.go index 9f80820d..d26718a7 100644 --- a/build/build.go +++ b/build/build.go @@ -92,9 +92,14 @@ func (s *Server) Build(ref string, f model.File) error { Size: size, }) + data, err := json.Marshal(manifestJSON{Layers: layers}) + if err != nil { + return err + } + return s.setManifestData( br.WithBuild(info.FileType.String()), - Manifest{Layers: layers}, + data, ) } @@ -161,20 +166,16 @@ func (s *Server) resolve(ref blob.Ref) (data []byte, path string, err error) { return data, path, nil } -func (s *Server) SetManifestData(ref string, m Manifest) error { +func (s *Server) SetManifestData(ref string, data []byte) error { br, err := parseCompleteRef(ref) if err != nil { return err } - return s.setManifestData(br, m) + return s.setManifestData(br, data) } // Set sets the data for the given ref. -func (s *Server) setManifestData(br blob.Ref, m Manifest) error { - data, err := json.Marshal(m) - if err != nil { - return err - } +func (s *Server) setManifestData(br blob.Ref, data []byte) error { path, err := s.refFileName(br) if err != nil { return err @@ -195,7 +196,7 @@ func (s *Server) refFileName(ref blob.Ref) (string, error) { return filepath.Join(s.st.Dir(), "manifests", filepath.Join(ref.Parts()...)), nil } -type Manifest struct { +type manifestJSON struct { // Layers is the list of layers in the manifest. Layers []layerJSON `json:"layers"` } @@ -208,14 +209,14 @@ type layerJSON struct { Size int64 `json:"size"` } -func (s *Server) getManifest(ref blob.Ref) (Manifest, error) { +func (s *Server) getManifest(ref blob.Ref) (manifestJSON, error) { data, path, err := s.resolve(ref) if err != nil { - return Manifest{}, err + return manifestJSON{}, err } - var m Manifest + var m manifestJSON if err := json.Unmarshal(data, &m); err != nil { - return Manifest{}, &fs.PathError{Op: "unmarshal", Path: path, Err: err} + return manifestJSON{}, &fs.PathError{Op: "unmarshal", Path: path, Err: err} } return m, nil } From aff7970628c7297cec99064a74b03d7b71536787 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Mon, 1 Apr 2024 23:41:42 -0700 Subject: [PATCH 23/29] build: remove superfluous parseCompleteRef --- build/build.go | 26 +++----------------------- 1 file changed, 3 insertions(+), 23 deletions(-) diff --git a/build/build.go b/build/build.go index d26718a7..913d1d12 100644 --- a/build/build.go +++ b/build/build.go @@ -113,21 +113,13 @@ func (s *Server) LayerFile(digest string) (string, error) { } func (s *Server) ManifestData(ref string) ([]byte, error) { - br, err := parseCompleteRef(ref) - if err != nil { - return nil, err - } - data, _, err := s.resolve(br) + data, _, err := s.resolve(blob.ParseRef(ref)) return data, err } // WeightFile returns the absolute path to the weights file for the given model ref. func (s *Server) WeightsFile(ref string) (string, error) { - br, err := parseCompleteRef(ref) - if err != nil { - return "", err - } - m, err := s.getManifest(br) + m, err := s.getManifest(blob.ParseRef(ref)) if err != nil { return "", err } @@ -167,11 +159,7 @@ func (s *Server) resolve(ref blob.Ref) (data []byte, path string, err error) { } func (s *Server) SetManifestData(ref string, data []byte) error { - br, err := parseCompleteRef(ref) - if err != nil { - return err - } - return s.setManifestData(br, data) + return s.setManifestData(blob.ParseRef(ref), data) } // Set sets the data for the given ref. @@ -220,11 +208,3 @@ func (s *Server) getManifest(ref blob.Ref) (manifestJSON, error) { } return m, nil } - -func parseCompleteRef(ref string) (blob.Ref, error) { - br := blob.ParseRef(ref) - if !br.Complete() { - return blob.Ref{}, fmt.Errorf("%w: %q", ErrIncompleteRef, ref) - } - return br, nil -} From 9959da05de8dcf10f9c26e978974299e3ae3fbf5 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Tue, 2 Apr 2024 11:38:10 -0700 Subject: [PATCH 24/29] build/blob: break out test refs for other tests/fuzzing --- build/blob/ref_test.go | 75 ++++++++++++------------------------------ 1 file changed, 21 insertions(+), 54 deletions(-) diff --git a/build/blob/ref_test.go b/build/blob/ref_test.go index 1b6161c6..03f6ca24 100644 --- a/build/blob/ref_test.go +++ b/build/blob/ref_test.go @@ -7,6 +7,22 @@ const ( refTooLong = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" ) +var testRefs = map[string]Ref{ + "mistral:latest": {name: "mistral", tag: "latest"}, + "mistral": {name: "mistral"}, + "mistral:30B": {name: "mistral", tag: "30B"}, + "mistral:7b": {name: "mistral", tag: "7b"}, + "mistral:7b+Q4_0": {name: "mistral", tag: "7b", build: "Q4_0"}, + "mistral+KQED": {name: "mistral", build: "KQED"}, + "mistral.x-3:7b+Q4_0": {name: "mistral.x-3", tag: "7b", build: "Q4_0"}, + "mistral:7b+q4_0": {name: "mistral", tag: "7b", build: "Q4_0"}, + "llama2:+": {name: "llama2"}, + + // invalid + "mistral:7b+Q4_0:latest": {}, + "mi tral": {}, +} + func TestRefParts(t *testing.T) { const wantNumParts = 5 var ref Ref @@ -16,60 +32,11 @@ func TestRefParts(t *testing.T) { } func TestParseRef(t *testing.T) { - cases := []struct { - in string - want Ref - }{ - {"mistral:latest", Ref{ - name: "mistral", - tag: "latest", - }}, - {"mistral", Ref{ - name: "mistral", - }}, - {"mistral:30B", Ref{ - name: "mistral", - tag: "30B", - }}, - {"mistral:7b", Ref{ - name: "mistral", - tag: "7b", - }}, - {"mistral:7b+Q4_0", Ref{ - name: "mistral", - tag: "7b", - build: "Q4_0", - }}, - {"mistral+KQED", Ref{ - name: "mistral", - build: "KQED", - }}, - {"mistral.x-3:7b+Q4_0", Ref{ - name: "mistral.x-3", - tag: "7b", - build: "Q4_0", - }}, - - // lowecase build - {"mistral:7b+q4_0", Ref{ - name: "mistral", - tag: "7b", - build: "Q4_0", - }}, - {"llama2:+", Ref{name: "llama2"}}, - - // Invalid - {"mistral:7b+Q4_0:latest", Ref{}}, - {"mi tral", Ref{}}, - - // too long - {refTooLong, Ref{}}, - } - for _, tt := range cases { - t.Run(tt.in, func(t *testing.T) { - got := ParseRef(tt.in) - if got != tt.want { - t.Errorf("ParseRef(%q) = %q; want %q", tt.in, got, tt.want) + for s, want := range testRefs { + t.Run(s, func(t *testing.T) { + got := ParseRef(s) + if got != want { + t.Errorf("ParseRef(%q) = %q; want %q", s, got, want) } }) } From eb75418be9014f5db16e0de2454cf8cef3fe39d4 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Tue, 2 Apr 2024 11:45:01 -0700 Subject: [PATCH 25/29] build/blob: test ParseRef round-trip --- build/blob/ref.go | 6 ++++++ build/blob/ref_test.go | 8 +++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/build/blob/ref.go b/build/blob/ref.go index 093701e1..e7c75841 100644 --- a/build/blob/ref.go +++ b/build/blob/ref.go @@ -204,6 +204,9 @@ func ParseRef(s string) Ref { switch state { case build: r.build = s[i+1 : j] + if r.build == "" { + return Ref{} + } r.build = strings.ToUpper(r.build) state, j = tag, i default: @@ -213,6 +216,9 @@ func ParseRef(s string) Ref { switch state { case build, tag: r.tag = s[i+1 : j] + if r.tag == "" { + return Ref{} + } state, j = name, i default: return Ref{} diff --git a/build/blob/ref_test.go b/build/blob/ref_test.go index 03f6ca24..bf4333df 100644 --- a/build/blob/ref_test.go +++ b/build/blob/ref_test.go @@ -16,7 +16,7 @@ var testRefs = map[string]Ref{ "mistral+KQED": {name: "mistral", build: "KQED"}, "mistral.x-3:7b+Q4_0": {name: "mistral.x-3", tag: "7b", build: "Q4_0"}, "mistral:7b+q4_0": {name: "mistral", tag: "7b", build: "Q4_0"}, - "llama2:+": {name: "llama2"}, + "llama2": {name: "llama2"}, // invalid "mistral:7b+Q4_0:latest": {}, @@ -38,6 +38,11 @@ func TestParseRef(t *testing.T) { if got != want { t.Errorf("ParseRef(%q) = %q; want %q", s, got, want) } + + // test round-trip + if ParseRef(got.String()) != got { + t.Errorf("String() = %q; want %q", got.String(), s) + } }) } } @@ -56,6 +61,7 @@ func TestRefFull(t *testing.T) { {"example.com/x/mistral:latest+Q4_0", "example.com/x/mistral:latest+Q4_0"}, {"mistral:7b+x", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:7b+X"}, + {"mistral:7b+q4_0", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:7b+Q4_0"}, {"mistral:7b+Q4_0", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:7b+Q4_0"}, {"mistral:latest", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:latest+!(MISSING BUILD)"}, {"mistral", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:!(MISSING TAG)+!(MISSING BUILD)"}, From 618eb5b909aa7a26a53223c5f43fc653981a60da Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Tue, 2 Apr 2024 13:40:23 -0700 Subject: [PATCH 26/29] registry: multipart push --- registry/apitype/apitype.go | 3 ++- registry/server.go | 9 +++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/registry/apitype/apitype.go b/registry/apitype/apitype.go index e33dfe34..19cda1ab 100644 --- a/registry/apitype/apitype.go +++ b/registry/apitype/apitype.go @@ -19,7 +19,8 @@ type PushRequest struct { type Requirement struct { Digest string `json:"digest"` - Size int64 `json:"size"` + Offset int64 `json:"offset"` + Size int64 `json:"Size"` URL string `json:"url"` } diff --git a/registry/server.go b/registry/server.go index 5c7ba691..0059cf07 100644 --- a/registry/server.go +++ b/registry/server.go @@ -77,26 +77,31 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { } // TODO(bmizerany): parallelize + const chunkSizeTODO = 50 * 1024 * 1024 var requirements []apitype.Requirement for _, l := range m.Layers { if l.Size == 0 { continue } + // TODO(bmizerany): "global" throttle of rate of transfer + pushed, err := s.statObject(r.Context(), l.Digest) if err != nil { return err } if !pushed { - const expires = 1 * time.Hour + const expires = 15 * time.Minute key := path.Join("blobs", l.Digest) signedURL, err := s.mc().PresignedPutObject(r.Context(), "test", key, expires) if err != nil { return err } + + size := min(l.Size, chunkSizeTODO) requirements = append(requirements, apitype.Requirement{ Digest: l.Digest, - Size: l.Size, + Size: size, // TODO(bmizerany): use signed+temp urls URL: signedURL.String(), From c95f97689b43bd6a43d41dcd110a24c2b6942211 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Tue, 2 Apr 2024 14:15:21 -0700 Subject: [PATCH 27/29] utils/upload: init --- utils/upload/upload.go | 27 +++++++++++++++++++++++++++ utils/upload/upload_test.go | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 utils/upload/upload.go create mode 100644 utils/upload/upload_test.go diff --git a/utils/upload/upload.go b/utils/upload/upload.go new file mode 100644 index 00000000..c7447b54 --- /dev/null +++ b/utils/upload/upload.go @@ -0,0 +1,27 @@ +package upload + +import ( + "iter" + + "golang.org/x/exp/constraints" +) + +type Chunk[I constraints.Integer] struct { + Offset I + Size I +} + +// Chunks yields a sequence of a part number and a Chunk. The Chunk is the offset +// and size of the chunk. The last chunk may be smaller than chunkSize if size is +// not a multiple of chunkSize. +// +// The first part number is 1 and increases monotonically. +func Chunks[I constraints.Integer](size, chunkSize I) iter.Seq2[int, Chunk[I]] { + return func(yield func(int, Chunk[I]) bool) { + var n int + for off := I(0); off < size; off += chunkSize { + n++ + yield(n, Chunk[I]{off, min(chunkSize, size-off)}) + } + } +} diff --git a/utils/upload/upload_test.go b/utils/upload/upload_test.go new file mode 100644 index 00000000..44ad7f21 --- /dev/null +++ b/utils/upload/upload_test.go @@ -0,0 +1,37 @@ +package upload + +import ( + "testing" + + "kr.dev/diff" +) + +func TestChunks(t *testing.T) { + const size = 101 + const chunkSize = 10 + var got []Chunk[int] + var lastN int + for n, c := range Chunks(size, chunkSize) { + if n != lastN+1 { + t.Errorf("n = %d; want %d", n, lastN+1) + } + got = append(got, c) + lastN = n + } + + want := []Chunk[int]{ + {0, 10}, + {10, 10}, + {20, 10}, + {30, 10}, + {40, 10}, + {50, 10}, + {60, 10}, + {70, 10}, + {80, 10}, + {90, 10}, + {100, 1}, + } + + diff.Test(t, t.Errorf, got, want) +} From 94befe366a9429ea2c77914068dc7d776c1c4173 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Tue, 2 Apr 2024 14:28:06 -0700 Subject: [PATCH 28/29] ... --- registry/server.go | 58 +++++++++++++++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 16 deletions(-) diff --git a/registry/server.go b/registry/server.go index 0059cf07..6d99669a 100644 --- a/registry/server.go +++ b/registry/server.go @@ -7,19 +7,28 @@ import ( "context" "errors" "log" + "math/rand" "net/http" + "net/url" "os" "path" + "strconv" "time" "bllamo.com/build/blob" "bllamo.com/client/ollama" "bllamo.com/oweb" "bllamo.com/registry/apitype" + "bllamo.com/utils/upload" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" ) +// Defaults +const ( + DefaultUploadChunkSize = 50 * 1024 * 1024 +) + // TODO(bmizerany): move all env things to package envkobs? var defaultLibrary = cmp.Or(os.Getenv("OLLAMA_REGISTRY"), "registry.ollama.ai/library") @@ -28,7 +37,8 @@ func DefaultLibrary() string { } type Server struct { - minioClient *minio.Client + UploadChunkSize int64 // default is DefaultUploadChunkSize + minioClient *minio.Client } func New(mc *minio.Client) *Server { @@ -77,7 +87,6 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { } // TODO(bmizerany): parallelize - const chunkSizeTODO = 50 * 1024 * 1024 var requirements []apitype.Requirement for _, l := range m.Layers { if l.Size == 0 { @@ -91,21 +100,29 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { return err } if !pushed { - const expires = 15 * time.Minute - key := path.Join("blobs", l.Digest) - signedURL, err := s.mc().PresignedPutObject(r.Context(), "test", key, expires) - if err != nil { - return err + uploadID := generateUploadID() + for n, c := range upload.Chunks(l.Size, cmp.Or(s.UploadChunkSize, DefaultUploadChunkSize)) { + const expires = 15 * time.Minute + + key := path.Join("blobs", l.Digest) + signedURL, err := s.mc().Presign(r.Context(), "PUT", "test", key, expires, url.Values{ + "UploadId": []string{uploadID}, + "PartNumber": []string{strconv.Itoa(n)}, + "ContentLength": []string{strconv.FormatInt(c.Size, 10)}, + }) + if err != nil { + return err + } + + requirements = append(requirements, apitype.Requirement{ + Digest: l.Digest, + Offset: c.Offset, + Size: c.Size, + + // TODO(bmizerany): use signed+temp urls + URL: signedURL.String(), + }) } - - size := min(l.Size, chunkSizeTODO) - requirements = append(requirements, apitype.Requirement{ - Digest: l.Digest, - Size: size, - - // TODO(bmizerany): use signed+temp urls - URL: signedURL.String(), - }) } } @@ -158,3 +175,12 @@ func (s *Server) mc() *minio.Client { } return mc } + +func generateUploadID() string { + const hex = "0123456789abcdef" + b := make([]byte, 32) + for i := range b { + b[i] = hex[rand.Intn(len(hex))] + } + return string(b) +} From a10a11b9d371f36b7c3510da32a1d70b74e27bd1 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Wed, 3 Apr 2024 10:39:30 -0700 Subject: [PATCH 29/29] registry: initial work on multipart pushes --- api/api.go | 18 ++- registry/apitype/apitype.go | 21 ++- registry/client.go | 51 +++++-- registry/server.go | 93 +++++++++---- registry/server_test.go | 260 +++++++++++++++++++++++++----------- utils/backoff/backoff.go | 58 ++++++++ 6 files changed, 378 insertions(+), 123 deletions(-) create mode 100644 utils/backoff/backoff.go diff --git a/api/api.go b/api/api.go index b1955a18..2d7800e4 100644 --- a/api/api.go +++ b/api/api.go @@ -10,6 +10,7 @@ import ( "bllamo.com/client/ollama/apitype" "bllamo.com/oweb" "bllamo.com/registry" + regtype "bllamo.com/registry/apitype" ) // Common API Errors @@ -64,11 +65,12 @@ func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error { } c := registry.Client{BaseURL: registryURLTODO} - requirements, err := c.Push(r.Context(), params.Name, man) + requirements, err := c.Push(r.Context(), params.Name, man, nil) if err != nil { return err } + var uploads []regtype.CompletePart for _, rq := range requirements { l, err := s.Build.LayerFile(rq.Digest) if err != nil { @@ -80,7 +82,15 @@ func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error { return err } defer f.Close() - return registry.PushLayer(r.Context(), rq.URL, rq.Size, f) + etag, err := registry.PushLayer(r.Context(), rq.URL, rq.Offset, rq.Size, f) + if err != nil { + return err + } + uploads = append(uploads, regtype.CompletePart{ + URL: rq.URL, + ETag: etag, + }) + return nil }() if err != nil { return err @@ -88,7 +98,9 @@ func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error { } // commit the manifest to the registry - requirements, err = c.Push(r.Context(), params.Name, man) + requirements, err = c.Push(r.Context(), params.Name, man, ®istry.PushParams{ + Uploaded: uploads, + }) if err != nil { return err } diff --git a/registry/apitype/apitype.go b/registry/apitype/apitype.go index 19cda1ab..36f2a342 100644 --- a/registry/apitype/apitype.go +++ b/registry/apitype/apitype.go @@ -6,6 +6,11 @@ type Manifest struct { Layers []Layer `json:"layers"` } +type CompletePart struct { + URL string `json:"url"` // contains PartNumber and UploadID from server + ETag string `json:"etag"` +} + type Layer struct { Digest string `json:"digest"` MediaType string `json:"mediaType"` @@ -13,15 +18,25 @@ type Layer struct { } type PushRequest struct { - Ref string `json:"ref"` - Manifest json.RawMessage + Ref string `json:"ref"` + Manifest json.RawMessage `json:"manifest"` + + // Parts is a list of upload parts that the client upload in the previous + // push. + Uploaded []CompletePart `json:"part_uploads"` } type Requirement struct { Digest string `json:"digest"` Offset int64 `json:"offset"` Size int64 `json:"Size"` - URL string `json:"url"` + + // URL is the url to PUT the layer to. + // + // Clients must include it as the URL, alond with the ETag in the + // response headers from the PUT request, in the next push request + // in the Uploaded field. + URL string `json:"url"` } type PushResponse struct { diff --git a/registry/client.go b/registry/client.go index 82616380..747dde57 100644 --- a/registry/client.go +++ b/registry/client.go @@ -1,7 +1,10 @@ package registry import ( + "cmp" "context" + "encoding/xml" + "fmt" "io" "net/http" @@ -18,12 +21,18 @@ func (c *Client) oclient() *ollama.Client { return (*ollama.Client)(c) } +type PushParams struct { + Uploaded []apitype.CompletePart +} + // Push pushes a manifest to the server. -func (c *Client) Push(ctx context.Context, ref string, manifest []byte) ([]apitype.Requirement, error) { +func (c *Client) Push(ctx context.Context, ref string, manifest []byte, p *PushParams) ([]apitype.Requirement, error) { + p = cmp.Or(p, &PushParams{}) // TODO(bmizerany): backoff v, err := ollama.Do[apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{ Ref: ref, Manifest: manifest, + Uploaded: p.Uploaded, }) if err != nil { return nil, err @@ -31,26 +40,42 @@ func (c *Client) Push(ctx context.Context, ref string, manifest []byte) ([]apity return v.Requirements, nil } -func PushLayer(ctx context.Context, dstURL string, size int64, file io.Reader) error { - req, err := http.NewRequest("PUT", dstURL, file) +func PushLayer(ctx context.Context, dstURL string, off, size int64, file io.ReaderAt) (etag string, err error) { + sr := io.NewSectionReader(file, off, size) + req, err := http.NewRequestWithContext(ctx, "PUT", dstURL, sr) if err != nil { - return err + return "", err } req.ContentLength = size res, err := http.DefaultClient.Do(req) if err != nil { - return err + return "", err } defer res.Body.Close() if res.StatusCode != 200 { - e := &ollama.Error{Status: res.StatusCode} - msg, err := io.ReadAll(res.Body) - if err != nil { - return err - } - // TODO(bmizerany): format error message - e.Message = string(msg) + return "", parseS3Error(res) } - return nil + return res.Header.Get("ETag"), nil +} + +type s3Error struct { + XMLName xml.Name `xml:"Error"` + Code string `xml:"Code"` + Message string `xml:"Message"` + Resource string `xml:"Resource"` + RequestId string `xml:"RequestId"` +} + +func (e *s3Error) Error() string { + return fmt.Sprintf("S3 (%s): %s: %s: %s", e.RequestId, e.Resource, e.Code, e.Message) +} + +// parseS3Error parses an XML error response from S3. +func parseS3Error(res *http.Response) error { + var se *s3Error + if err := xml.NewDecoder(res.Body).Decode(&se); err != nil { + return err + } + return se } diff --git a/registry/server.go b/registry/server.go index 6d99669a..91659767 100644 --- a/registry/server.go +++ b/registry/server.go @@ -6,8 +6,8 @@ import ( "cmp" "context" "errors" + "fmt" "log" - "math/rand" "net/http" "net/url" "os" @@ -70,7 +70,13 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error { } } +func (s *Server) uploadChunkSize() int64 { + return cmp.Or(s.UploadChunkSize, DefaultUploadChunkSize) +} + func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { + const bucketTODO = "test" + pr, err := oweb.DecodeUserJSON[apitype.PushRequest]("", r.Body) if err != nil { return err @@ -78,7 +84,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { ref := blob.ParseRef(pr.Ref) if !ref.Complete() { - return oweb.Mistake("invalid", "name", "must be fully qualified") + return oweb.Mistake("invalid", "name", "must be complete") } m, err := oweb.DecodeUserJSON[apitype.Manifest]("manifest", bytes.NewReader(pr.Manifest)) @@ -86,28 +92,80 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { return err } - // TODO(bmizerany): parallelize + mcc := &minio.Core{Client: s.mc()} + // TODO(bmizerany): complete uploads before stats for any with ETag + + type completeParts struct { + key string + parts []minio.CompletePart + } + + completePartsByUploadID := make(map[string]completeParts) + for _, pu := range pr.Uploaded { + // parse the URL + u, err := url.Parse(pu.URL) + if err != nil { + return err + } + q := u.Query() + uploadID := q.Get("UploadId") + if uploadID == "" { + return oweb.Mistake("invalid", "url", "missing UploadId") + } + partNumber, err := strconv.Atoi(q.Get("PartNumber")) + if err != nil { + return oweb.Mistake("invalid", "url", "invalid or missing PartNumber") + } + etag := pu.ETag + if etag == "" { + return oweb.Mistake("invalid", "etag", "missing") + } + cp, ok := completePartsByUploadID[uploadID] + if !ok { + cp = completeParts{key: u.Path} + completePartsByUploadID[uploadID] = cp + } + cp.parts = append(cp.parts, minio.CompletePart{ + PartNumber: partNumber, + ETag: etag, + }) + fmt.Println("uploadID", uploadID, "partNumber", partNumber, "etag", etag) + completePartsByUploadID[uploadID] = cp + } + + for uploadID, cp := range completePartsByUploadID { + var zeroOpts minio.PutObjectOptions + _, err := mcc.CompleteMultipartUpload(r.Context(), bucketTODO, cp.key, uploadID, cp.parts, zeroOpts) + if err != nil { + // log and continue; put backpressure on the client + log.Printf("error completing upload: %v", err) + } + } + var requirements []apitype.Requirement for _, l := range m.Layers { + // TODO(bmizerany): do in parallel if l.Size == 0 { continue } // TODO(bmizerany): "global" throttle of rate of transfer - pushed, err := s.statObject(r.Context(), l.Digest) if err != nil { return err } if !pushed { - uploadID := generateUploadID() - for n, c := range upload.Chunks(l.Size, cmp.Or(s.UploadChunkSize, DefaultUploadChunkSize)) { - const expires = 15 * time.Minute + key := path.Join("blobs", l.Digest) + uploadID, err := mcc.NewMultipartUpload(r.Context(), bucketTODO, key, minio.PutObjectOptions{}) + if err != nil { + return err + } + for partNumber, c := range upload.Chunks(l.Size, s.uploadChunkSize()) { + const timeToStartUpload = 15 * time.Minute - key := path.Join("blobs", l.Digest) - signedURL, err := s.mc().Presign(r.Context(), "PUT", "test", key, expires, url.Values{ + signedURL, err := s.mc().Presign(r.Context(), "PUT", bucketTODO, key, timeToStartUpload, url.Values{ "UploadId": []string{uploadID}, - "PartNumber": []string{strconv.Itoa(n)}, + "PartNumber": []string{strconv.Itoa(partNumber)}, "ContentLength": []string{strconv.FormatInt(c.Size, 10)}, }) if err != nil { @@ -118,9 +176,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { Digest: l.Digest, Offset: c.Offset, Size: c.Size, - - // TODO(bmizerany): use signed+temp urls - URL: signedURL.String(), + URL: signedURL.String(), }) } } @@ -130,7 +186,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { // Commit the manifest body := bytes.NewReader(pr.Manifest) path := path.Join("manifests", path.Join(ref.Parts()...)) - _, err := s.mc().PutObject(r.Context(), "test", path, body, int64(len(pr.Manifest)), minio.PutObjectOptions{}) + _, err := s.mc().PutObject(r.Context(), bucketTODO, path, body, int64(len(pr.Manifest)), minio.PutObjectOptions{}) if err != nil { return err } @@ -175,12 +231,3 @@ func (s *Server) mc() *minio.Client { } return mc } - -func generateUploadID() string { - const hex = "0123456789abcdef" - b := make([]byte, 32) - for i := range b { - b[i] = hex[rand.Intn(len(hex))] - } - return string(b) -} diff --git a/registry/server_test.go b/registry/server_test.go index 466fd787..8cb1ecc1 100644 --- a/registry/server_test.go +++ b/registry/server_test.go @@ -1,118 +1,189 @@ package registry import ( + "bufio" "context" "encoding/json" "errors" + "fmt" "io" + "net" "net/http/httptest" + "os" "os/exec" "strings" "testing" "time" "bllamo.com/registry/apitype" - "github.com/kr/pretty" + "bllamo.com/utils/backoff" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" "kr.dev/diff" ) -func TestPush(t *testing.T) { - startMinio(t) +const abc = "abcdefghijklmnopqrstuvwxyz" - s := &Server{} - hs := httptest.NewServer(s) - t.Cleanup(hs.Close) - c := &Client{BaseURL: hs.URL} +func testPush(t *testing.T, chunkSize int64) { + t.Run(fmt.Sprintf("chunkSize=%d", chunkSize), func(t *testing.T) { + mc := startMinio(t, false) - manifest := []byte(`{ - "layers": [ - {"digest": "sha256-1", "size": 1}, - {"digest": "sha256-2", "size": 2}, - {"digest": "sha256-3", "size": 3} - ] - }`) + manifest := []byte(`{ + "layers": [ + {"digest": "sha256-1", "size": 1}, + {"digest": "sha256-2", "size": 2}, + {"digest": "sha256-3", "size": 3} + ] + }`) - const ref = "registry.ollama.ai/x/y:latest+Z" + const ref = "registry.ollama.ai/x/y:latest+Z" - got, err := c.Push(context.Background(), ref, manifest) - if err != nil { - t.Fatal(err) - } + hs := httptest.NewServer(&Server{ + minioClient: mc, + UploadChunkSize: chunkSize, + }) + t.Cleanup(hs.Close) + c := &Client{BaseURL: hs.URL} - diff.Test(t, t.Errorf, got, []apitype.Requirement{ - {Digest: "sha256-1", Size: 1}, - {Digest: "sha256-2", Size: 2}, - {Digest: "sha256-3", Size: 3}, - }, diff.ZeroFields[apitype.Requirement]("URL")) - - for _, r := range got { - body := io.Reader(strings.NewReader(strings.Repeat("x", int(r.Size)))) - if err := PushLayer(context.Background(), r.URL, r.Size, body); err != nil { + requirements, err := c.Push(context.Background(), ref, manifest, nil) + if err != nil { t.Fatal(err) } - } - got, err = c.Push(context.Background(), ref, manifest) - if err != nil { - t.Fatal(err) - } + if len(requirements) < 3 { + t.Fatalf("expected at least 3 requirements; got %d", len(requirements)) + t.Logf("requirements: %v", requirements) + } - if len(got) != 0 { - t.Fatalf("unexpected requirements: % #v", pretty.Formatter(got)) - } + var uploaded []apitype.CompletePart + for i, r := range requirements { + t.Logf("[%d] pushing layer: offset=%d size=%d", i, r.Offset, r.Size) - mc, err := minio.New("localhost:9000", &minio.Options{ - Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""), - Secure: false, - }) - if err != nil { - t.Fatal(err) - } + body := strings.NewReader(abc) + etag, err := PushLayer(context.Background(), r.URL, r.Offset, r.Size, body) + if err != nil { + t.Fatal(err) + } + uploaded = append(uploaded, apitype.CompletePart{ + URL: r.URL, + ETag: etag, + }) + } - var paths []string - keys := mc.ListObjects(context.Background(), "test", minio.ListObjectsOptions{ - Recursive: true, - }) - for k := range keys { - paths = append(paths, k.Key) - } + requirements, err = c.Push(context.Background(), ref, manifest, &PushParams{ + Uploaded: uploaded, + }) + if err != nil { + t.Fatal(err) + } + if len(requirements) != 0 { + t.Fatalf("unexpected requirements: %v", requirements) + } - t.Logf("paths: %v", paths) + var paths []string + keys := mc.ListObjects(context.Background(), "test", minio.ListObjectsOptions{ + Recursive: true, + }) + for k := range keys { + paths = append(paths, k.Key) + } - diff.Test(t, t.Errorf, paths, []string{ - "blobs/sha256-1", - "blobs/sha256-2", - "blobs/sha256-3", - "manifests/registry.ollama.ai/x/y/latest/Z", - }) + t.Logf("paths: %v", paths) - obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{}) - if err != nil { - t.Fatal(err) - } - defer obj.Close() + diff.Test(t, t.Errorf, paths, []string{ + "blobs/sha256-1", + "blobs/sha256-2", + "blobs/sha256-3", + "manifests/registry.ollama.ai/x/y/latest/Z", + }) - var gotM apitype.Manifest - if err := json.NewDecoder(obj).Decode(&gotM); err != nil { - t.Fatal(err) - } + obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{}) + if err != nil { + t.Fatal(err) + } + defer obj.Close() - diff.Test(t, t.Errorf, gotM, apitype.Manifest{ - Layers: []apitype.Layer{ - {Digest: "sha256-1", Size: 1}, - {Digest: "sha256-2", Size: 2}, - {Digest: "sha256-3", Size: 3}, - }, + var gotM apitype.Manifest + if err := json.NewDecoder(obj).Decode(&gotM); err != nil { + t.Fatal(err) + } + + diff.Test(t, t.Errorf, gotM, apitype.Manifest{ + Layers: []apitype.Layer{ + {Digest: "sha256-1", Size: 1}, + {Digest: "sha256-2", Size: 2}, + {Digest: "sha256-3", Size: 3}, + }, + }) + + // checksum the blobs + for i, l := range gotM.Layers { + obj, err := mc.GetObject(context.Background(), "test", "blobs/"+l.Digest, minio.GetObjectOptions{}) + if err != nil { + t.Fatal(err) + } + defer obj.Close() + + info, err := obj.Stat() + if err != nil { + t.Fatal(err) + } + t.Logf("[%d] layer info: name=%q l.Size=%d size=%d", i, info.Key, l.Size, info.Size) + + data, err := io.ReadAll(obj) + if err != nil { + t.Fatal(err) + } + + got := string(data) + want := abc[:l.Size] + if got != want { + t.Errorf("[%d] got layer data = %q; want %q", i, got, want) + } + } }) } -func startMinio(t *testing.T) { +func TestPush(t *testing.T) { + testPush(t, 0) + testPush(t, 1) +} + +func availableAddr() string { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + panic(err) + } + defer l.Close() + return l.Addr().String() +} + +func startMinio(t *testing.T, debug bool) *minio.Client { t.Helper() dir := t.TempDir() - cmd := exec.Command("minio", "server", "--address", "localhost:9000", dir) + t.Logf(">> minio data dir: %s", dir) + addr := availableAddr() + cmd := exec.Command("minio", "server", "--address", addr, dir) + cmd.Env = os.Environ() + + if debug { + stdout, err := cmd.StdoutPipe() + if err != nil { + t.Fatal(err) + } + doneLogging := make(chan struct{}) + t.Cleanup(func() { + <-doneLogging + }) + go func() { + defer close(doneLogging) + sc := bufio.NewScanner(stdout) + for sc.Scan() { + t.Logf("minio: %s", sc.Text()) + } + }() + } // TODO(bmizerany): wait delay etc... if err := cmd.Start(); err != nil { @@ -131,7 +202,7 @@ func startMinio(t *testing.T) { } }) - mc, err := minio.New("localhost:9000", &minio.Options{ + mc, err := minio.New(addr, &minio.Options{ Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""), Secure: false, }) @@ -139,17 +210,44 @@ func startMinio(t *testing.T) { t.Fatal(err) } - // wait for server to start - // TODO(bmizerany): use backoff - for { - _, err := mc.ListBuckets(context.Background()) - if err == nil { + ctx, cancel := context.WithCancel(context.Background()) + deadline, ok := t.Deadline() + if ok { + ctx, cancel = context.WithDeadline(ctx, deadline.Add(-100*time.Millisecond)) + defer cancel() + } + + // wait for server to start with exponential backoff + for _, err := range backoff.Upto(ctx, 1*time.Second) { + if err != nil { + t.Fatal(err) + } + if mc.IsOnline() { break } - time.Sleep(100 * time.Millisecond) } if err := mc.MakeBucket(context.Background(), "test", minio.MakeBucketOptions{}); err != nil { t.Fatal(err) } + + return mc +} + +// contextForTest returns a context that is canceled when the test deadline, +// if any, is reached. The returned doneLogging function should be called +// after all Log/Error/Fatalf calls are done before the test returns. +func contextForTest(t *testing.T) (_ context.Context, doneLogging func()) { + done := make(chan struct{}) + deadline, ok := t.Deadline() + if !ok { + return context.Background(), func() {} + } + + ctx, cancel := context.WithDeadline(context.Background(), deadline.Add(-100*time.Millisecond)) + t.Cleanup(func() { + cancel() + <-done + }) + return ctx, func() { close(done) } } diff --git a/utils/backoff/backoff.go b/utils/backoff/backoff.go new file mode 100644 index 00000000..b77f8706 --- /dev/null +++ b/utils/backoff/backoff.go @@ -0,0 +1,58 @@ +package backoff + +import ( + "context" + "errors" + "iter" + "math/rand" + "time" +) + +// Errors +var ( + // ErrMaxAttempts is not used by backoff but is available for use by + // callers that want to signal that a maximum number of retries has + // been exceeded. This should eliminate the need for callers to invent + // their own error. + ErrMaxAttempts = errors.New("max retries exceeded") +) + +// Upto implements a backoff strategy that yields nil errors until the +// context is canceled, the maxRetries is exceeded, or yield returns false. +// +// The backoff strategy is a simple exponential backoff with a maximum +// backoff of maxBackoff. The backoff is randomized between 0.5-1.5 times +// the current backoff, in order to prevent accidental "thundering herd" +// problems. +func Upto(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] { + var n int + return func(yield func(int, error) bool) { + for { + if ctx.Err() != nil { + yield(n, ctx.Err()) + return + } + + n++ + + // n^2 backoff timer is a little smoother than the + // common choice of 2^n. + d := time.Duration(n*n) * 10 * time.Millisecond + if d > maxBackoff { + d = maxBackoff + } + // Randomize the delay between 0.5-1.5 x msec, in order + // to prevent accidental "thundering herd" problems. + d = time.Duration(float64(d) * (rand.Float64() + 0.5)) + t := time.NewTimer(d) + select { + case <-ctx.Done(): + t.Stop() + case <-t.C: + if !yield(n, nil) { + return + } + } + } + } +}