commit c49947dcf5600e147fa76a8289c08615b49ada36 Author: Blake Mizerany Date: Thu Mar 7 20:33:57 2024 -0800 init diff --git a/api/api.go b/api/api.go new file mode 100644 index 00000000..fd525937 --- /dev/null +++ b/api/api.go @@ -0,0 +1,106 @@ +package api + +import ( + "errors" + "fmt" + "net/http" + "os" + + "bllamo.com/build" + "bllamo.com/build/blob" + "bllamo.com/client/ollama/apitype" + "bllamo.com/oweb" + "bllamo.com/registry" +) + +// Common API Errors +var ( + errUnqualifiedRef = oweb.Mistake("invalid", "name", "must be fully qualified") + errRefNotFound = oweb.Mistake("not_found", "name", "no such model") +) + +type Server struct { + Build *build.Server +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + oweb.Serve(s.serveHTTP, w, r) +} + +func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error { + switch r.URL.Path { + case "/v1/push": + return s.handlePush(w, r) + default: + return oweb.ErrNotFound + } +} + +func want(r *http.Request, method, path string) bool { + return r.Method == method && r.URL.Path == path +} + +func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error { + if r.Method != "POST" { + return oweb.ErrMethodNotAllowed + } + + params, err := oweb.DecodeJSON[apitype.PushRequest](r.Body) + if err != nil { + return err + } + + if params.Name == "" { + return oweb.Missing("name") + } + + const registryURLTODO = "http://localhost:8888" + + ref := blob.ParseRef(params.Name) + if !ref.FullyQualified() { + return errUnqualifiedRef + } + + man, err := s.Build.Manifest(ref) + if err != nil { + if errors.Is(err, build.ErrNotFound) { + return errRefNotFound + } + return err + } + + c := registry.Client{BaseURL: registryURLTODO} + requirements, err := c.Push(r.Context(), params.Name, man) + if err != nil { + return err + } + + for _, rq := range requirements { + l, err := s.Build.LayerFile(rq.Digest) + if err != nil { + return err + } + err = func() error { + f, err := os.Open(l) + if err != nil { + return err + } + defer f.Close() + return registry.PushLayer(r.Context(), rq.URL, rq.Size, f) + }() + if err != nil { + return err + } + } + + // commit the manifest to the registry + requirements, err = c.Push(r.Context(), params.Name, man) + if err != nil { + return err + } + for _, r := range requirements { + err = errors.Join(err, fmt.Errorf("push failed for %q", r.Digest)) + } + return err + +} diff --git a/build/blob/ref.go b/build/blob/ref.go new file mode 100644 index 00000000..9a033fcb --- /dev/null +++ b/build/blob/ref.go @@ -0,0 +1,135 @@ +package blob + +import ( + "cmp" + "strings" +) + +// Ref is an opaque reference to a blob. +// +// It is comparable and can be used as a map key. +// +// Users or Ref must check Valid before using it. +type Ref struct { + name string + tag string + build string +} + +// WithBuild returns a copy of r with the provided build. If the provided +// build is empty, it returns the short, unqualified copy of r. +func (r Ref) WithBuild(build string) Ref { + if build == "" { + return Ref{r.name, r.tag, ""} + } + if !isValidPart(build) { + return Ref{} + } + return makeRef(r.name, r.tag, build) +} + +// String returns the fully qualified ref string. +func (r Ref) String() string { + var b strings.Builder + b.WriteString(r.name) + if r.tag != "" { + b.WriteString(":") + b.WriteString(r.tag) + } + if r.build != "" { + b.WriteString("+") + b.WriteString(r.build) + } + return b.String() +} + +// Full returns the fully qualified ref string, or a string indicating the +// build is missing, or an empty string if the ref is invalid. +func (r Ref) Full() string { + if !r.Valid() { + return "" + } + return makeRef(r.name, r.tag, cmp.Or(r.build, "!(MISSING BUILD)")).String() +} + +// Short returns the short ref string which does not include the build. +func (r Ref) Short() string { + return r.WithBuild("").String() +} + +func (r Ref) Valid() bool { + return r.name != "" +} + +func (r Ref) FullyQualified() bool { + return r.name != "" && r.tag != "" && r.build != "" +} + +func (r Ref) Name() string { return r.name } +func (r Ref) Tag() string { return r.tag } +func (r Ref) Build() string { return r.build } + +// ParseRef parses a ref string into a Ref. A ref string is a name, an +// optional tag, and an optional build, separated by colons and pluses. +// +// The name must be valid ascii [a-zA-Z0-9_]. +// The tag must be valid ascii [a-zA-Z0-9_]. +// The build must be valid ascii [a-zA-Z0-9_]. +// +// It returns then zero value if the ref is invalid. +// +// // Valid Examples: +// ParseRef("mistral:latest") returns ("mistral", "latest", "") +// ParseRef("mistral") returns ("mistral", "", "") +// ParseRef("mistral:30B") returns ("mistral", "30B", "") +// ParseRef("mistral:7b") returns ("mistral", "7b", "") +// ParseRef("mistral:7b+Q4_0") returns ("mistral", "7b", "Q4_0") +// ParseRef("mistral+KQED") returns ("mistral", "latest", "KQED") +// ParseRef(".x.:7b+Q4_0:latest") returns (".x.", "7b", "Q4_0") +// ParseRef("-grok-f.oo:7b+Q4_0") returns ("-grok-f.oo", "7b", "Q4_0") +// +// // Invalid Examples: +// ParseRef("m stral") returns ("", "", "") // zero +// ParseRef("... 129 chars ...") returns ("", "", "") // zero +func ParseRef(s string) Ref { + if len(s) > 128 { + return Ref{} + } + + nameAndTag, build, expectBuild := strings.Cut(s, "+") + name, tag, expectTag := strings.Cut(nameAndTag, ":") + if !isValidPart(name) { + return Ref{} + } + if expectTag && !isValidPart(tag) { + return Ref{} + } + if expectBuild && !isValidPart(build) { + return Ref{} + } + return makeRef(name, tag, build) +} + +// makeRef makes a ref, skipping validation. +func makeRef(name, tag, build string) Ref { + return Ref{name, cmp.Or(tag, "latest"), strings.ToUpper(build)} +} + +// isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-] +func isValidPart(s string) bool { + if len(s) == 0 { + return false + } + for _, c := range []byte(s) { + if c == '.' || c == '-' { + return true + } + if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' { + continue + } else { + return false + + } + } + return true +} diff --git a/build/blob/ref_test.go b/build/blob/ref_test.go new file mode 100644 index 00000000..b49d39df --- /dev/null +++ b/build/blob/ref_test.go @@ -0,0 +1,69 @@ +package blob + +import "testing" + +// test refs +const ( + refTooLong = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" +) + +func TestParseRef(t *testing.T) { + cases := []struct { + in string + want Ref + }{ + {"mistral:latest", Ref{"mistral", "latest", ""}}, + {"mistral", Ref{"mistral", "latest", ""}}, + {"mistral:30B", Ref{"mistral", "30B", ""}}, + {"mistral:7b", Ref{"mistral", "7b", ""}}, + {"mistral:7b+Q4_0", Ref{"mistral", "7b", "Q4_0"}}, + {"mistral+KQED", Ref{"mistral", "latest", "KQED"}}, + {"mistral.x-3:7b+Q4_0", Ref{"mistral.x-3", "7b", "Q4_0"}}, + + // lowecase build + {"mistral:7b+q4_0", Ref{"mistral", "7b", "Q4_0"}}, + + // Invalid + {"mistral:7b+Q4_0:latest", Ref{"", "", ""}}, + {"mi tral", Ref{"", "", ""}}, + {"llama2:+", Ref{"", "", ""}}, + + // too long + {refTooLong, Ref{"", "", ""}}, + } + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + got := ParseRef(tt.in) + if got != tt.want { + t.Errorf("ParseRef(%q) = %q; want %q", tt.in, got, tt.want) + } + }) + } +} + +func TestRefFull(t *testing.T) { + cases := []struct { + in string + wantShort string + wantFull string + }{ + {"", "", ""}, + {"mistral:7b+x", "mistral:7b", "mistral:7b+X"}, + {"mistral:7b+Q4_0", "mistral:7b", "mistral:7b+Q4_0"}, + {"mistral:latest", "mistral:latest", "mistral:latest+!(MISSING BUILD)"}, + {"mistral", "mistral:latest", "mistral:latest+!(MISSING BUILD)"}, + {"mistral:30b", "mistral:30b", "mistral:30b+!(MISSING BUILD)"}, + } + + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + ref := ParseRef(tt.in) + if g := ref.Short(); g != tt.wantShort { + t.Errorf("Short(%q) = %q; want %q", tt.in, g, tt.wantShort) + } + if g := ref.Full(); g != tt.wantFull { + t.Errorf("Full(%q) = %q; want %q", tt.in, g, tt.wantFull) + } + }) + } +} diff --git a/build/build.go b/build/build.go new file mode 100644 index 00000000..f66138db --- /dev/null +++ b/build/build.go @@ -0,0 +1,165 @@ +package build + +import ( + "encoding/json" + "errors" + "fmt" + "io/fs" + "os" + + "bllamo.com/build/blob" + "bllamo.com/build/internal/blobstore" + "bllamo.com/model" +) + +// Errors +var ( + ErrInvalidRef = errors.New("invalid ref") + ErrUnsupportedModelFormat = errors.New("unsupported model format") + ErrMissingFileType = errors.New("missing 'general.file_type' key") + ErrNoSuchBlob = errors.New("no such blob") + ErrNotFound = errors.New("not found") +) + +type mediaType string + +// Known media types +const ( + mediaTypeModel mediaType = "application/vnd.ollama.image.model" +) + +type Server struct { + st *blobstore.Store +} + +// Open starts a new build server that uses dir as the base directory for all +// build artifacts. If dir is empty, DefaultDir is used. +// +// It returns an error if the provided or default dir cannot be initialized. +func Open(dir string) (*Server, error) { + if dir == "" { + var err error + dir, err = DefaultDir() + if err != nil { + return nil, err + } + } + st, err := blobstore.Open(dir) + if err != nil { + return nil, err + } + return &Server{st: st}, nil +} + +func (s *Server) Build(ref string, f model.File) error { + br := blob.ParseRef(ref) + if !br.Valid() { + return invalidRef(ref) + } + + // 1. Resolve FROM + // a. If it's a local file (gguf), hash it and add it to the store. + // b. If it's a local dir (safetensor), convert to gguf and add to + // store. + // c. If it's a remote file (http), refuse. + // 2. Turn other pragmas into layers, and add them to the store. + // 3. Create a manifest from the layers. + // 4. Store the manifest in the manifest cache + // 5. Done. + + if f.From == "" { + return &model.Error{Pragma: "FROM", Message: "missing"} + } + + var layers []layerJSON + + id, info, size, err := s.importModel(f.From) + if err != nil { + return err + } + layers = append(layers, layerJSON{ + ID: id, + MediaType: mediaTypeModel, + Size: size, + }) + + id, size, err = blobstore.PutString(s.st, f.License) + if err != nil { + return err + } + layers = append(layers, layerJSON{ + ID: id, + MediaType: "text/plain", + Size: size, + }) + + data, err := json.Marshal(manifestJSON{Layers: layers}) + if err != nil { + return err + } + return s.st.Set(br.WithBuild(info.FileType.String()), data) +} + +func (s *Server) LayerFile(digest string) (string, error) { + fileName := s.st.OutputFilename(blobstore.ParseID(digest)) + _, err := os.Stat(fileName) + if errors.Is(err, fs.ErrNotExist) { + return "", fmt.Errorf("%w: %q", ErrNoSuchBlob, digest) + } + return fileName, nil +} + +func (s *Server) Manifest(ref blob.Ref) ([]byte, error) { + data, _, err := s.getManifestData(ref) + if errors.Is(err, fs.ErrNotExist) { + return nil, fmt.Errorf("%w: %q", ErrNotFound, ref) + } + return data, err +} + +// WeightFile returns the absolute path to the weights file for the given model ref. +func (s *Server) WeightsFile(ref blob.Ref) (string, error) { + m, err := s.getManifest(ref) + if err != nil { + return "", err + } + for _, l := range m.Layers { + if l.MediaType == mediaTypeModel { + return s.st.OutputFilename(l.ID), nil + } + } + return "", fmt.Errorf("missing weights layer for %q", ref) +} + +type manifestJSON struct { + // Layers is the list of layers in the manifest. + Layers []layerJSON `json:"layers"` +} + +// Layer is a layer in a model manifest. +type layerJSON struct { + // ID is the ID of the layer. + ID blobstore.ID `json:"digest"` + MediaType mediaType `json:"mediaType"` + Size int64 `json:"size"` +} + +func (s *Server) getManifest(ref blob.Ref) (manifestJSON, error) { + data, path, err := s.getManifestData(ref) + if err != nil { + return manifestJSON{}, err + } + var m manifestJSON + if err := json.Unmarshal(data, &m); err != nil { + return manifestJSON{}, &fs.PathError{Op: "unmarshal", Path: path, Err: err} + } + return m, nil +} + +func (s *Server) getManifestData(ref blob.Ref) (data []byte, path string, err error) { + return s.st.Resolve(ref) +} + +func invalidRef(ref string) error { + return fmt.Errorf("%w: %q", ErrInvalidRef, ref) +} diff --git a/build/build_test.go b/build/build_test.go new file mode 100644 index 00000000..c146717e --- /dev/null +++ b/build/build_test.go @@ -0,0 +1,150 @@ +package build + +import ( + "errors" + "os" + "path/filepath" + "testing" + + "bllamo.com/build/blob" + "bllamo.com/encoding/gguf" + "bllamo.com/model" +) + +func TestServerBuildErrors(t *testing.T) { + dir := t.TempDir() + + s, err := Open(dir) + if err != nil { + t.Fatal(err) + } + + t.Run("FROM pragma missing", func(t *testing.T) { + err := s.Build("foo", model.File{}) + var e *model.Error + if !errors.As(err, &e) { + t.Fatalf("unexpected error: %v", err) + } + if e.Pragma != "FROM" { + t.Errorf("e.Pragma = %s; want FROM", e.Pragma) + } + if e.Message != "missing" { + t.Errorf("e.Message = %s; want missing", e.Message) + } + }) + + t.Run("FROM file not found", func(t *testing.T) { + err := s.Build("x", model.File{From: "bar"}) + if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("Build() err = %v; want file not found", err) + } + }) + + t.Run("FROM gguf", func(t *testing.T) { + w := newWorkDir(t) + // Write a gguf file without general.file_type metadata. + w.write("gguf", ""+ + "GGUF"+ // magic + "\x03\x00\x00\x00"+ // version + "\x00\x00\x00\x00\x00\x00\x00\x00"+ // numMetaValues + "\x00\x00\x00\x00\x00\x00\x00\x00"+ // numTensors + "", + ) + + err := s.Build("x", model.File{From: w.fileName("gguf")}) + if !errors.Is(err, ErrMissingFileType) { + t.Fatalf("Build() err = %#v; want missing file type", err) + } + }) + + t.Run("FROM obscure dir", func(t *testing.T) { + w := newWorkDir(t) + w.mkdirAll("unknown") + if err := s.Build("x", model.File{From: w.fileName("unknown")}); err != ErrUnsupportedModelFormat { + t.Fatalf("Build() err = %#v; want unsupported model type", err) + } + }) + + t.Run("FROM unsupported model type", func(t *testing.T) { + w := newWorkDir(t) + from := w.write("unknown", "unknown content") + err := s.Build("x", model.File{From: from}) + if !errors.Is(err, ErrUnsupportedModelFormat) { + t.Fatalf("Build() err = %#v; want unsupported model type", err) + } + }) +} + +func TestBuildBasicGGUF(t *testing.T) { + w := newWorkDir(t) + w.write("gguf", ""+ + "GGUF"+ // magic + "\x03\x00\x00\x00"+ // version + "\x00\x00\x00\x00\x00\x00\x00\x00"+ // numTensors + "\x01\x00\x00\x00\x00\x00\x00\x00"+ // numMetaValues + + // general.file_type key + "\x11\x00\x00\x00\x00\x00\x00\x00"+ // key length + "general.file_type"+ // key + "\x04\x00\x00\x00"+ // type (uint32) + "\x02\x00\x00\x00\x00\x00\x00\x00"+ // uint32 value + "", + ) + + dir := t.TempDir() + s, err := Open(dir) + if err != nil { + t.Fatal(err) + } + if err := s.Build("x", model.File{From: w.fileName("gguf")}); err != nil { + t.Fatal(err) + } + + filepath.Walk(dir, func(p string, info os.FileInfo, err error) error { + t.Logf("file: %s", p) + return nil + }) + + path, err := s.WeightsFile(blob.ParseRef("x+Q4_0")) + if err != nil { + t.Fatal(err) + } + + info, err := gguf.Stat(path) + if err != nil { + t.Fatal(err) + } + if info.FileType != gguf.TypeQ4_0 { + t.Errorf("info.FileType = %d; want 1", info.FileType) + } +} + +type work struct { + t testing.TB + dir string +} + +func newWorkDir(t *testing.T) work { + return work{t: t, dir: t.TempDir()} +} + +func (w work) write(name, content string) (path string) { + w.t.Helper() + path = w.fileName(name) + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + w.t.Fatal(err) + } + return path +} + +func (w work) fileName(name string) string { + w.t.Helper() + return filepath.Join(w.dir, name) +} + +func (w work) mkdirAll(path string) { + w.t.Helper() + if err := os.MkdirAll(filepath.Join(w.dir, path), 0755); err != nil { + w.t.Fatal(err) + } +} diff --git a/build/convert.go b/build/convert.go new file mode 100644 index 00000000..939bb691 --- /dev/null +++ b/build/convert.go @@ -0,0 +1,12 @@ +package build + +func convertSafeTensorToGGUF(path string) (ggufPath string, err error) { + // TODO: decine on hueristic for converting safetensor to gguf and + // the errors that can be returned. For now, we just say + // "unsupported", however it may be intended to be a valid safe + // tensor but we hit an error in the conversion. + // + // I (bmizernay) think this will naturally evolve as we implement + // the conversion. + return "", ErrUnsupportedModelFormat +} diff --git a/build/default.go b/build/default.go new file mode 100644 index 00000000..927301d8 --- /dev/null +++ b/build/default.go @@ -0,0 +1,28 @@ +package build + +import ( + "os" + "path/filepath" + "sync" +) + +var ( + defaultDir = sync.OnceValues(func() (string, error) { + dir := os.Getenv("OLLAMA_MODELS") + if dir == "" { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + dir = filepath.Join(home, ".ollama", "models") + } + return dir, nil + }) +) + +// DefaultDir returns the default directory for models. It returns the value +// of the OLLAMA_MODELS environment variable if set; otherwise it returns +// "$HOME/.ollama/models". +func DefaultDir() (string, error) { + return defaultDir() +} diff --git a/build/import.go b/build/import.go new file mode 100644 index 00000000..412989a6 --- /dev/null +++ b/build/import.go @@ -0,0 +1,59 @@ +package build + +import ( + "errors" + "fmt" + "os" + + "bllamo.com/build/internal/blobstore" + "bllamo.com/encoding/gguf" +) + +func importError(err error) (blobstore.ID, gguf.Info, int64, error) { + return blobstore.ID{}, gguf.Info{}, 0, err +} + +func (s *Server) importModel(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) { + info, err := os.Stat(path) + if err != nil { + return importError(err) + } + if info.IsDir() { + return s.importSafeTensor(path) + } else { + return s.importGGUF(path) + } +} + +func (s *Server) importGGUF(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) { + f, err := os.Open(path) + if err != nil { + return importError(err) + } + defer f.Close() + + info, err := gguf.StatReader(f) + if errors.Is(err, gguf.ErrBadMagic) { + return importError(ErrUnsupportedModelFormat) + } + if err != nil { + return importError(err) + } + + if info.FileType == 0 { + return importError(fmt.Errorf("%w: %q", ErrMissingFileType, path)) + } + id, size, err := s.st.Put(f) + if err != nil { + return importError(err) + } + return id, info, size, nil +} + +func (s *Server) importSafeTensor(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) { + path, err := convertSafeTensorToGGUF(path) + if err != nil { + return importError(err) + } + return s.importGGUF(path) +} diff --git a/build/internal/blobstore/blob.go b/build/internal/blobstore/blob.go new file mode 100644 index 00000000..e5981416 --- /dev/null +++ b/build/internal/blobstore/blob.go @@ -0,0 +1,378 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package blobstore implements a blob store. +package blobstore + +import ( + "bytes" + "crypto/sha256" + "errors" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "strings" + "time" + + "bllamo.com/build/blob" + "bllamo.com/types/structs" +) + +var ( + ErrInvalidID = errors.New("invalid ID") +) + +const HashSize = 32 + +// An ID is a blob output key, the hash of an output of a computation. +type ID struct { + a [HashSize]byte +} + +func (id ID) MarshalText() ([]byte, error) { + return []byte(id.String()), nil +} + +func (id *ID) UnmarshalText(text []byte) error { + *id = ParseID(string(text)) + return nil +} + +func ParseID(s string) ID { + const prefix = "sha256-" + h, ok := strings.CutPrefix(s, prefix) + if !ok { + return ID{} + } + + if len(h) != HashSize*2 { + return ID{} + } + + var b []byte + _, err := fmt.Sscanf(h, "%x", &b) + if err != nil { + return ID{} + } + + var id ID + copy(id.a[:], b) + return id +} + +func (id ID) String() string { + if !id.Valid() { + return "" + } + return fmt.Sprintf("sha256-%x", id.a[:]) +} + +func (id ID) Valid() bool { + return id != ID{} +} + +func (id ID) Match(h [HashSize]byte) bool { + return id.a == h +} + +// A Store is a blob store, backed by a file system directory tree. +type Store struct { + dir string + now func() time.Time +} + +// Open opens and returns the store in the given directory. +// +// It is safe for multiple processes on a single machine to use the +// same store directory in a local file system simultaneously. +// They will coordinate using operating system file locks and may +// duplicate effort but will not corrupt the store. +// +// However, it is NOT safe for multiple processes on different machines +// to share a store directory (for example, if the directory were stored +// in a network file system). File locking is notoriously unreliable in +// network file systems and may not suffice to protect the store. +func Open(dir string) (*Store, error) { + info, err := os.Stat(dir) + if err != nil { + return nil, err + } + if !info.IsDir() { + return nil, &fs.PathError{Op: "open", Path: dir, Err: fmt.Errorf("not a directory")} + } + + for _, sub := range []string{"blobs", "manifests"} { + if err := os.MkdirAll(filepath.Join(dir, sub), 0777); err != nil { + return nil, err + } + } + + c := &Store{ + dir: dir, + now: time.Now, + } + return c, nil +} + +// fileName returns the name of the blob file corresponding to the given id. +func (s *Store) fileName(id ID) string { + return filepath.Join(s.dir, "blobs", fmt.Sprintf("sha256-%x", id.a[:])) +} + +// An entryNotFoundError indicates that a store entry was not found, with an +// optional underlying reason. +type entryNotFoundError struct { + Err error +} + +func (e *entryNotFoundError) Error() string { + if e.Err == nil { + return "store entry not found" + } + return fmt.Sprintf("store entry not found: %v", e.Err) +} + +func (e *entryNotFoundError) Unwrap() error { + return e.Err +} + +type Entry struct { + _ structs.Incomparable + + ID ID + Size int64 + Time time.Time // when added to store +} + +// GetFile looks up the blob ID in the store and returns +// the name of the corresponding data file. +func GetFile(s *Store, id ID) (file string, entry Entry, err error) { + entry, err = s.Get(id) + if err != nil { + return "", Entry{}, err + } + file = s.OutputFilename(entry.ID) + info, err := os.Stat(file) + if err != nil { + return "", Entry{}, &entryNotFoundError{Err: err} + } + if info.Size() != entry.Size { + return "", Entry{}, &entryNotFoundError{Err: errors.New("file incomplete")} + } + return file, entry, nil +} + +// GetBytes looks up the blob ID in the store and returns +// the corresponding output bytes. +// GetBytes should only be used for data that can be expected to fit in memory. +func GetBytes(s *Store, id ID) ([]byte, Entry, error) { + entry, err := s.Get(id) + if err != nil { + return nil, entry, err + } + data, _ := os.ReadFile(s.OutputFilename(entry.ID)) + if entry.ID.Match(sha256.Sum256(data)) { + return nil, entry, &entryNotFoundError{Err: errors.New("bad checksum")} + } + return data, entry, nil +} + +// OutputFilename returns the name of the blob file for the given ID. +func (s *Store) OutputFilename(id ID) string { + file := s.fileName(id) + // TODO(bmizerany): touch as "used" for cache trimming. (see + // cache.go in cmd/go/internal/cache for the full reference implementation to go off of. + return file +} + +// Resolve returns the data for the given ref, if any. +// +// TODO: This should ideally return an ID, but the current on +// disk layout is that the actual manifest is stored in the "ref" instead of +// a pointer to a content-addressed blob. I (bmizerany) think we should +// change the on-disk layout to store the manifest in a content-addressed +// blob, and then have the ref point to that blob. This would simplify the +// code, allow us to have integrity checks on the manifest, and clean up +// this interface. +func (s *Store) Resolve(ref blob.Ref) (data []byte, path string, err error) { + path, err = s.refFileName(ref) + if err != nil { + return nil, "", err + } + data, err = os.ReadFile(path) + if err != nil { + return nil, "", &entryNotFoundError{Err: err} + } + return data, path, nil +} + +// Set sets the data for the given ref. +func (s *Store) Set(ref blob.Ref, data []byte) error { + path, err := s.refFileName(ref) + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil { + return err + } + if err := os.WriteFile(path, data, 0666); err != nil { + return err + } + return nil +} + +func (s *Store) refFileName(ref blob.Ref) (string, error) { + if !ref.FullyQualified() { + return "", fmt.Errorf("ref not fully qualified: %q", ref) + } + const cheatTODO = "registry.ollama.ai/library" + return filepath.Join(s.dir, "manifests", cheatTODO, ref.Name(), ref.Tag(), ref.Build()), nil +} + +// Get looks up the blob ID in the store, +// returning the corresponding output ID and file size, if any. +// Note that finding an output ID does not guarantee that the +// saved file for that output ID is still available. +func (s *Store) Get(id ID) (Entry, error) { + file := s.fileName(id) + info, err := os.Stat(file) + if err != nil { + return Entry{}, &entryNotFoundError{Err: err} + } + return Entry{ + ID: id, + Size: info.Size(), + Time: info.ModTime(), + }, nil +} + +func (s *Store) Close() error { + // TODO(bmizerany): return c.Trim() + return nil +} + +// Put stores the data read from the given file into the store as ID. +// +// It may read file twice. The content of file must not change between the +// two passes. +func (s *Store) Put(file io.ReadSeeker) (ID, int64, error) { + return s.put(file) +} + +func PutBytes(s *Store, data []byte) (ID, int64, error) { + return s.Put(bytes.NewReader(data)) +} + +func PutString(s *Store, data string) (ID, int64, error) { + return s.Put(strings.NewReader(data)) +} + +func (s *Store) put(file io.ReadSeeker) (ID, int64, error) { + // Compute output ID. + h := sha256.New() + if _, err := file.Seek(0, 0); err != nil { + return ID{}, 0, err + } + size, err := io.Copy(h, file) + if err != nil { + return ID{}, 0, err + } + var out ID + h.Sum(out.a[:0]) + + // Copy to blob file (if not already present). + if err := s.copyFile(file, out, size); err != nil { + return out, size, err + } + + // TODO: Add to manifest index. + return out, size, nil +} + +// copyFile copies file into the store, expecting it to have the given +// output ID and size, if that file is not present already. +func (s *Store) copyFile(file io.ReadSeeker, out ID, size int64) error { + name := s.fileName(out) + println("name", name) + info, err := os.Stat(name) + if err == nil && info.Size() == size { + // Check hash. + if f, err := os.Open(name); err == nil { + h := sha256.New() + io.Copy(h, f) + f.Close() + var out2 ID + h.Sum(out2.a[:0]) + if out == out2 { + return nil + } + } + // Hash did not match. Fall through and rewrite file. + } + + // Copy file to blobs directory. + mode := os.O_RDWR | os.O_CREATE + if err == nil && info.Size() > size { // shouldn't happen but fix in case + mode |= os.O_TRUNC + } + f, err := os.OpenFile(name, mode, 0666) + if err != nil { + return err + } + defer f.Close() + if size == 0 { + // File now exists with correct size. + // Only one possible zero-length file, so contents are OK too. + // Early return here makes sure there's a "last byte" for code below. + return nil + } + + // From here on, if any of the I/O writing the file fails, + // we make a best-effort attempt to truncate the file f + // before returning, to avoid leaving bad bytes in the file. + + // Copy file to f, but also into h to double-check hash. + if _, err := file.Seek(0, 0); err != nil { + f.Truncate(0) + return err + } + h := sha256.New() + w := io.MultiWriter(f, h) + if _, err := io.CopyN(w, file, size-1); err != nil { + f.Truncate(0) + return err + } + // Check last byte before writing it; writing it will make the size match + // what other processes expect to find and might cause them to start + // using the file. + buf := make([]byte, 1) + if _, err := file.Read(buf); err != nil { + f.Truncate(0) + return err + } + h.Write(buf) + sum := h.Sum(nil) + if !bytes.Equal(sum, out.a[:]) { + f.Truncate(0) + return fmt.Errorf("file content changed underfoot") + } + + // Commit manifest entry. + if _, err := f.Write(buf); err != nil { + f.Truncate(0) + return err + } + if err := f.Close(); err != nil { + // Data might not have been written, + // but file may look like it is the right size. + // To be extra careful, remove stored file. + os.Remove(name) + return err + } + os.Chtimes(name, s.now(), s.now()) // mainly for tests + + return nil +} diff --git a/build/internal/blobstore/blob_test.go b/build/internal/blobstore/blob_test.go new file mode 100644 index 00000000..f54684a2 --- /dev/null +++ b/build/internal/blobstore/blob_test.go @@ -0,0 +1,54 @@ +package blobstore + +import ( + "strings" + "testing" +) + +func TestParseID(t *testing.T) { + const valid = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + var invalid = strings.Repeat("\x00", HashSize*2) + + cases := []struct { + in string + want string + }{ + {"", invalid}, + {"sha256-", invalid}, + {"sha256-" + valid, valid}, + + {"" + valid, invalid}, // no prefix + {"sha123-" + valid, invalid}, // invalid prefix + {"sha256-" + valid[1:], invalid}, // too short + {"sha256-" + valid + "a", invalid}, // too long + {"sha256-!" + valid[1:], invalid}, // invalid hex + } + + for _, tt := range cases { + t.Run("", func(t *testing.T) { + // sanity check + if len(tt.want) > HashSize*2 { + panic("invalid test") + } + + got := ParseID(tt.in) + + wantValid := tt.want != invalid + if wantValid { + if !got.Valid() { + t.Errorf("ParseID(%q).Valid() = false; want true", tt.in) + } + if got.String() != "sha256-"+tt.want { + t.Errorf("ParseID(%q).String() = %q; want %q", tt.in, got.String(), "sha256-"+tt.want) + } + } else { + if got.Valid() { + t.Errorf("ParseID(%q).Valid() = true; want false", tt.in) + } + if got.String() != "" { + t.Errorf("ParseID(%q).String() = %q; want %q", tt.in, got.String(), "") + } + } + }) + } +} diff --git a/build/internal/blobstore/store_test.go b/build/internal/blobstore/store_test.go new file mode 100644 index 00000000..ddcc05aa --- /dev/null +++ b/build/internal/blobstore/store_test.go @@ -0,0 +1,161 @@ +package blobstore + +import ( + "errors" + "iter" + "os" + "path/filepath" + "testing" + "time" + + "bllamo.com/build/blob" + "kr.dev/diff" +) + +const ( + blobNameHello = "sha256-2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824" +) + +func TestStoreBasicBlob(t *testing.T) { + dir := t.TempDir() + + checkDir(t, dir, nil) + + st, err := Open(dir) + if err != nil { + t.Fatal(err) + } + + now := time.Now() + st.now = func() time.Time { return now } + + checkDir(t, dir, []string{ + "blobs/", + "manifests/", + }) + + id, size, err := PutBytes(st, []byte("hello")) + if err != nil { + t.Fatal(err) + } + + if id != ParseID(blobNameHello) { + t.Errorf("unexpected ID: %s", id) + } + if size != 5 { + t.Errorf("unexpected size: %d", size) + } + + checkDir(t, dir, []string{ + "blobs/", + "blobs/" + blobNameHello, + "manifests/", + }) + + got, err := st.Get(id) + if err != nil { + t.Fatal(err) + } + + diff.Test(t, t.Errorf, got, Entry{ + ID: id, + Size: 5, + Time: now, + }) + + file := st.OutputFilename(id) + wantFile := filepath.Join(dir, "blobs", blobNameHello) + if file != wantFile { + t.Errorf("unexpected file: %s", file) + } + + // Check tags + ref := blob.ParseRef("test+KQED") + + t.Logf("resolving %s", ref) + + data, _, err := st.Resolve(ref) + var e *entryNotFoundError + if !errors.As(err, &e) { + t.Fatal(err) + } + if data != nil { + t.Errorf("unexpected data: %q", data) + } + + if err := st.Set(ref, []byte("{}")); err != nil { + t.Fatal(err) + } + + data, _, err = st.Resolve(ref) + if err != nil { + t.Fatal(err) + } + + if g := string(data); g != "{}" { + t.Errorf("g = %q; want %q", g, "{}") + } + + checkDir(t, dir, []string{ + "blobs/", + "blobs/sha256-2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", + "manifests/", + "manifests/registry.ollama.ai/", + "manifests/registry.ollama.ai/library/", + "manifests/registry.ollama.ai/library/test/", + "manifests/registry.ollama.ai/library/test/latest/", + "manifests/registry.ollama.ai/library/test/latest/KQED", + }) +} + +// checkDir checks that the directory at dir contains the files in want. The +// files in want must be relative to dir. +// +// direcotories are suffixed with a slash (e.g. "foo/" instead of "foo"). +// +// want must be in lexicographic order. +func checkDir(t testing.TB, dir string, want []string) { + t.Helper() + + var matches []string + for path, err := range walkDir(dir) { + if err != nil { + t.Fatal(err) + } + t.Logf("found %s", path) + if path == "./" { + continue + } + path = filepath.ToSlash(path) + matches = append(matches, path) + } + + diff.Test(t, t.Errorf, matches, want) +} + +var errStop = errors.New("stop") + +func walkDir(dir string) iter.Seq2[string, error] { + return func(yield func(string, error) bool) { + err := filepath.WalkDir(dir, func(path string, info os.DirEntry, err error) error { + if err != nil { + return err + } + path, err = filepath.Rel(dir, path) + if err != nil { + return err + } + path = filepath.ToSlash(path) + if info.IsDir() { + path += "/" + } + if !yield(path, nil) { + return errStop + } + return nil + }) + if !errors.Is(err, errStop) && err != nil { + yield("", err) + } + } +} diff --git a/client/ollama/apitype/apitype.go b/client/ollama/apitype/apitype.go new file mode 100644 index 00000000..e2def7a2 --- /dev/null +++ b/client/ollama/apitype/apitype.go @@ -0,0 +1,31 @@ +package apitype + +import "time" + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type Model struct { + Ref string `json:"ref"` + Digest string `json:"digest"` + Size int64 `json:"size"` + ModifiedAt int64 `json:"modified"` +} + +func (m Model) Modifed() time.Time { + return time.Unix(0, m.ModifiedAt) +} + +type PushRequest struct { + Name string `json:"name"` // Ref is the official term, "name" is for backward compatibility with exiting clients. + Insecure bool `json:"insecure"` + Stream bool `json:"stream"` +} + +type PushStatus struct { + Status string `json:"status"` + Digest string `json:"digest"` + Total int64 `json:"total"` +} diff --git a/client/ollama/ollama.go b/client/ollama/ollama.go new file mode 100644 index 00000000..2c89c4cd --- /dev/null +++ b/client/ollama/ollama.go @@ -0,0 +1,70 @@ +package ollama + +import ( + "cmp" + "context" + "io/fs" + "iter" + "os" + + "bllamo.com/client/ollama/apitype" + "bllamo.com/oweb" + "bllamo.com/types/empty" +) + +// TODO(bmizerany): PROGRESS INDICATORS!!!! + +const DefaultBaseURL = "http://localhost:11434" + +var envBaseURL = cmp.Or(os.Getenv("OLLAMA_BASE_URL"), DefaultBaseURL) + +// Default returns a new client with the default base URL. +func Default() *Client { + return &Client{BaseURL: envBaseURL} +} + +// I_Acknowledge_This_API_Is_Under_Development is a flag that must be set to +// true for any instance of Client to work. +var I_Acknowledge_This_API_Is_Under_Development bool + +// Client is a client for the Ollama API. +type Client struct { + // BaseURL is the base URL of the Ollama API. + BaseURL string +} + +// Build requests the remote Ollama service to build a model. It uploads any +// source files the server needs. +func (c *Client) Build(ctx context.Context, ref string, modelfile []byte, source fs.FS) error { + panic("not implemented") +} + +// Push requests the remote Ollama service to push a model to the server. +func (c *Client) Push(ctx context.Context, ref string) error { + _, err := oweb.Do[empty.Message](ctx, "POST", c.BaseURL+"/v1/push", apitype.PushRequest{Name: ref}) + return err +} + +func (c *Client) Pull(ctx context.Context, ref string) error { + panic("not implemented") +} + +func (c *Client) List(ctx context.Context) iter.Seq2[apitype.Model, error] { + panic("not implemented") +} + +func (c *Client) Show(ctx context.Context, ref string) (*apitype.Model, error) { + panic("not implemented") +} + +func (c *Client) Remove(ctx context.Context, ref string) error { + panic("not implemented") +} + +func (c *Client) Copy(ctx context.Context, dstRef, srcRef string) error { + panic("not implemented") +} + +func (c *Client) Run(ctx context.Context, ref string, messages []apitype.Message) error { + panic("not implemented") +} diff --git a/cmd/bllamo/bllamo.go b/cmd/bllamo/bllamo.go new file mode 100644 index 00000000..32dadfb7 --- /dev/null +++ b/cmd/bllamo/bllamo.go @@ -0,0 +1,100 @@ +// Bllamo is a (new) tool for managing Ollama models. +// +// Usage: +// +// bllamo [arguments] +// +// The commands are: +// +// build build a model from a Modelfile +// list list all models +// push push a model from an ollama registry +// pull pull a model from an ollama registry +// delete delete a model from an ollama registry +// help display help for a command +package main + +import ( + "cmp" + "context" + "flag" + "fmt" + "net/http" + "os" + + "bllamo.com/api" + "bllamo.com/build" + "bllamo.com/client/ollama" + "bllamo.com/registry" +) + +func main() { + flag.Parse() + args := flag.Args() + if len(args) < 1 { + fmt.Fprintln(os.Stderr, "bllamo: no command provided") + os.Exit(2) + } + if err := Main(flag.Args()...); err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } +} + +var TODOUsage = fmt.Errorf("TODO: usage") + +var commands = map[string]func(ctx context.Context, args ...string) error{ + "build": cmdBuild, + "push": cmdPush, + "serve": cmdServe, + "registry": cmdRegistry, +} + +// Main is the entry point for the blammo command. +func Main(args ...string) error { + cmd := args[0] + args = args[1:] + if f, ok := commands[cmd]; ok { + ctx := context.TODO() + return f(ctx, args...) + } + return fmt.Errorf("blammo: unknown command %q", cmd) +} + +func cmdBuild(ctx context.Context, args ...string) error { + var v struct { + Modelfile string `flag:"f,the Modelfile to use"` + } + + fs := readFlags("build", args, &v) + if fs.NArg() != 1 { + return TODOUsage + } + + modelfile, err := os.ReadFile(cmp.Or(v.Modelfile, "Modelfile")) + if err != nil { + return err + } + return ollama.Default().Build(ctx, args[0], modelfile, os.DirFS(".")) +} + +func cmdRegistry(_ context.Context, _ ...string) error { + var s registry.Server + return http.ListenAndServe(":8888", &s) +} + +func cmdServe(ctx context.Context, args ...string) error { + bs, err := build.Open("") + if err != nil { + return err + } + return http.ListenAndServe(":11434", &api.Server{Build: bs}) +} + +func cmdPush(ctx context.Context, args ...string) error { + fs := readFlags("push", args, nil) + if fs.NArg() != 1 { + return TODOUsage + } + return ollama.Default().Push(ctx, fs.Arg(0)) +} diff --git a/cmd/bllamo/flags.go b/cmd/bllamo/flags.go new file mode 100644 index 00000000..a781c7ce --- /dev/null +++ b/cmd/bllamo/flags.go @@ -0,0 +1,59 @@ +package main + +import ( + "flag" + "fmt" + "reflect" + "strings" +) + +// parseArgs parses the provided args using a flag.FlagSet that is +// dynamically build using reflection for the provided type. The type fields +// that have a "flag" tag are used to build the flags. The flag tag should +// include either a ('-'). Example usage: +// +// func main() { +// var flags struct { +// Modelfile string `flag:"f,path to the Modelfile"` +// } +// +// fs := readFlags(os.Args[1:], &flags) +// fs.Parse(os.Args[1:]) +// } +func readFlags(name string, args []string, v any) *flag.FlagSet { + fs := flag.NewFlagSet(name, flag.ExitOnError) + defer fs.Parse(args) + if v == nil { + return fs + } + + for i := 0; i < reflect.ValueOf(v).NumField(); i++ { + f := reflect.ValueOf(v).Field(i) + if !f.CanSet() { + continue + } + + tag := f.Type().Field(i).Tag.Get("flag") + if tag == "" { + continue + } + var name, usage string + if i := strings.Index(tag, ","); i != -1 { + name = tag[:i] + usage = tag[i+1:] + } else { + name = tag + } + + // TODO(bmizerany): add more types as needed + switch f.Kind() { + case reflect.String: + fs.StringVar(f.Addr().Interface().(*string), name, "", usage) + case reflect.Bool: + fs.BoolVar(f.Addr().Interface().(*bool), name, false, usage) + default: + panic(fmt.Sprintf("unsupported type %v", f.Kind())) + } + } + return fs +} diff --git a/cmd/gguf/gguf.go b/cmd/gguf/gguf.go new file mode 100644 index 00000000..0ea8bf55 --- /dev/null +++ b/cmd/gguf/gguf.go @@ -0,0 +1,97 @@ +// Gguf is a tool for learning about GGUF files. +// +// Usage: +// +// gguf [flags] +package main + +import ( + "flag" + "fmt" + "io" + "log" + "os" + "text/tabwriter" + + "bllamo.com/encoding/gguf" +) + +func main() { + if err := Main(os.Stdout, os.Args[1:]...); err != nil { + log.Fatal(err) + } +} + +func Main(stdout io.Writer, args ...string) error { + fs := flag.NewFlagSet("gguf", flag.ExitOnError) + flagGPU := fs.Uint64("gpu", 0, "use N bytes of GPU memory (default is 0)") + + fs.Usage = func() { + io.WriteString(stdout, "Gguf is a tool for learning about GGUF files.\n") + io.WriteString(stdout, "\n") + io.WriteString(stdout, "Usage:\n") + io.WriteString(stdout, "\n") + io.WriteString(stdout, "\tgguf [flags] \n") + io.WriteString(stdout, "\n") + var numFlags int + fs.VisitAll(func(*flag.Flag) { numFlags++ }) + if numFlags > 0 { + io.WriteString(stdout, "Flags:\n") + fs.PrintDefaults() + } + } + fs.Parse(args) + + if fs.NArg() != 1 { + fs.Usage() + os.Exit(2) + } + + file := fs.Arg(0) + f, err := os.Open(file) + if err != nil { + log.Fatal(err) + } + defer f.Close() + + g, err := gguf.ReadFile(f) + if err != nil { + log.Fatal(err) + } + + tw := tabwriter.NewWriter(stdout, 0, 2, 2, ' ', 0) + defer tw.Flush() + + fmt.Fprintf(tw, "version:\t%d\n", g.Version()) + + for m, err := range g.Metadata { + if err != nil { + log.Fatal(err) + } + if len(m.Values) > 5 { + fmt.Fprintf(tw, "meta:\t%q: ... (%d values)\n", m.Key, len(m.Values)) + } else { + fmt.Fprintf(tw, "meta:\t%q: %v\n", m.Key, m.Values) + } + } + + var i int + var totalLayerBytes uint64 + var offGPU bool + for t, err := range g.Tensors { + if err != nil { + log.Fatal(err) + } + + totalLayerBytes += t.Size + if totalLayerBytes > *flagGPU { + offGPU = true + } + + const msg = "tensor (layer %000d):\t%q\t%s\tdims=%v\toffset=%d\tsize=%d\tonGPU=%v\n" + fmt.Fprintf(tw, msg, i, t.Name, t.Type, t.Dimensions, t.Offset, t.Size, !offGPU) + + i++ + } + return nil +} diff --git a/encoding/gguf/gguf.go b/encoding/gguf/gguf.go new file mode 100644 index 00000000..79e7c98a --- /dev/null +++ b/encoding/gguf/gguf.go @@ -0,0 +1,376 @@ +package gguf + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "os" + "strconv" + "strings" + + "bllamo.com/types/structs" +) + +// TODO(bmizerany): determine a more reasonable value for MaxDimensions. + +// MaxDimensions is the maximum number of dimensions a tensor can have. +const MaxDimensions uint32 = 1e6 + +// Errors +var ( + // ErrBadMagic is returned when the magic bytes at the start of the + // file. This is useful for detecting if the file is not a gguf + // file. + ErrBadMagic = errors.New("gguf: bad magic") + + ErrUnsupportedVersion = errors.New("gguf: unsupported version") + ErrMangled = errors.New("gguf: mangled data") +) + +type Type uint32 + +const ( + TypeF32 Type = 0 + TypeF16 Type = 1 + TypeQ4_0 Type = 2 + TypeQ4_1 Type = 3 + TypeQ5_0 Type = 6 + TypeQ5_1 Type = 7 + TypeQ8_0 Type = 8 + TypeQ8_1 Type = 9 + TypeQ2_K Type = 10 + TypeQ3_K Type = 11 + TypeQ4_K Type = 12 + TypeQ5_K Type = 13 + TypeQ6_K Type = 14 + TypeQ8_K Type = 15 + TypeI8 Type = 16 + TypeI16 Type = 17 + TypeI32 Type = 18 + TypeCount Type = 19 +) + +var typeNames = map[Type]string{ + TypeF32: "F32", + TypeF16: "F16", + TypeQ4_0: "Q4_0", + TypeQ4_1: "Q4_1", + TypeQ5_0: "Q5_0", + TypeQ5_1: "Q5_1", + TypeQ8_0: "Q8_0", + TypeQ8_1: "Q8_1", + TypeQ2_K: "Q2_K", + TypeQ3_K: "Q3_K", + TypeQ4_K: "Q4_K", + TypeQ5_K: "Q5_K", + TypeQ6_K: "Q6_K", + TypeQ8_K: "Q8_K", + TypeI8: "I8", + TypeI16: "I16", + TypeI32: "I32", + TypeCount: "COUNT", +} + +func (t Type) String() string { + if name := typeNames[t]; name != "" { + return name + } + return fmt.Sprintf("(!unknown_type %d!)", t) +} + +// ValueType is the type of a metadata value. +type ValueType uint32 + +func (t ValueType) String() string { + if name := metaTypeNames[t]; name != "" { + return name + } + return fmt.Sprintf("(!unknown_value_type %d!)", t) +} + +const ( + ValueTypeUint8 ValueType = 0 + ValueTypeInt8 ValueType = 1 + ValueTypeUint16 ValueType = 2 + ValueTypeInt16 ValueType = 3 + ValueTypeUint32 ValueType = 4 + ValueTypeInt32 ValueType = 5 + ValueTypeFloat32 ValueType = 6 + ValueTypeBool ValueType = 7 + ValueTypeString ValueType = 8 + ValueTypeArray ValueType = 9 + ValueTypeUint64 ValueType = 10 + ValueTypeInt64 ValueType = 11 + ValueTypeFloat64 ValueType = 12 +) + +var metaTypeNames = map[ValueType]string{ + ValueTypeUint8: "uint8", + ValueTypeInt8: "int8", + ValueTypeUint16: "uint16", + ValueTypeInt16: "int16", + ValueTypeUint32: "uint32", + ValueTypeInt32: "int32", + ValueTypeFloat32: "float32", + ValueTypeBool: "bool", + ValueTypeString: "string", + ValueTypeArray: "array", + ValueTypeUint64: "uint64", + ValueTypeInt64: "int64", + ValueTypeFloat64: "float64", +} + +type TensorInfo struct { + Name string + Dimensions []uint64 + Type Type + Offset uint64 + Size uint64 +} + +type MetaValue struct { + Type ValueType + Value []byte +} + +func (v MetaValue) String() string { + var b strings.Builder + b.WriteString(v.Type.String()) + b.WriteString("(") + switch v.Type { + case ValueTypeArray: + b.WriteString("[...]") + case ValueTypeString: + b.WriteString(strconv.Quote(string(v.Value))) + case ValueTypeBool: + if len(v.Value) == 0 { + b.WriteString("(!invalid bool)") + } + switch v.Value[0] { + case 0: + b.WriteString("false") + case 1: + b.WriteString("true") + default: + b.WriteString("!invalid bool") + } + case ValueTypeUint8, ValueTypeInt8, ValueTypeUint16, ValueTypeInt16, ValueTypeUint32, ValueTypeInt32, ValueTypeUint64, ValueTypeInt64, ValueTypeFloat32, ValueTypeFloat64: + var buf [8]byte + if len(v.Value) < 8 { + copy(buf[:], v.Value) + } + fmt.Fprintf(&b, "%v", binary.LittleEndian.Uint64(buf[:])) + default: + fmt.Fprintf(&b, "%v", v.Value) + } + b.WriteString(")") + return b.String() +} + +type MetaEntry struct { + Key string + Type ValueType + Values []MetaValue +} + +func (e MetaEntry) String() string { + if len(e.Values) == 0 { + return "" + } + return string(e.Values[0].Value) +} + +func (e MetaEntry) Uint32() uint32 { + if len(e.Values) == 0 { + return 0 + } + return binary.LittleEndian.Uint32(e.Values[0].Value) +} + +func (e MetaEntry) FileType() Type { + if len(e.Values) == 0 { + return TypeCount + } + return Type(e.Uint32()) +} + +func (e MetaEntry) GoString() string { + var b strings.Builder + b.WriteString(e.Key) + b.WriteString(": ") + b.WriteString(e.Type.String()) + b.WriteString("(") + for i, v := range e.Values { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(v.String()) + } + b.WriteString(")") + return b.String() +} + +type Info struct { + _ structs.Incomparable // prevent comparison of Info values so we can change the implementation later + + Version int + FileType Type +} + +func Stat(path string) (Info, error) { + f, err := os.Open(path) + if err != nil { + return Info{}, err + } + defer f.Close() + return StatReader(f) +} + +// StatReader reads the header information from r and returns an Info +// struct with the version and file type. +// +// It returns an error if any. +// +// As a special case, it returns ErrBadMagic if the file does not start with +// the magic bytes. This can be used to detect if the file is not a GGUF +// file. +func StatReader(r io.ReadSeeker) (Info, error) { + if _, err := r.Seek(0, 0); err != nil { + return Info{}, err + } + f, err := ReadFile(r) + if err != nil { + return Info{}, err + } + info := Info{Version: f.Version()} + for m, err := range f.Metadata { + if err != nil { + return Info{}, err + } + if m.Key == "general.file_type" { + if m.Type != ValueTypeUint32 { + return Info{}, fmt.Errorf("unexpected type for metadata key %q: %v, want %v", m.Key, m.Type, ValueTypeUint32) + } + info.FileType = m.FileType() + } + } + return info, nil +} + +type File struct { + version uint32 + numMetaValues uint64 + numTensors uint64 + + gr *ggufReader +} + +// ReadFile reads header information from r and returns a File, ready for +// iteration over Metadata and Tensors. +func ReadFile(r io.Reader) (*File, error) { + f, err := readFile(r) + if err != nil { + return nil, err + } + return f, nil +} + +func (f *File) Version() int { + return int(f.version) +} + +// Metadata iterates over the metadata in the file. It must be exhausted +// before calling Tensors. +// +// It is not resumable. +func (f *File) Metadata(yield func(MetaEntry, error) bool) { + var n int + for range f.numMetaValues { + meta, err := f.gr.readMetaEntry() + if err != nil { + err = fmt.Errorf("error reading metadata entry %d: %w", n, err) + yield(MetaEntry{}, err) + return + } + if !yield(meta, nil) { + return + } + n++ + } +} + +// Tensors iterates over the tensors in the file. It must only be called +// after exhausting the metadata iterator. +// +// It is not resumable. +func (f *File) Tensors(yield func(TensorInfo, error) bool) { + var last TensorInfo + for range f.numTensors { + info, err := f.gr.readTensorInfo() + + // If the last tensor had a valid offset, yield it. + // + // NOTE: No tensor should have an offset of 0 because the + // offset is the start of the tensor data which is always + // afer the magic bytes, version, numMetaValues, and + // numTensors, which MUST all be non-zero bytes as per the + // GGUF spec. + if last.Offset > 0 { + if !yield(last, err) { + return + } + } + if err != nil { + yield(TensorInfo{}, err) + return + } + + // Tensor data does not include size, so we need to + // calculate it based on the offset of the previous tensor + // offset to the current. + offset0 := last.Offset + last = info + last.Size = info.Offset - offset0 + } + if last.Offset > 0 { + yield(last, nil) + } +} + +var magicBytes = []byte{0x47, 0x47, 0x55, 0x46} + +func readFile(r io.Reader) (*File, error) { + gr := &ggufReader{r: &reader{r: r}} + magic, err := gr.next(4) + if err != nil { + return nil, errors.Join(err, ErrBadMagic) + } + if !bytes.Equal(magic, magicBytes) { + return nil, ErrBadMagic + } + version, err := gr.readUint32() + if err != nil { + return nil, err + } + if version != 3 { + return nil, fmt.Errorf("%w: %d", ErrUnsupportedVersion, version) + } + numTensors, err := gr.readUint64() + if err != nil { + return nil, err + } + numMetaValues, err := gr.readUint64() + if err != nil { + return nil, err + } + info := &File{ + version: version, + + numMetaValues: numMetaValues, + numTensors: numTensors, + gr: gr, + } + return info, nil +} diff --git a/encoding/gguf/gguf_test.go b/encoding/gguf/gguf_test.go new file mode 100644 index 00000000..2f6aa4f0 --- /dev/null +++ b/encoding/gguf/gguf_test.go @@ -0,0 +1,345 @@ +package gguf + +import ( + "errors" + "io" + "strings" + "testing" + + "kr.dev/diff" +) + +func TestStat(t *testing.T) { + cases := []struct { + name string + data string + wantInfo Info + wantErr error + }{ + { + name: "empty", + wantErr: ErrBadMagic, + }, + { + name: "bad magic", + data: "\xBB\xAA\xDD\x00", + wantErr: ErrBadMagic, + }, + { + name: "bad version", + data: string(magicBytes) + + "\x02\x00\x00\x00" + // version + "", + wantErr: ErrUnsupportedVersion, + }, + { + name: "valid general.file_type", + data: string(magicBytes) + // magic + "\x03\x00\x00\x00" + // version + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + + // general.file_type key + "\x11\x00\x00\x00\x00\x00\x00\x00" + // key length + "general.file_type" + // key + "\x04\x00\x00\x00" + // type (uint32) + "\x01\x00\x00\x00\x00\x00\x00\x00" + // uint32 value + "", + wantInfo: Info{ + Version: 3, + FileType: 1, + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + info, err := StatReader(strings.NewReader(tt.data)) + if tt.wantErr != nil { + if !errors.Is(err, tt.wantErr) { + t.Fatalf("err = %v; want %q", err, tt.wantErr) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + diff.Test(t, t.Errorf, info, tt.wantInfo) + }) + } +} + +func TestReadInfo(t *testing.T) { + cases := []struct { + name string + data string + + wantMeta []MetaEntry + wantTensor []TensorInfo + wantReadErr error + wantMetaErr error + wantTensorErr error + wantInfo Info + }{ + { + name: "empty", + wantReadErr: io.ErrUnexpectedEOF, + }, + { + name: "bad magic", + data: "\xBB\xAA\xDD\x00", + wantReadErr: ErrBadMagic, + }, + { + name: "bad version", + data: string(magicBytes) + + "\x02\x00\x00\x00" + // version + "", + wantReadErr: ErrUnsupportedVersion, + }, + { + name: "no metadata or tensors", + data: string(magicBytes) + // magic + "\x03\x00\x00\x00" + // version + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "", + wantReadErr: nil, + }, + { + name: "good metadata", + data: string(magicBytes) + // magic + "\x03\x00\x00\x00" + // version + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + "\x01\x00\x00\x00\x00\x00\x00\x00" + // key length + "K" + // key + "\x08\x00\x00\x00" + // type (string) + "\x02\x00\x00\x00\x00\x00\x00\x00" + // string length + "VV" + // string value + "", + wantMeta: []MetaEntry{ + {Key: "K", Type: ValueTypeString, Values: []MetaValue{{Type: ValueTypeString, Value: []byte("VV")}}}, + }, + }, + { + name: "good metadata with multiple values", + data: string(magicBytes) + // magic + "\x03\x00\x00\x00" + // version + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "\x02\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + + // MetaEntry 1 + "\x01\x00\x00\x00\x00\x00\x00\x00" + // key length + "x" + // key + "\x08\x00\x00\x00" + // type (string) + "\x02\x00\x00\x00\x00\x00\x00\x00" + // string length + "XX" + // string value + + // MetaEntry 2 + "\x01\x00\x00\x00\x00\x00\x00\x00" + // key length + "y" + // key + "\x04\x00\x00\x00" + // type (uint32) + "\x99\x88\x77\x66" + // uint32 value + "", + wantMeta: []MetaEntry{ + {Key: "x", Type: ValueTypeString, Values: []MetaValue{{ + Type: ValueTypeString, + Value: []byte("XX"), + }}}, + {Key: "y", Type: ValueTypeUint32, Values: []MetaValue{{ + Type: ValueTypeUint32, + Value: []byte{0x99, 0x88, 0x77, 0x66}, + }}}, + }, + }, + { + name: "negative string length in meta key", + data: string(magicBytes) + // magic + "\x03\x00\x00\x00" + // version + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + "\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF" + // key length + "K" + // key + "\x08\x00\x00\x00" + // type (string) + "\x02\x00\x00\x00\x00\x00\x00\x00" + // string length + "VV" + // string value + "", + wantMetaErr: ErrMangled, + }, + + // Tensor tests + { + name: "good tensor", + data: string(magicBytes) + // magic + "\x03\x00\x00\x00" + // version + "\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + + // Tensor 1 + "\x01\x00\x00\x00\x00\x00\x00\x00" + // name length + "t" + + + // dimensions + "\x01\x00\x00\x00" + // dimensions length + "\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0] + + "\x03\x00\x00\x00" + // type (i8) + "\x00\x01\x00\x00\x00\x00\x00\x00" + // offset + "", + wantTensor: []TensorInfo{ + { + Name: "t", + Dimensions: []uint64{1}, + Type: TypeQ4_1, + Offset: 256, + Size: 256, + }, + }, + }, + { + name: "too many dimensions", + data: string(magicBytes) + // magic + "\x03\x00\x00\x00" + // version + "\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + + // Tensor 1 + "\x01\x00\x00\x00\x00\x00\x00\x00" + // name length + "t" + + + "\x00\x00\x00\x01" + // dimensions length + "", + wantTensorErr: ErrMangled, + }, + { + name: "size computed", + data: string(magicBytes) + // magic + "\x03\x00\x00\x00" + // version + "\x02\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + + // Tensor 1 + "\x01\x00\x00\x00\x00\x00\x00\x00" + // name length + "t" + + "\x01\x00\x00\x00" + // dimensions length + "\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0] + "\x03\x00\x00\x00" + // type (i8) + "\x00\x01\x00\x00\x00\x00\x00\x00" + // offset + + // Tensor 2 + "\x01\x00\x00\x00\x00\x00\x00\x00" + // name length + "t" + + "\x01\x00\x00\x00" + // dimensions length + "\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0] + "\x03\x00\x00\x00" + // type (i8) + "\x00\x03\x00\x00\x00\x00\x00\x00" + // offset + "", + wantTensor: []TensorInfo{ + { + Name: "t", + Dimensions: []uint64{1}, + Type: TypeQ4_1, + Offset: 256, + Size: 256, + }, + { + Name: "t", + Dimensions: []uint64{1}, + Type: TypeQ4_1, + Offset: 768, + Size: 512, + }, + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + f, err := ReadFile(strings.NewReader(tt.data)) + if err != nil { + if !errors.Is(err, tt.wantReadErr) { + t.Fatalf("unexpected ReadFile error: %v", err) + } + return + } + + var got []MetaEntry + for meta, err := range f.Metadata { + if !errors.Is(err, tt.wantMetaErr) { + t.Fatalf("err = %v; want %v", err, ErrMangled) + } + if err != nil { + return + } + got = append(got, meta) + } + diff.Test(t, t.Errorf, got, tt.wantMeta) + + var gotT []TensorInfo + for tinfo, err := range f.Tensors { + if !errors.Is(err, tt.wantTensorErr) { + t.Fatalf("err = %v; want %v", err, tt.wantTensorErr) + } + if err != nil { + return + } + gotT = append(gotT, tinfo) + } + diff.Test(t, t.Errorf, gotT, tt.wantTensor) + }) + } +} + +func FuzzReadInfo(f *testing.F) { + f.Add(string(magicBytes)) + f.Add(string(magicBytes) + + "\x03\x00\x00\x00" + // version + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + "\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "") + f.Add(string(magicBytes) + + "\x03\x00\x00\x00" + // version + "\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues + "\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors + "\x01\x00\x00\x00\x00\x00\x00\x00" + // key length + "K" + // key + "\x08\x00\x00\x00" + // type (string) + "\x02\x00\x00\x00\x00\x00\x00\x00" + // string length + "VV" + // string value + "\x01\x00\x00\x00\x00\x00\x00\x00" + // name length + "t" + + "\x01\x00\x00\x00" + // dimensions length + "\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0] + "\x03\x00\x00\x00" + // type (i8) + "\x05\x00\x00\x00\x00\x00\x00\x00" + // offset + "") + + f.Fuzz(func(t *testing.T, data string) { + gf, err := ReadFile(strings.NewReader(data)) + if err != nil { + t.Logf("ReadFile error: %v", err) + t.Skip() + } + for _, err := range gf.Metadata { + if err != nil { + t.Logf("metadata error: %v", err) + t.Skip() + } + } + for tinfo, err := range gf.Tensors { + if err != nil { + t.Logf("tensor error: %v", err) + t.Skip() + } + if tinfo.Offset <= 0 { + t.Logf("invalid tensor offset: %+v", t) + t.Skip() + } + if tinfo.Size <= 0 { + t.Logf("invalid tensor size: %+v", t) + t.Skip() + } + } + }) +} diff --git a/encoding/gguf/ggufio.go b/encoding/gguf/ggufio.go new file mode 100644 index 00000000..5179800b --- /dev/null +++ b/encoding/gguf/ggufio.go @@ -0,0 +1,195 @@ +package gguf + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "iter" +) + +type ggufReader struct { + r *reader + n int +} + +func (r *ggufReader) readMetaEntry() (MetaEntry, error) { + key, err := r.readString() + if err != nil { + return MetaEntry{}, err + } + typ, err := r.readValueType() + if err != nil { + return MetaEntry{}, err + } + var values []MetaValue + for v, err := range r.readMetaValues(typ) { + if err != nil { + err = fmt.Errorf("(key=%q type=%s): %w", key, typ, err) + return MetaEntry{}, err + } + values = append(values, v) + } + return MetaEntry{ + Key: string(key), + Type: typ, + Values: values, + }, nil +} + +func (r *ggufReader) readMetaValue(typ ValueType) (MetaValue, error) { + var value []byte + var err error + switch typ { + case ValueTypeUint8, ValueTypeInt8: + value, err = r.next(1) + case ValueTypeUint16, ValueTypeInt16: + value, err = r.next(2) + case ValueTypeUint32, ValueTypeInt32, ValueTypeFloat32: + value, err = r.next(4) + case ValueTypeUint64, ValueTypeInt64, ValueTypeFloat64: + value, err = r.next(8) + case ValueTypeBool: + value, err = r.next(1) + case ValueTypeString: + value, err = r.readString() + case ValueTypeArray: + err = fmt.Errorf("nested arrays are not supported") + default: + err = fmt.Errorf("unsupported metadata type: %d", typ) + } + if err != nil { + return MetaValue{}, err + } + return MetaValue{ + Type: typ, + Value: bytes.Clone(value), + }, nil +} + +func (r *ggufReader) readMetaValues(typ ValueType) iter.Seq2[MetaValue, error] { + return func(yield func(MetaValue, error) bool) { + if typ == ValueTypeArray { + atyp, err := r.readValueType() + if err != nil { + err = fmt.Errorf("invalid type: %w", err) + yield(MetaValue{}, err) + return + } + n, err := r.readUint64() + if err != nil { + err = fmt.Errorf("invalid length: %w", err) + yield(MetaValue{}, err) + return + } + for i := range n { + v, err := r.readMetaValue(atyp) + if err != nil { + err = fmt.Errorf("invalid entry (type=%s) %d: %w", atyp, i, err) + yield(MetaValue{}, err) + return + } + if !yield(v, nil) { + return + } + } + } else { + v, err := r.readMetaValue(typ) + if err != nil { + err = fmt.Errorf("error reading metadata value: %w", err) + yield(MetaValue{}, err) + return + } + yield(v, nil) + } + } +} + +func (r *ggufReader) readValueType() (ValueType, error) { + typ, err := r.readUint32() + return ValueType(typ), err +} + +func (r *ggufReader) readTensorInfo() (TensorInfo, error) { + name, err := r.readString() + if err != nil { + return TensorInfo{}, err + } + + numDimensions, err := r.readUint32() + if err != nil { + return TensorInfo{}, err + } + if numDimensions > MaxDimensions { + return TensorInfo{}, fmt.Errorf("%w: dimensions length (%d) exceeds %d", ErrMangled, numDimensions, MaxDimensions) + } + + dims := make([]uint64, numDimensions) + for i := range dims { + d, err := r.readUint64() + if err != nil { + return TensorInfo{}, err + } + dims[i] = d + } + typ, err := r.readUint32() + if err != nil { + return TensorInfo{}, err + } + offset, err := r.readUint64() + if err != nil { + return TensorInfo{}, err + } + + // TODO(bmizerany): check offset is multiple of ALIGNMENT + return TensorInfo{ + Name: string(name), + Dimensions: dims, + Type: Type(typ), + Offset: offset, + }, nil +} + +func (r *ggufReader) next(n int) ([]byte, error) { + if n < 0 { + return nil, errors.Join(fmt.Errorf("invalid read length: %d", n), ErrMangled) + } + w := r.r.window() + for len(w) < n { + if r.r.extend() == 0 { + return nil, io.ErrUnexpectedEOF + } + w = r.r.window() + } + r.r.release(n) + r.n += n + return w[:n], nil +} + +func (r *ggufReader) readString() ([]byte, error) { + n, err := r.readUint64() + if err != nil { + return nil, err + } + // TODO(bmizerany): limit max string length + return r.next(int(n)) +} + +func (r *ggufReader) readUint32() (uint32, error) { + b, err := r.next(4) + if err != nil { + return 0, err + } + n := binary.LittleEndian.Uint32(b) + return n, nil +} + +func (r *ggufReader) readUint64() (uint64, error) { + b, err := r.next(8) + if err != nil { + return 0, err + } + n := binary.LittleEndian.Uint64(b) + return n, nil +} diff --git a/encoding/gguf/reader.go b/encoding/gguf/reader.go new file mode 100644 index 00000000..7dadc469 --- /dev/null +++ b/encoding/gguf/reader.go @@ -0,0 +1,70 @@ +package gguf + +import "io" + +// A reader implements a sliding window over an io.Reader. +type reader struct { + data []byte + offset int + r io.Reader + err error +} + +// release discards n bytes from the front of the window. +func (b *reader) release(n int) { + b.offset += n +} + +// window returns the current window. +// The window is invalidated by calls to release or extend. +func (b *reader) window() []byte { + return b.data[b.offset:] +} + +// tuning constants for byteReader.extend. +const ( + newBufferSize = 8 << 10 + minReadSize = newBufferSize >> 2 +) + +// extend extends the window with data from the underlying reader. +func (b *reader) extend() int { + if b.err != nil { + return 0 + } + + remaining := len(b.data) - b.offset + if remaining == 0 { + b.data = b.data[:0] + b.offset = 0 + } + if cap(b.data)-len(b.data) >= minReadSize { + // nothing to do, enough space exists between len and cap. + } else if cap(b.data)-remaining >= minReadSize { + // buffer has enough space if we move the data to the front. + b.compact() + } else { + // otherwise, we must allocate/extend a new buffer + b.grow() + } + remaining += b.offset + n, err := b.r.Read(b.data[remaining:cap(b.data)]) + // reduce length to the existing plus the data we read. + b.data = b.data[:remaining+n] + b.err = err + return n +} + +// grow grows the buffer, moving the active data to the front. +func (b *reader) grow() { + buf := make([]byte, max(cap(b.data)*2, newBufferSize)) + copy(buf, b.data[b.offset:]) + b.data = buf + b.offset = 0 +} + +// compact moves the active data to the front of the buffer. +func (b *reader) compact() { + copy(b.data, b.data[b.offset:]) + b.offset = 0 +} diff --git a/encoding/gguf/testdata/fuzz/FuzzReadInfo/787da6e90e4be491 b/encoding/gguf/testdata/fuzz/FuzzReadInfo/787da6e90e4be491 new file mode 100644 index 00000000..e9bc63cd --- /dev/null +++ b/encoding/gguf/testdata/fuzz/FuzzReadInfo/787da6e90e4be491 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xa6\x00\x00\x00\x00\x00\x00\x00\x02\x000\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00") diff --git a/encoding/gguf/testdata/fuzz/FuzzReadInfo/8b42c37d144cd2c6 b/encoding/gguf/testdata/fuzz/FuzzReadInfo/8b42c37d144cd2c6 new file mode 100644 index 00000000..161b2fbb --- /dev/null +++ b/encoding/gguf/testdata/fuzz/FuzzReadInfo/8b42c37d144cd2c6 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("GGUF\x03\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00") diff --git a/encoding/gguf/testdata/fuzz/FuzzReadInfo/92b890e394a77cfc b/encoding/gguf/testdata/fuzz/FuzzReadInfo/92b890e394a77cfc new file mode 100644 index 00000000..e33f4f37 --- /dev/null +++ b/encoding/gguf/testdata/fuzz/FuzzReadInfo/92b890e394a77cfc @@ -0,0 +1,2 @@ +go test fuzz v1 +string("GGUF\x03\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xfd\xff\xff\xff\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00") diff --git a/encoding/gguf/testdata/fuzz/FuzzReadInfo/9cfd6a48931a2753 b/encoding/gguf/testdata/fuzz/FuzzReadInfo/9cfd6a48931a2753 new file mode 100644 index 00000000..42e7b8fd --- /dev/null +++ b/encoding/gguf/testdata/fuzz/FuzzReadInfo/9cfd6a48931a2753 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("GGUF\x03\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00K\b\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00VV\x01\x00\x00\x00\x00\\x00\\x00\\x00\\x00") diff --git a/encoding/gguf/testdata/fuzz/FuzzReadInfo/a8c5454e2a164af2 b/encoding/gguf/testdata/fuzz/FuzzReadInfo/a8c5454e2a164af2 new file mode 100644 index 00000000..05177332 --- /dev/null +++ b/encoding/gguf/testdata/fuzz/FuzzReadInfo/a8c5454e2a164af2 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xa6\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00") diff --git a/encoding/gguf/testdata/fuzz/FuzzReadInfo/a931e37cb6f932d4 b/encoding/gguf/testdata/fuzz/FuzzReadInfo/a931e37cb6f932d4 new file mode 100644 index 00000000..50588528 --- /dev/null +++ b/encoding/gguf/testdata/fuzz/FuzzReadInfo/a931e37cb6f932d4 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00") diff --git a/encoding/gguf/testdata/fuzz/FuzzReadInfo/bcd20fa73e7351a2 b/encoding/gguf/testdata/fuzz/FuzzReadInfo/bcd20fa73e7351a2 new file mode 100644 index 00000000..cbc68bd2 --- /dev/null +++ b/encoding/gguf/testdata/fuzz/FuzzReadInfo/bcd20fa73e7351a2 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("GGUF\x03\x00\x00\x0000000000000000000000000\xe5") diff --git a/encoding/gguf/testdata/fuzz/FuzzReadInfo/d29846a68e32052d b/encoding/gguf/testdata/fuzz/FuzzReadInfo/d29846a68e32052d new file mode 100644 index 00000000..d6cd186f --- /dev/null +++ b/encoding/gguf/testdata/fuzz/FuzzReadInfo/d29846a68e32052d @@ -0,0 +1,2 @@ +go test fuzz v1 +string("GGUF\x03\x00\x00\x0000000000\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x000\b\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x0000\x01\x00\x00\x00\x00\x00\x00\x000\x01\x00\x001\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\a") diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..54c56120 --- /dev/null +++ b/go.mod @@ -0,0 +1,30 @@ +module bllamo.com + +go 1.22.1 + +require kr.dev/diff v0.3.0 + +require ( + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/compress v1.17.6 // indirect + github.com/klauspost/cpuid/v2 v2.2.6 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/minio/md5-simd v1.1.2 // indirect + github.com/minio/minio-go/v7 v7.0.69 // indirect + github.com/minio/sha256-simd v1.0.1 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e // indirect + github.com/rogpeppe/go-internal v1.9.0 // indirect + github.com/rs/xid v1.5.0 // indirect + golang.org/x/crypto v0.19.0 // indirect + golang.org/x/exp v0.0.0-20220218215828-6cf2b201936e // indirect + golang.org/x/net v0.21.0 // indirect + golang.org/x/sync v0.6.0 // indirect + golang.org/x/sys v0.17.0 // indirect + golang.org/x/text v0.14.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..40dd784c --- /dev/null +++ b/go.sum @@ -0,0 +1,63 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI= +github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= +github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= +github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= +github.com/minio/minio-go/v7 v7.0.69 h1:l8AnsQFyY1xiwa/DaQskY4NXSLA2yrGsW5iD9nRPVS0= +github.com/minio/minio-go/v7 v7.0.69/go.mod h1:XAvOPJQ5Xlzk5o3o/ArO2NMbhSGkimC+bpW/ngRKDmQ= +github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM= +github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg= +github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/exp v0.0.0-20220218215828-6cf2b201936e h1:iWVPgObh6F4UDtjBLK51zsy5UHTPLQwCmsNjCsbKhQ0= +golang.org/x/exp v0.0.0-20220218215828-6cf2b201936e/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +kr.dev/diff v0.3.0 h1:o/T8/tkAq9IuRIuFqCupyKPC5iSY3WXpVZ2p6ZK3Emw= +kr.dev/diff v0.3.0/go.mod h1:XiTaLOg2/PD0cmXY7WQXUR8RAF3RwWpqIQEj910J2NY= diff --git a/model/file.go b/model/file.go new file mode 100644 index 00000000..2f112784 --- /dev/null +++ b/model/file.go @@ -0,0 +1,126 @@ +package model + +import ( + "bufio" + "io" + "iter" + "strings" +) + +type Param struct { + Key string + Value string +} + +type Message struct { + Role string + Content string +} + +type File struct { + // From is a required pragma that specifies the source of the model, + // either on disk, or by reference (see blob.ParseRef). + From string + + // Optional + Params []Param + Template string + System string + Adapter string + Messages []Message + + License string +} + +type Error struct { + Pragma string + Message string +} + +func (e *Error) Error() string { + return e.Pragma + ": " + e.Message +} + +type Pragma struct { + // The pragma name + Name string + + // Args contains the user-defined arguments for the pragma. If no + // arguments were provided, it is nil. + Args []string +} + +func (p Pragma) Arg(i int) string { + if i >= len(p.Args) { + return "" + } + return p.Args[i] +} + +func Pragmas(r io.Reader) iter.Seq2[Pragma, error] { + return func(yield func(Pragma, error) bool) { + sc := bufio.NewScanner(r) + for sc.Scan() { + line := sc.Text() + + // TODO(bmizerany): set a max num fields/args to + // prevent mem bloat + args := strings.Fields(line) + if len(args) == 0 { + continue + } + + p := Pragma{ + Name: strings.ToUpper(args[0]), + } + if p.Name == "MESSAGE" { + // handle special case where message content + // is space separated on the _rest_ of the + // line like: `MESSAGE user Is Ontario in + // Canada?` + panic("TODO") + } + if len(args) > 1 { + p.Args = args[1:] + } + if !yield(p, nil) { + return + } + } + if sc.Err() != nil { + yield(Pragma{}, sc.Err()) + } + } +} + +func Decode(r io.Reader) (File, error) { + var f File + for p, err := range Pragmas(r) { + if err != nil { + return File{}, err + } + switch p.Name { + case "FROM": + f.From = p.Arg(0) + case "PARAMETER": + f.Params = append(f.Params, Param{ + Key: strings.ToLower(p.Arg(0)), + Value: p.Arg(1), + }) + case "TEMPLATE": + f.Template = p.Arg(0) + case "SYSTEM": + f.System = p.Arg(0) + case "ADAPTER": + f.Adapter = p.Arg(0) + case "MESSAGE": + f.Messages = append(f.Messages, Message{ + Role: p.Arg(0), + Content: p.Arg(1), + }) + case "LICENSE": + f.License = p.Arg(0) + } + } + return f, nil +} diff --git a/oweb/oweb.go b/oweb/oweb.go new file mode 100644 index 00000000..a5cb499c --- /dev/null +++ b/oweb/oweb.go @@ -0,0 +1,143 @@ +package oweb + +import ( + "bytes" + "cmp" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "strings" +) + +type Error struct { + Status int `json:"-"` + Code string `json:"code"` + Message string `json:"message"` + Field string `json:"field,omitempty"` + RawBody []byte `json:"-"` +} + +func Missing(field string) error { + return &Error{ + Status: 400, + Code: "missing", + Field: field, + Message: fmt.Sprintf("%s is required", field), + } +} + +func Mistake(code, field, message string) error { + return &Error{ + Status: 400, + Code: code, + Field: field, + Message: fmt.Sprintf("%s: %s", field, message), + } +} + +func Fault(code, message string) error { + return &Error{ + Status: 500, + Code: "fault", + Message: message, + } +} + +func (e *Error) Error() string { + var b strings.Builder + b.WriteString("ollama: ") + b.WriteString(e.Code) + if e.Message != "" { + b.WriteString(": ") + b.WriteString(e.Message) + } + return b.String() +} + +// Convinience errors +var ( + ErrNotFound = &Error{Status: 404, Code: "not_found"} + ErrInternal = &Error{Status: 500, Code: "internal_error"} + ErrMethodNotAllowed = &Error{Status: 405, Code: "method_not_allowed"} +) + +type HandlerFunc func(w http.ResponseWriter, r *http.Request) error + +func Serve(h HandlerFunc, w http.ResponseWriter, r *http.Request) { + if err := h(w, r); err != nil { + // TODO: take a slog.Logger + log.Printf("error: %v", err) + var e *Error + if !errors.As(err, &e) { + e = ErrInternal + } + w.WriteHeader(cmp.Or(e.Status, 400)) + if err := EncodeJSON(w, e); err != nil { + log.Printf("error encoding error: %v", err) + } + } +} + +func DecodeUserJSON[T any](r io.Reader) (*T, error) { + v, err := DecodeJSON[T](r) + var e *json.SyntaxError + if errors.As(err, &e) { + return nil, &Error{Code: "invalid_json", Message: e.Error()} + } + var se *json.UnmarshalTypeError + if errors.As(err, &se) { + return nil, &Error{ + Code: "invalid_json", + Message: fmt.Sprintf("%s (%q) is not a %s", se.Field, se.Value, se.Type), + } + } + return v, err +} + +func DecodeJSON[T any](r io.Reader) (*T, error) { + var v *T + if err := json.NewDecoder(r).Decode(&v); err != nil { + var zero T + return &zero, err + } + return v, nil +} + +func EncodeJSON(w io.Writer, v any) error { + return json.NewEncoder(w).Encode(v) +} + +func Do[Res any](ctx context.Context, method, urlStr string, in any) (*Res, error) { + var body bytes.Buffer + if err := EncodeJSON(&body, in); err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, method, urlStr, &body) + if err != nil { + return nil, err + } + + res, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode/100 != 2 { + var b bytes.Buffer + body := io.TeeReader(res.Body, &b) + e, err := DecodeJSON[Error](body) + if err != nil { + return nil, err + } + e.RawBody = b.Bytes() + return nil, e + } + + return DecodeJSON[Res](res.Body) +} diff --git a/registry/apitypes.go b/registry/apitypes.go new file mode 100644 index 00000000..dc3e22d9 --- /dev/null +++ b/registry/apitypes.go @@ -0,0 +1,27 @@ +package registry + +type Manifest struct { + Layers []Layer `json:"layers"` +} + +type Layer struct { + Digest string `json:"digest"` + MediaType string `json:"mediaType"` + Size int64 `json:"size"` +} + +type PushRequest struct { + Manifest Manifest `json:"manifest"` +} + +type Requirement struct { + Digest string `json:"digest"` + Size int64 `json:"size"` + URL string `json:"url"` +} + +type PushResponse struct { + // Requirements is a list of digests that the client needs to push before + // repushing the manifest. + Requirements []Requirement `json:"requirements,omitempty"` +} diff --git a/registry/client.go b/registry/client.go new file mode 100644 index 00000000..2336ab19 --- /dev/null +++ b/registry/client.go @@ -0,0 +1,50 @@ +package registry + +import ( + "context" + "encoding/json" + "io" + "net/http" + + "bllamo.com/oweb" +) + +type Client struct { + BaseURL string +} + +// Push pushes a manifest to the server. +func (c *Client) Push(ctx context.Context, ref string, manifest []byte) ([]Requirement, error) { + // TODO(bmizerany): backoff + v, err := oweb.Do[PushResponse](ctx, "POST", c.BaseURL+"/v1/push/"+ref, struct { + Manifest json.RawMessage `json:"manifest"` + }{manifest}) + if err != nil { + return nil, err + } + return v.Requirements, nil +} + +func PushLayer(ctx context.Context, dstURL string, size int64, file io.Reader) error { + req, err := http.NewRequest("PUT", dstURL, file) + if err != nil { + return err + } + req.ContentLength = size + + res, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != 200 { + e := &oweb.Error{Status: res.StatusCode} + msg, err := io.ReadAll(res.Body) + if err != nil { + return err + } + // TODO(bmizerany): format error message + e.Message = string(msg) + } + return nil +} diff --git a/registry/server.go b/registry/server.go new file mode 100644 index 00000000..2da28c35 --- /dev/null +++ b/registry/server.go @@ -0,0 +1,117 @@ +// Package implements an Ollama registry client and server +package registry + +import ( + "cmp" + "context" + "errors" + "log" + "net/http" + "strings" + "time" + + "bllamo.com/oweb" + "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" +) + +type Server struct{} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if err := s.serveHTTP(w, r); err != nil { + log.Printf("error: %v", err) + var e *oweb.Error + if !errors.As(err, &e) { + e = oweb.ErrInternal + } + w.WriteHeader(cmp.Or(e.Status, 400)) + if err := oweb.EncodeJSON(w, e); err != nil { + log.Printf("error encoding error: %v", err) + } + } +} + +func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error { + switch { + case strings.HasPrefix(r.URL.Path, "/v1/push/"): + return s.handlePush(w, r) + case strings.HasPrefix(r.URL.Path, "/v1/pull/"): + return s.handlePull(w, r) + } + return oweb.ErrNotFound +} + +func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { + pr, err := oweb.DecodeUserJSON[PushRequest](r.Body) + if err != nil { + return err + } + + mc, err := minio.New("localhost:9000", &minio.Options{ + Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""), + Secure: false, + }) + + // TODO(bmizerany): parallelize + var requirements []Requirement + for _, l := range pr.Manifest.Layers { + if l.Size == 0 { + continue + } + + pushed, err := s.statObject(r.Context(), l.Digest) + if err != nil { + return err + } + if !pushed { + const expires = 1 * time.Hour + signedURL, err := mc.PresignedPutObject(r.Context(), "test", l.Digest, expires) + if err != nil { + return err + } + requirements = append(requirements, Requirement{ + Digest: l.Digest, + Size: l.Size, + + // TODO(bmizerany): use signed+temp urls + URL: signedURL.String(), + }) + } + } + + // TODO(bmizerany): commit to db + // ref, _ := strings.CutPrefix(r.URL.Path, "/v1/push/") + + return oweb.EncodeJSON(w, &PushResponse{Requirements: requirements}) +} + +func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error { + // lookup manifest + panic("TODO") +} + +func (s *Server) statObject(ctx context.Context, digest string) (pushed bool, err error) { + // TODO(bmizerany): hold client on *Server (hack for now) + mc, err := minio.New("localhost:9000", &minio.Options{ + Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""), + Secure: false, + }) + if err != nil { + return false, err + } + + // HEAD the object + _, err = mc.StatObject(ctx, "test", digest, minio.StatObjectOptions{}) + if err != nil { + if isNoSuchKey(err) { + err = nil + } + return false, err + } + return true, nil +} + +func isNoSuchKey(err error) bool { + var e minio.ErrorResponse + return errors.As(err, &e) && e.Code == "NoSuchKey" +} diff --git a/registry/server_test.go b/registry/server_test.go new file mode 100644 index 00000000..1457a0d8 --- /dev/null +++ b/registry/server_test.go @@ -0,0 +1,99 @@ +package registry + +import ( + "context" + "net/http/httptest" + "os/exec" + "strings" + "testing" + "time" + + "github.com/kr/pretty" + "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" + "kr.dev/diff" +) + +func TestPush(t *testing.T) { + startMinio(t) + + s := &Server{} + hs := httptest.NewServer(s) + t.Cleanup(hs.Close) + c := &Client{BaseURL: hs.URL} + + manifest := []byte(`{ + "layers": [ + {"digest": "sha256-1", "size": 1}, + {"digest": "sha256-2", "size": 2}, + {"digest": "sha256-3", "size": 3} + ] + }`) + + got, err := c.Push(context.Background(), "x+y", manifest) + if err != nil { + t.Fatal(err) + } + + diff.Test(t, t.Errorf, got, []Requirement{ + {Digest: "sha256-1", Size: 1}, + {Digest: "sha256-2", Size: 2}, + {Digest: "sha256-3", Size: 3}, + }, diff.ZeroFields[Requirement]("URL")) + + for _, r := range got { + body := strings.NewReader(strings.Repeat("x", int(r.Size))) + if err := PushLayer(context.Background(), r.URL, r.Size, body); err != nil { + t.Fatal(err) + } + } + + got, err = c.Push(context.Background(), "x+y", manifest) + if err != nil { + t.Fatal(err) + } + + if len(got) != 0 { + t.Fatalf("unexpected requirements: % #v", pretty.Formatter(got)) + } +} + +func startMinio(t *testing.T) { + t.Helper() + + dir := t.TempDir() + cmd := exec.Command("minio", "server", "--address", "localhost:9000", dir) + + // TODO(bmizerany): wait delay etc... + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + cmd.Process.Kill() + if err := cmd.Wait(); err != nil { + t.Log(err) + } + }) + + mc, err := minio.New("localhost:9000", &minio.Options{ + Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""), + Secure: false, + }) + if err != nil { + t.Fatal(err) + } + + // wait for server to start + // TODO(bmizerany): use backoff + for { + _, err := mc.ListBuckets(context.Background()) + if err == nil { + break + } + time.Sleep(100 * time.Millisecond) + } + + if err := mc.MakeBucket(context.Background(), "test", minio.MakeBucketOptions{}); err != nil { + t.Fatal(err) + } +} diff --git a/types/empty/message.go b/types/empty/message.go new file mode 100644 index 00000000..ab0f1022 --- /dev/null +++ b/types/empty/message.go @@ -0,0 +1,4 @@ +package empty + +// Message is a placeholder type used when encoding json messages. +type Message struct{} diff --git a/types/structs/structs.go b/types/structs/structs.go new file mode 100644 index 00000000..52929ebf --- /dev/null +++ b/types/structs/structs.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package structs contains the Incomparable type. +package structs + +// Incomparable is a zero-width incomparable type. If added as the +// first field in a struct, it marks that struct as not comparable +// (can't do == or be a map key) and usually doesn't add any width to +// the struct (unless the struct has only small fields). +// +// By making a struct incomparable, you can prevent misuse (prevent +// people from using ==), but also you can shrink generated binaries, +// as the compiler can omit equality funcs from the binary. +type Incomparable [0]func() diff --git a/types/they/want.go b/types/they/want.go new file mode 100644 index 00000000..4a601cc6 --- /dev/null +++ b/types/they/want.go @@ -0,0 +1,12 @@ +package they + +import ( + "net/http" + "strings" +) + +// Want returns true if the request method is method and the request path +// starts with pathPrefix. +func Want(r *http.Request, method string, pathPrefix string) bool { + return r.Method == method && strings.HasPrefix(r.URL.Path, pathPrefix) +}