Add 'x/' from commit 'a10a11b9d371f36b7c3510da32a1d70b74e27bd1'

git-subtree-dir: x
git-subtree-mainline: 7d05a6ee8f44b314fa697a427439e5fa4d78c3d7
git-subtree-split: a10a11b9d371f36b7c3510da32a1d70b74e27bd1
This commit is contained in:
Blake Mizerany 2024-04-03 10:40:23 -07:00
commit adc23d5f96
42 changed files with 3996 additions and 0 deletions

116
x/api/api.go Normal file
View File

@ -0,0 +1,116 @@
package api
import (
"errors"
"fmt"
"net/http"
"os"
"bllamo.com/build"
"bllamo.com/client/ollama/apitype"
"bllamo.com/oweb"
"bllamo.com/registry"
regtype "bllamo.com/registry/apitype"
)
// Common API Errors
var (
errUnqualifiedRef = oweb.Mistake("invalid", "name", "must be fully qualified")
errRefNotFound = oweb.Mistake("not_found", "name", "no such model")
)
type Server struct {
Build *build.Server
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
oweb.Serve(s.serveHTTP, w, r)
}
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error {
switch r.URL.Path {
case "/v1/push":
return s.handlePush(w, r)
default:
return oweb.ErrNotFound
}
}
func want(r *http.Request, method, path string) bool {
return r.Method == method && r.URL.Path == path
}
func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error {
if r.Method != "POST" {
return oweb.ErrMethodNotAllowed
}
params, err := oweb.DecodeJSON[apitype.PushRequest](r.Body)
if err != nil {
return err
}
if params.Name == "" {
return oweb.Missing("name")
}
const registryURLTODO = "http://localhost:8888"
man, err := s.Build.ManifestData(params.Name)
if err != nil {
if errors.Is(err, build.ErrNotFound) {
return errRefNotFound
}
return err
}
c := registry.Client{BaseURL: registryURLTODO}
requirements, err := c.Push(r.Context(), params.Name, man, nil)
if err != nil {
return err
}
var uploads []regtype.CompletePart
for _, rq := range requirements {
l, err := s.Build.LayerFile(rq.Digest)
if err != nil {
return err
}
err = func() error {
f, err := os.Open(l)
if err != nil {
return err
}
defer f.Close()
etag, err := registry.PushLayer(r.Context(), rq.URL, rq.Offset, rq.Size, f)
if err != nil {
return err
}
uploads = append(uploads, regtype.CompletePart{
URL: rq.URL,
ETag: etag,
})
return nil
}()
if err != nil {
return err
}
}
// commit the manifest to the registry
requirements, err = c.Push(r.Context(), params.Name, man, &registry.PushParams{
Uploaded: uploads,
})
if err != nil {
return err
}
for _, r := range requirements {
err = errors.Join(err, fmt.Errorf("push failed for %q", r.Digest))
}
return err
}
func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error {
return oweb.ErrNotFound
}

295
x/build/blob/ref.go Normal file
View File

@ -0,0 +1,295 @@
package blob
import (
"cmp"
"fmt"
"slices"
"strings"
)
// Levels of concreteness
const (
domain = iota
namespace
name
tag
build
)
// Ref is an opaque reference to a blob.
//
// It is comparable and can be used as a map key.
//
// Users or Ref must check Valid before using it.
type Ref struct {
domain string
namespace string
name string
tag string
build string
}
// WithDomain returns a copy of r with the provided domain. If the provided
// domain is empty, it returns the short, unqualified copy of r.
func (r Ref) WithDomain(s string) Ref {
return with(r, domain, s)
}
// WithNamespace returns a copy of r with the provided namespace. If the
// provided namespace is empty, it returns the short, unqualified copy of r.
func (r Ref) WithNamespace(s string) Ref {
return with(r, namespace, s)
}
func (r Ref) WithTag(s string) Ref {
return with(r, tag, s)
}
// WithBuild returns a copy of r with the provided build. If the provided
// build is empty, it returns the short, unqualified copy of r.
func (r Ref) WithBuild(s string) Ref {
return with(r, build, s)
}
func with(r Ref, part int, value string) Ref {
if value != "" && !isValidPart(value) {
return Ref{}
}
switch part {
case domain:
r.domain = value
case namespace:
r.namespace = value
case name:
r.name = value
case tag:
r.tag = value
case build:
r.build = value
default:
panic(fmt.Sprintf("invalid completeness: %d", part))
}
return r
}
// Format returns a string representation of the ref with the given
// concreteness. If a part is missing, it is replaced with a loud
// placeholder.
func (r Ref) Full() string {
r.domain = cmp.Or(r.domain, "!(MISSING DOMAIN)")
r.namespace = cmp.Or(r.namespace, "!(MISSING NAMESPACE)")
r.name = cmp.Or(r.name, "!(MISSING NAME)")
r.tag = cmp.Or(r.tag, "!(MISSING TAG)")
r.build = cmp.Or(r.build, "!(MISSING BUILD)")
return r.String()
}
func (r Ref) NameAndTag() string {
r.domain = ""
r.namespace = ""
r.build = ""
return r.String()
}
func (r Ref) NameTagAndBuild() string {
r.domain = ""
r.namespace = ""
return r.String()
}
// String returns the fully qualified ref string.
func (r Ref) String() string {
var b strings.Builder
if r.domain != "" {
b.WriteString(r.domain)
b.WriteString("/")
}
if r.namespace != "" {
b.WriteString(r.namespace)
b.WriteString("/")
}
b.WriteString(r.name)
if r.tag != "" {
b.WriteString(":")
b.WriteString(r.tag)
}
if r.build != "" {
b.WriteString("+")
b.WriteString(r.build)
}
return b.String()
}
// Complete returns true if the ref is valid and has no empty parts.
func (r Ref) Complete() bool {
return r.Valid() && !slices.Contains(r.Parts(), "")
}
func (r Ref) CompleteWithoutBuild() bool {
return r.Valid() && !slices.Contains(r.Parts()[:tag], "")
}
// Less returns true if r is less concrete than o; false otherwise.
func (r Ref) Less(o Ref) bool {
rp := r.Parts()
op := o.Parts()
for i := range rp {
if rp[i] < op[i] {
return true
}
}
return false
}
// Parts returns the parts of the ref in order of concreteness.
//
// The length of the returned slice is always 5.
func (r Ref) Parts() []string {
return []string{
domain: r.domain,
namespace: r.namespace,
name: r.name,
tag: r.tag,
build: r.build,
}
}
func (r Ref) Domain() string { return r.namespace }
func (r Ref) Namespace() string { return r.namespace }
func (r Ref) Name() string { return r.name }
func (r Ref) Tag() string { return r.tag }
func (r Ref) Build() string { return r.build }
// ParseRef parses a ref string into a Ref. A ref string is a name, an
// optional tag, and an optional build, separated by colons and pluses.
//
// The name must be valid ascii [a-zA-Z0-9_].
// The tag must be valid ascii [a-zA-Z0-9_].
// The build must be valid ascii [a-zA-Z0-9_].
//
// It returns then zero value if the ref is invalid.
//
// // Valid Examples:
// ParseRef("mistral:latest") returns ("mistral", "latest", "")
// ParseRef("mistral") returns ("mistral", "", "")
// ParseRef("mistral:30B") returns ("mistral", "30B", "")
// ParseRef("mistral:7b") returns ("mistral", "7b", "")
// ParseRef("mistral:7b+Q4_0") returns ("mistral", "7b", "Q4_0")
// ParseRef("mistral+KQED") returns ("mistral", "latest", "KQED")
// ParseRef(".x.:7b+Q4_0:latest") returns (".x.", "7b", "Q4_0")
// ParseRef("-grok-f.oo:7b+Q4_0") returns ("-grok-f.oo", "7b", "Q4_0")
//
// // Invalid Examples:
// ParseRef("m stral") returns ("", "", "") // zero
// ParseRef("... 129 chars ...") returns ("", "", "") // zero
func ParseRef(s string) Ref {
if len(s) > 128 {
return Ref{}
}
if strings.HasPrefix(s, "http://") {
s = s[len("http://"):]
}
if strings.HasPrefix(s, "https://") {
s = s[len("https://"):]
}
var r Ref
state, j := build, len(s)
for i := len(s) - 1; i >= 0; i-- {
c := s[i]
switch c {
case '+':
switch state {
case build:
r.build = s[i+1 : j]
if r.build == "" {
return Ref{}
}
r.build = strings.ToUpper(r.build)
state, j = tag, i
default:
return Ref{}
}
case ':':
switch state {
case build, tag:
r.tag = s[i+1 : j]
if r.tag == "" {
return Ref{}
}
state, j = name, i
default:
return Ref{}
}
case '/':
switch state {
case name, tag, build:
r.name = s[i+1 : j]
state, j = namespace, i
case namespace:
r.namespace = s[i+1 : j]
state, j = domain, i
default:
return Ref{}
}
}
}
// handle the first part based on final state
switch state {
case domain:
r.domain = s[:j]
case namespace:
r.namespace = s[:j]
default:
r.name = s[:j]
}
if !r.Valid() {
return Ref{}
}
return r
}
func (r Ref) Valid() bool {
// Name is required
if !isValidPart(r.name) {
return false
}
// Optional parts must be valid if present
if r.domain != "" && !isValidPart(r.domain) {
return false
}
if r.namespace != "" && !isValidPart(r.namespace) {
return false
}
if r.tag != "" && !isValidPart(r.tag) {
return false
}
if r.build != "" && !isValidPart(r.build) {
return false
}
return true
}
// isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-]
func isValidPart(s string) bool {
if len(s) == 0 {
return false
}
for _, c := range []byte(s) {
if c == '.' || c == '-' {
return true
}
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' {
continue
} else {
return false
}
}
return true
}

80
x/build/blob/ref_test.go Normal file
View File

@ -0,0 +1,80 @@
package blob
import "testing"
// test refs
const (
refTooLong = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
)
var testRefs = map[string]Ref{
"mistral:latest": {name: "mistral", tag: "latest"},
"mistral": {name: "mistral"},
"mistral:30B": {name: "mistral", tag: "30B"},
"mistral:7b": {name: "mistral", tag: "7b"},
"mistral:7b+Q4_0": {name: "mistral", tag: "7b", build: "Q4_0"},
"mistral+KQED": {name: "mistral", build: "KQED"},
"mistral.x-3:7b+Q4_0": {name: "mistral.x-3", tag: "7b", build: "Q4_0"},
"mistral:7b+q4_0": {name: "mistral", tag: "7b", build: "Q4_0"},
"llama2": {name: "llama2"},
// invalid
"mistral:7b+Q4_0:latest": {},
"mi tral": {},
}
func TestRefParts(t *testing.T) {
const wantNumParts = 5
var ref Ref
if len(ref.Parts()) != wantNumParts {
t.Errorf("Parts() = %d; want %d", len(ref.Parts()), wantNumParts)
}
}
func TestParseRef(t *testing.T) {
for s, want := range testRefs {
t.Run(s, func(t *testing.T) {
got := ParseRef(s)
if got != want {
t.Errorf("ParseRef(%q) = %q; want %q", s, got, want)
}
// test round-trip
if ParseRef(got.String()) != got {
t.Errorf("String() = %q; want %q", got.String(), s)
}
})
}
}
func TestRefFull(t *testing.T) {
const empty = "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/!(MISSING NAME):!(MISSING TAG)+!(MISSING BUILD)"
cases := []struct {
in string
wantFull string
}{
{"", empty},
{"example.com/mistral:7b+x", "!(MISSING DOMAIN)/example.com/mistral:7b+X"},
{"example.com/mistral:7b+Q4_0", "!(MISSING DOMAIN)/example.com/mistral:7b+Q4_0"},
{"example.com/x/mistral:latest", "example.com/x/mistral:latest+!(MISSING BUILD)"},
{"example.com/x/mistral:latest+Q4_0", "example.com/x/mistral:latest+Q4_0"},
{"mistral:7b+x", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:7b+X"},
{"mistral:7b+q4_0", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:7b+Q4_0"},
{"mistral:7b+Q4_0", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:7b+Q4_0"},
{"mistral:latest", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:latest+!(MISSING BUILD)"},
{"mistral", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:!(MISSING TAG)+!(MISSING BUILD)"},
{"mistral:30b", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:30b+!(MISSING BUILD)"},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
ref := ParseRef(tt.in)
t.Logf("ParseRef(%q) = %#v", tt.in, ref)
if g := ref.Full(); g != tt.wantFull {
t.Errorf("Full(%q) = %q; want %q", tt.in, g, tt.wantFull)
}
})
}
}

210
x/build/build.go Normal file
View File

@ -0,0 +1,210 @@
package build
import (
"encoding/json"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"bllamo.com/build/blob"
"bllamo.com/build/internal/blobstore"
"bllamo.com/model"
)
// Errors
var (
ErrIncompleteRef = errors.New("unqualified ref")
ErrBuildPresentInRef = errors.New("build present in ref")
ErrUnsupportedModelFormat = errors.New("unsupported model format")
ErrMissingFileType = errors.New("missing 'general.file_type' key")
ErrNotFound = errors.New("not found")
)
type mediaType string
// Known media types
const (
mediaTypeModel mediaType = "application/vnd.ollama.image.model"
)
type Server struct {
st *blobstore.Store
}
// Open starts a new build server that uses dir as the base directory for all
// build artifacts. If dir is empty, DefaultDir is used.
//
// It returns an error if the provided or default dir cannot be initialized.
func Open(dir string) (*Server, error) {
if dir == "" {
var err error
dir, err = DefaultDir()
if err != nil {
return nil, err
}
}
st, err := blobstore.Open(dir)
if err != nil {
return nil, err
}
return &Server{st: st}, nil
}
func (s *Server) Build(ref string, f model.File) error {
br := blob.ParseRef(ref)
if !br.CompleteWithoutBuild() {
return fmt.Errorf("%w: %q", ErrIncompleteRef, ref)
}
// 1. Resolve FROM
// a. If it's a local file (gguf), hash it and add it to the store.
// c. If it's a remote file (http), refuse.
// 2. Turn other pragmas into layers, and add them to the store.
// 3. Create a manifest from the layers.
// 4. Store the manifest in the manifest cache
// 5. Done.
if f.From == "" {
return &model.Error{Pragma: "FROM", Message: "missing"}
}
var layers []layerJSON
id, info, size, err := s.importModel(f.From)
if err != nil {
return err
}
layers = append(layers, layerJSON{
ID: id,
MediaType: mediaTypeModel,
Size: size,
})
id, size, err = blobstore.PutString(s.st, f.License)
if err != nil {
return err
}
layers = append(layers, layerJSON{
ID: id,
MediaType: "text/plain",
Size: size,
})
data, err := json.Marshal(manifestJSON{Layers: layers})
if err != nil {
return err
}
return s.setManifestData(
br.WithBuild(info.FileType.String()),
data,
)
}
func (s *Server) LayerFile(digest string) (string, error) {
fileName := s.st.OutputFilename(blobstore.ParseID(digest))
_, err := os.Stat(fileName)
if errors.Is(err, fs.ErrNotExist) {
return "", fmt.Errorf("%w: %q", ErrNotFound, digest)
}
return fileName, nil
}
func (s *Server) ManifestData(ref string) ([]byte, error) {
data, _, err := s.resolve(blob.ParseRef(ref))
return data, err
}
// WeightFile returns the absolute path to the weights file for the given model ref.
func (s *Server) WeightsFile(ref string) (string, error) {
m, err := s.getManifest(blob.ParseRef(ref))
if err != nil {
return "", err
}
for _, l := range m.Layers {
if l.MediaType == mediaTypeModel {
return s.st.OutputFilename(l.ID), nil
}
}
return "", fmt.Errorf("missing weights layer for %q", ref)
}
// resolve returns the data for the given ref, if any.
//
// TODO: This should ideally return an ID, but the current on
// disk layout is that the actual manifest is stored in the "ref" instead of
// a pointer to a content-addressed blob. I (bmizerany) think we should
// change the on-disk layout to store the manifest in a content-addressed
// blob, and then have the ref point to that blob. This would simplify the
// code, allow us to have integrity checks on the manifest, and clean up
// this interface.
func (s *Server) resolve(ref blob.Ref) (data []byte, path string, err error) {
path, err = s.refFileName(ref)
if err != nil {
return nil, "", err
}
data, err = os.ReadFile(path)
if errors.Is(err, fs.ErrNotExist) {
return nil, "", fmt.Errorf("%w: %q", ErrNotFound, ref)
}
if err != nil {
// do not wrap the error here, as it is likely an I/O error
// and we want to preserve the absraction since we may not
// be on disk later.
return nil, "", fmt.Errorf("manifest read error: %v", err)
}
return data, path, nil
}
func (s *Server) SetManifestData(ref string, data []byte) error {
return s.setManifestData(blob.ParseRef(ref), data)
}
// Set sets the data for the given ref.
func (s *Server) setManifestData(br blob.Ref, data []byte) error {
path, err := s.refFileName(br)
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil {
return err
}
if err := os.WriteFile(path, data, 0666); err != nil {
return err
}
return nil
}
func (s *Server) refFileName(ref blob.Ref) (string, error) {
if !ref.Complete() {
return "", fmt.Errorf("ref not fully qualified: %q", ref)
}
return filepath.Join(s.st.Dir(), "manifests", filepath.Join(ref.Parts()...)), nil
}
type manifestJSON struct {
// Layers is the list of layers in the manifest.
Layers []layerJSON `json:"layers"`
}
// Layer is a layer in a model manifest.
type layerJSON struct {
// ID is the ID of the layer.
ID blobstore.ID `json:"digest"`
MediaType mediaType `json:"mediaType"`
Size int64 `json:"size"`
}
func (s *Server) getManifest(ref blob.Ref) (manifestJSON, error) {
data, path, err := s.resolve(ref)
if err != nil {
return manifestJSON{}, err
}
var m manifestJSON
if err := json.Unmarshal(data, &m); err != nil {
return manifestJSON{}, &fs.PathError{Op: "unmarshal", Path: path, Err: err}
}
return m, nil
}

163
x/build/build_test.go Normal file
View File

@ -0,0 +1,163 @@
package build
import (
"errors"
"os"
"path/filepath"
"testing"
"bllamo.com/encoding/gguf"
"bllamo.com/model"
)
const qualifiedRef = "x/y/z:latest+Q4_0"
func TestServerBuildErrors(t *testing.T) {
dir := t.TempDir()
s, err := Open(dir)
if err != nil {
t.Fatal(err)
}
t.Run("unqualified ref", func(t *testing.T) {
err := s.Build("x", model.File{})
if !errors.Is(err, ErrIncompleteRef) {
t.Fatalf("Build() err = %v; want unqualified ref", err)
}
})
t.Run("FROM pragma missing", func(t *testing.T) {
err := s.Build(qualifiedRef, model.File{})
var e *model.Error
if !errors.As(err, &e) {
t.Fatalf("unexpected error: %v", err)
}
if e.Pragma != "FROM" {
t.Errorf("e.Pragma = %s; want FROM", e.Pragma)
}
if e.Message != "missing" {
t.Errorf("e.Message = %s; want missing", e.Message)
}
})
t.Run("FROM file not found", func(t *testing.T) {
err := s.Build(qualifiedRef, model.File{From: "bar"})
if !errors.Is(err, os.ErrNotExist) {
t.Fatalf("Build() err = %v; want file not found", err)
}
})
t.Run("FROM gguf", func(t *testing.T) {
w := newWorkDir(t)
// Write a gguf file without general.file_type metadata.
w.write("gguf", ""+
"GGUF"+ // magic
"\x03\x00\x00\x00"+ // version
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numMetaValues
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numTensors
"",
)
err := s.Build(qualifiedRef, model.File{From: w.fileName("gguf")})
if !errors.Is(err, ErrMissingFileType) {
t.Fatalf("Build() err = %#v; want missing file type", err)
}
})
t.Run("FROM obscure dir", func(t *testing.T) {
w := newWorkDir(t)
w.mkdirAll("unknown")
if err := s.Build(qualifiedRef, model.File{From: w.fileName("unknown")}); err != ErrUnsupportedModelFormat {
t.Fatalf("Build() err = %#v; want unsupported model type", err)
}
})
t.Run("FROM unsupported model type", func(t *testing.T) {
w := newWorkDir(t)
from := w.write("unknown", "unknown content")
err := s.Build(qualifiedRef, model.File{From: from})
if !errors.Is(err, ErrUnsupportedModelFormat) {
t.Fatalf("Build() err = %#v; want unsupported model type", err)
}
})
}
func TestBuildBasicGGUF(t *testing.T) {
w := newWorkDir(t)
w.write("gguf", ""+
"GGUF"+ // magic
"\x03\x00\x00\x00"+ // version
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numTensors
"\x01\x00\x00\x00\x00\x00\x00\x00"+ // numMetaValues
// general.file_type key
"\x11\x00\x00\x00\x00\x00\x00\x00"+ // key length
"general.file_type"+ // key
"\x04\x00\x00\x00"+ // type (uint32)
"\x02\x00\x00\x00\x00\x00\x00\x00"+ // uint32 value
"",
)
dir := t.TempDir()
s, err := Open(dir)
if err != nil {
t.Fatal(err)
}
if err := s.Build(qualifiedRef, model.File{From: w.fileName("gguf")}); err != nil {
t.Fatal(err)
}
filepath.Walk(dir, func(p string, info os.FileInfo, err error) error {
t.Logf("file: %s", p)
return nil
})
_, err = s.WeightsFile("unknown/y/z:latest+Q4_0")
if !errors.Is(err, ErrNotFound) {
t.Fatalf("WeightsFile() err = %v; want not found", err)
}
path, err := s.WeightsFile("x/y/z:latest+Q4_0")
if err != nil {
t.Fatal(err)
}
info, err := gguf.Stat(path)
if err != nil {
t.Fatal(err)
}
if info.FileType != gguf.TypeQ4_0 {
t.Errorf("info.FileType = %d; want 1", info.FileType)
}
}
type work struct {
t testing.TB
dir string
}
func newWorkDir(t *testing.T) work {
return work{t: t, dir: t.TempDir()}
}
func (w work) write(name, content string) (path string) {
w.t.Helper()
path = w.fileName(name)
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
w.t.Fatal(err)
}
return path
}
func (w work) fileName(name string) string {
w.t.Helper()
return filepath.Join(w.dir, name)
}
func (w work) mkdirAll(path string) {
w.t.Helper()
if err := os.MkdirAll(filepath.Join(w.dir, path), 0755); err != nil {
w.t.Fatal(err)
}
}

12
x/build/convert.go Normal file
View File

@ -0,0 +1,12 @@
package build
func convertSafeTensorToGGUF(path string) (ggufPath string, err error) {
// TODO: decine on hueristic for converting safetensor to gguf and
// the errors that can be returned. For now, we just say
// "unsupported", however it may be intended to be a valid safe
// tensor but we hit an error in the conversion.
//
// I (bmizernay) think this will naturally evolve as we implement
// the conversion.
return "", ErrUnsupportedModelFormat
}

28
x/build/default.go Normal file
View File

@ -0,0 +1,28 @@
package build
import (
"os"
"path/filepath"
"sync"
)
var (
defaultDir = sync.OnceValues(func() (string, error) {
dir := os.Getenv("OLLAMA_MODELS")
if dir == "" {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
dir = filepath.Join(home, ".ollama", "models")
}
return dir, nil
})
)
// DefaultDir returns the default directory for models. It returns the value
// of the OLLAMA_MODELS environment variable if set; otherwise it returns
// "$HOME/.ollama/models".
func DefaultDir() (string, error) {
return defaultDir()
}

59
x/build/import.go Normal file
View File

@ -0,0 +1,59 @@
package build
import (
"errors"
"fmt"
"os"
"bllamo.com/build/internal/blobstore"
"bllamo.com/encoding/gguf"
)
func importError(err error) (blobstore.ID, gguf.Info, int64, error) {
return blobstore.ID{}, gguf.Info{}, 0, err
}
func (s *Server) importModel(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
info, err := os.Stat(path)
if err != nil {
return importError(err)
}
if info.IsDir() {
return s.importSafeTensor(path)
} else {
return s.importGGUF(path)
}
}
func (s *Server) importGGUF(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
f, err := os.Open(path)
if err != nil {
return importError(err)
}
defer f.Close()
info, err := gguf.StatReader(f)
if errors.Is(err, gguf.ErrBadMagic) {
return importError(ErrUnsupportedModelFormat)
}
if err != nil {
return importError(err)
}
if info.FileType == 0 {
return importError(fmt.Errorf("%w: %q", ErrMissingFileType, path))
}
id, size, err := s.st.Put(f)
if err != nil {
return importError(err)
}
return id, info, size, nil
}
func (s *Server) importSafeTensor(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
path, err := convertSafeTensorToGGUF(path)
if err != nil {
return importError(err)
}
return s.importGGUF(path)
}

View File

@ -0,0 +1,329 @@
// Package blobstore implements a blob store.
package blobstore
import (
"bytes"
"crypto/sha256"
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"strings"
"time"
"bllamo.com/types/structs"
)
var (
ErrInvalidID = errors.New("invalid ID")
)
const HashSize = 32
// An ID is a blob output key, the hash of an output of a computation.
type ID struct {
a [HashSize]byte
}
func (id ID) MarshalText() ([]byte, error) {
return []byte(id.String()), nil
}
func (id *ID) UnmarshalText(text []byte) error {
*id = ParseID(string(text))
return nil
}
func ParseID(s string) ID {
const prefix = "sha256-"
h, ok := strings.CutPrefix(s, prefix)
if !ok {
return ID{}
}
if len(h) != HashSize*2 {
return ID{}
}
var b []byte
_, err := fmt.Sscanf(h, "%x", &b)
if err != nil {
return ID{}
}
var id ID
copy(id.a[:], b)
return id
}
func (id ID) String() string {
if !id.Valid() {
return ""
}
return fmt.Sprintf("sha256-%x", id.a[:])
}
func (id ID) Valid() bool {
return id != ID{}
}
func (id ID) Match(h [HashSize]byte) bool {
return id.a == h
}
// A Store is a blob store, backed by a file system directory tree.
type Store struct {
dir string
now func() time.Time
}
// Open opens and returns the store in the given directory.
//
// It is safe for multiple processes on a single machine to use the
// same store directory in a local file system simultaneously.
// They will coordinate using operating system file locks and may
// duplicate effort but will not corrupt the store.
//
// However, it is NOT safe for multiple processes on different machines
// to share a store directory (for example, if the directory were stored
// in a network file system). File locking is notoriously unreliable in
// network file systems and may not suffice to protect the store.
func Open(dir string) (*Store, error) {
info, err := os.Stat(dir)
if err != nil {
return nil, err
}
if !info.IsDir() {
return nil, &fs.PathError{Op: "open", Path: dir, Err: fmt.Errorf("not a directory")}
}
if err := os.MkdirAll(filepath.Join(dir, "blobs"), 0777); err != nil {
return nil, err
}
c := &Store{
dir: dir,
now: time.Now,
}
return c, nil
}
func (s *Store) Dir() string {
return s.dir
}
// fileName returns the name of the blob file corresponding to the given id.
func (s *Store) fileName(id ID) string {
return filepath.Join(s.dir, "blobs", fmt.Sprintf("sha256-%x", id.a[:]))
}
// An entryNotFoundError indicates that a store entry was not found, with an
// optional underlying reason.
type entryNotFoundError struct {
Err error
}
func (e *entryNotFoundError) Error() string {
if e.Err == nil {
return "store entry not found"
}
return fmt.Sprintf("store entry not found: %v", e.Err)
}
func (e *entryNotFoundError) Unwrap() error {
return e.Err
}
type Entry struct {
_ structs.Incomparable
ID ID
Size int64
Time time.Time // when added to store
}
// GetFile looks up the blob ID in the store and returns
// the name of the corresponding data file.
func GetFile(s *Store, id ID) (file string, entry Entry, err error) {
entry, err = s.Get(id)
if err != nil {
return "", Entry{}, err
}
file = s.OutputFilename(entry.ID)
info, err := os.Stat(file)
if err != nil {
return "", Entry{}, &entryNotFoundError{Err: err}
}
if info.Size() != entry.Size {
return "", Entry{}, &entryNotFoundError{Err: errors.New("file incomplete")}
}
return file, entry, nil
}
// GetBytes looks up the blob ID in the store and returns
// the corresponding output bytes.
// GetBytes should only be used for data that can be expected to fit in memory.
func GetBytes(s *Store, id ID) ([]byte, Entry, error) {
entry, err := s.Get(id)
if err != nil {
return nil, entry, err
}
data, _ := os.ReadFile(s.OutputFilename(entry.ID))
if entry.ID.Match(sha256.Sum256(data)) {
return nil, entry, &entryNotFoundError{Err: errors.New("bad checksum")}
}
return data, entry, nil
}
// OutputFilename returns the name of the blob file for the given ID.
func (s *Store) OutputFilename(id ID) string {
file := s.fileName(id)
// TODO(bmizerany): touch as "used" for cache trimming. (see
// cache.go in cmd/go/internal/cache for the full reference implementation to go off of.
return file
}
// Get looks up the blob ID in the store,
// returning the corresponding output ID and file size, if any.
// Note that finding an output ID does not guarantee that the
// saved file for that output ID is still available.
func (s *Store) Get(id ID) (Entry, error) {
file := s.fileName(id)
info, err := os.Stat(file)
if err != nil {
return Entry{}, &entryNotFoundError{Err: err}
}
return Entry{
ID: id,
Size: info.Size(),
Time: info.ModTime(),
}, nil
}
func (s *Store) Close() error {
// TODO(bmizerany): return c.Trim()
return nil
}
// Put stores the data read from the given file into the store as ID.
//
// It may read file twice. The content of file must not change between the
// two passes.
func (s *Store) Put(file io.ReadSeeker) (ID, int64, error) {
return s.put(file)
}
func PutBytes(s *Store, data []byte) (ID, int64, error) {
return s.Put(bytes.NewReader(data))
}
func PutString(s *Store, data string) (ID, int64, error) {
return s.Put(strings.NewReader(data))
}
func (s *Store) put(file io.ReadSeeker) (ID, int64, error) {
// Compute output ID.
h := sha256.New()
if _, err := file.Seek(0, 0); err != nil {
return ID{}, 0, err
}
size, err := io.Copy(h, file)
if err != nil {
return ID{}, 0, err
}
var out ID
h.Sum(out.a[:0])
// Copy to blob file (if not already present).
if err := s.copyFile(file, out, size); err != nil {
return out, size, err
}
// TODO: Add to manifest index.
return out, size, nil
}
// copyFile copies file into the store, expecting it to have the given
// output ID and size, if that file is not present already.
func (s *Store) copyFile(file io.ReadSeeker, out ID, size int64) error {
name := s.fileName(out)
println("name", name)
info, err := os.Stat(name)
if err == nil && info.Size() == size {
// Check hash.
if f, err := os.Open(name); err == nil {
h := sha256.New()
io.Copy(h, f)
f.Close()
var out2 ID
h.Sum(out2.a[:0])
if out == out2 {
return nil
}
}
// Hash did not match. Fall through and rewrite file.
}
// Copy file to blobs directory.
mode := os.O_RDWR | os.O_CREATE
if err == nil && info.Size() > size { // shouldn't happen but fix in case
mode |= os.O_TRUNC
}
f, err := os.OpenFile(name, mode, 0666)
if err != nil {
return err
}
defer f.Close()
if size == 0 {
// File now exists with correct size.
// Only one possible zero-length file, so contents are OK too.
// Early return here makes sure there's a "last byte" for code below.
return nil
}
// From here on, if any of the I/O writing the file fails,
// we make a best-effort attempt to truncate the file f
// before returning, to avoid leaving bad bytes in the file.
// Copy file to f, but also into h to double-check hash.
if _, err := file.Seek(0, 0); err != nil {
f.Truncate(0)
return err
}
h := sha256.New()
w := io.MultiWriter(f, h)
if _, err := io.CopyN(w, file, size-1); err != nil {
f.Truncate(0)
return err
}
// Check last byte before writing it; writing it will make the size match
// what other processes expect to find and might cause them to start
// using the file.
buf := make([]byte, 1)
if _, err := file.Read(buf); err != nil {
f.Truncate(0)
return err
}
h.Write(buf)
sum := h.Sum(nil)
if !bytes.Equal(sum, out.a[:]) {
f.Truncate(0)
return fmt.Errorf("file content changed underfoot")
}
// Commit manifest entry.
if _, err := f.Write(buf); err != nil {
f.Truncate(0)
return err
}
if err := f.Close(); err != nil {
// Data might not have been written,
// but file may look like it is the right size.
// To be extra careful, remove stored file.
os.Remove(name)
return err
}
os.Chtimes(name, s.now(), s.now()) // mainly for tests
return nil
}

View File

@ -0,0 +1,54 @@
package blobstore
import (
"strings"
"testing"
)
func TestParseID(t *testing.T) {
const valid = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
var invalid = strings.Repeat("\x00", HashSize*2)
cases := []struct {
in string
want string
}{
{"", invalid},
{"sha256-", invalid},
{"sha256-" + valid, valid},
{"" + valid, invalid}, // no prefix
{"sha123-" + valid, invalid}, // invalid prefix
{"sha256-" + valid[1:], invalid}, // too short
{"sha256-" + valid + "a", invalid}, // too long
{"sha256-!" + valid[1:], invalid}, // invalid hex
}
for _, tt := range cases {
t.Run("", func(t *testing.T) {
// sanity check
if len(tt.want) > HashSize*2 {
panic("invalid test")
}
got := ParseID(tt.in)
wantValid := tt.want != invalid
if wantValid {
if !got.Valid() {
t.Errorf("ParseID(%q).Valid() = false; want true", tt.in)
}
if got.String() != "sha256-"+tt.want {
t.Errorf("ParseID(%q).String() = %q; want %q", tt.in, got.String(), "sha256-"+tt.want)
}
} else {
if got.Valid() {
t.Errorf("ParseID(%q).Valid() = true; want false", tt.in)
}
if got.String() != "" {
t.Errorf("ParseID(%q).String() = %q; want %q", tt.in, got.String(), "")
}
}
})
}
}

View File

@ -0,0 +1,128 @@
package blobstore
import (
"errors"
"iter"
"os"
"path/filepath"
"testing"
"time"
"bllamo.com/build/blob"
"kr.dev/diff"
)
const (
blobNameHello = "sha256-2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
)
func TestStoreBasicBlob(t *testing.T) {
dir := t.TempDir()
checkDir(t, dir, nil)
st, err := Open(dir)
if err != nil {
t.Fatal(err)
}
now := time.Now()
st.now = func() time.Time { return now }
checkDir(t, dir, []string{
"blobs/",
})
id, size, err := PutBytes(st, []byte("hello"))
if err != nil {
t.Fatal(err)
}
if id != ParseID(blobNameHello) {
t.Errorf("unexpected ID: %s", id)
}
if size != 5 {
t.Errorf("unexpected size: %d", size)
}
checkDir(t, dir, []string{
"blobs/",
"blobs/" + blobNameHello,
})
got, err := st.Get(id)
if err != nil {
t.Fatal(err)
}
diff.Test(t, t.Errorf, got, Entry{
ID: id,
Size: 5,
Time: now,
})
file := st.OutputFilename(id)
wantFile := filepath.Join(dir, "blobs", blobNameHello)
if file != wantFile {
t.Errorf("unexpected file: %s", file)
}
// Check tags
ref := blob.ParseRef("registry.ollama.ai/library/test:latest+KQED")
t.Logf("RESOLVING: %q", ref.Parts())
}
// checkDir checks that the directory at dir contains the files in want. The
// files in want must be relative to dir.
//
// direcotories are suffixed with a slash (e.g. "foo/" instead of "foo").
//
// want must be in lexicographic order.
func checkDir(t testing.TB, dir string, want []string) {
t.Helper()
var matches []string
for path, err := range walkDir(dir) {
t.Helper()
if err != nil {
t.Fatal(err)
}
t.Logf("found %s", path)
if path == "./" {
continue
}
path = filepath.ToSlash(path)
matches = append(matches, path)
}
diff.Test(t, t.Errorf, matches, want)
}
var errStop = errors.New("stop")
func walkDir(dir string) iter.Seq2[string, error] {
return func(yield func(string, error) bool) {
err := filepath.WalkDir(dir, func(path string, info os.DirEntry, err error) error {
if err != nil {
return err
}
path, err = filepath.Rel(dir, path)
if err != nil {
return err
}
path = filepath.ToSlash(path)
if info.IsDir() {
path += "/"
}
if !yield(path, nil) {
return errStop
}
return nil
})
if !errors.Is(err, errStop) && err != nil {
yield("", err)
}
}
}

View File

@ -0,0 +1,31 @@
package apitype
import "time"
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type Model struct {
Ref string `json:"ref"`
Digest string `json:"digest"`
Size int64 `json:"size"`
ModifiedAt int64 `json:"modified"`
}
func (m Model) Modifed() time.Time {
return time.Unix(0, m.ModifiedAt)
}
type PushRequest struct {
Name string `json:"name"` // Ref is the official term, "name" is for backward compatibility with exiting clients.
Insecure bool `json:"insecure"`
Stream bool `json:"stream"`
}
type PushStatus struct {
Status string `json:"status"`
Digest string `json:"digest"`
Total int64 `json:"total"`
}

162
x/client/ollama/ollama.go Normal file
View File

@ -0,0 +1,162 @@
package ollama
import (
"bytes"
"cmp"
"context"
"encoding/json"
"fmt"
"io"
"io/fs"
"iter"
"net/http"
"os"
"strings"
"bllamo.com/client/ollama/apitype"
"bllamo.com/types/empty"
)
// TODO(bmizerany): PROGRESS INDICATORS!!!!
const DefaultBaseURL = "http://localhost:11434"
var envBaseURL = cmp.Or(os.Getenv("OLLAMA_BASE_URL"), DefaultBaseURL)
// Default returns a new client with the default base URL.
func Default() *Client {
return &Client{BaseURL: envBaseURL}
}
// I_Acknowledge_This_API_Is_Under_Development is a flag that must be set to
// true for any instance of Client to work.
var I_Acknowledge_This_API_Is_Under_Development bool
// Client is a client for the Ollama API.
type Client struct {
// BaseURL is the base URL of the Ollama API.
BaseURL string
HTTPClient *http.Client // The HTTP client to use. If nil, http.DefaultClient is used.
}
// Build requests the remote Ollama service to build a model. It uploads any
// source files the server needs.
func (c *Client) Build(ctx context.Context, ref string, modelfile []byte, source fs.FS) error {
panic("not implemented")
}
// Push requests the remote Ollama service to push a model to the server.
func (c *Client) Push(ctx context.Context, ref string) error {
_, err := Do[empty.Message](ctx, c, "POST", "/v1/push", apitype.PushRequest{Name: ref})
return err
}
func (c *Client) Pull(ctx context.Context, ref string) error {
panic("not implemented")
}
func (c *Client) List(ctx context.Context) iter.Seq2[apitype.Model, error] {
panic("not implemented")
}
func (c *Client) Show(ctx context.Context, ref string) (*apitype.Model, error) {
panic("not implemented")
}
func (c *Client) Remove(ctx context.Context, ref string) error {
panic("not implemented")
}
func (c *Client) Copy(ctx context.Context, dstRef, srcRef string) error {
panic("not implemented")
}
func (c *Client) Run(ctx context.Context, ref string, messages []apitype.Message) error {
panic("not implemented")
}
type Error struct {
// Status is the HTTP status code returned by the server.
Status int `json:"status"`
// Code specifies a machine readable code indicating the class of
// error this error is. See http://docs.ollama.com/errors for a full
// list of error codes.
Code string `json:"code"`
// Message is a humage readable message that describes the error. It
// may change across versions of the API, so it should not be used for
// programmatic decisions.
Message string `json:"message,omitempty"`
// Field is the field in the request that caused the error, if any.
Field string `json:"field,omitempty"`
}
func (e *Error) Error() string {
var b strings.Builder
b.WriteString("ollama: ")
b.WriteString(e.Code)
if e.Message != "" {
b.WriteString(": ")
b.WriteString(e.Message)
}
return b.String()
}
// Do encodes in and sends it in a request to the Ollama server and decodes
// the response into Res, or an error response (non-2xx) into an *Error, or
// any error encounted decoding the response.
func Do[Res any](ctx context.Context, c *Client, method, path string, in any) (*Res, error) {
var body bytes.Buffer
// TODO(bmizerany): pool and reuse this buffer AND the encoder
if err := encodeJSON(&body, in); err != nil {
return nil, err
}
urlStr := c.BaseURL + path
req, err := http.NewRequestWithContext(ctx, method, urlStr, &body)
if err != nil {
return nil, err
}
hc := cmp.Or(c.HTTPClient, http.DefaultClient)
res, err := hc.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode/100 != 2 {
var buf bytes.Buffer
body := io.TeeReader(res.Body, &buf)
e, err := decodeJSON[Error](body)
if err != nil {
err := fmt.Errorf("ollama: invalid error response from server (status %d): %q", res.StatusCode, buf.String())
return nil, err
}
return nil, e
}
return decodeJSON[Res](res.Body)
}
// decodeJSON decodes JSON from r into a new value of type T.
//
// NOTE: This is (and encodeJSON) are copies and paste from oweb.go, please
// do not try and consolidate so we can keep ollama/client free from
// dependencies which are moving targets and not pulling enough weight to
// justify their inclusion.
func decodeJSON[T any](r io.Reader) (*T, error) {
var v T
if err := json.NewDecoder(r).Decode(&v); err != nil {
return nil, err
}
return &v, nil
}
// NOTE: see NOT above decodeJSON
func encodeJSON(w io.Writer, v any) error {
// TODO(bmizerany): pool and reuse encoder
return json.NewEncoder(w).Encode(v)
}

100
x/cmd/bllamo/bllamo.go Normal file
View File

@ -0,0 +1,100 @@
// Bllamo is a (new) tool for managing Ollama models.
//
// Usage:
//
// bllamo <command> [arguments]
//
// The commands are:
//
// build build a model from a Modelfile
// list list all models
// push push a model from an ollama registry
// pull pull a model from an ollama registry
// delete delete a model from an ollama registry
// help display help for a command
package main
import (
"cmp"
"context"
"flag"
"fmt"
"net/http"
"os"
"bllamo.com/api"
"bllamo.com/build"
"bllamo.com/client/ollama"
"bllamo.com/registry"
)
func main() {
flag.Parse()
args := flag.Args()
if len(args) < 1 {
fmt.Fprintln(os.Stderr, "bllamo: no command provided")
os.Exit(2)
}
if err := Main(flag.Args()...); err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
}
var TODOUsage = fmt.Errorf("TODO: usage")
var commands = map[string]func(ctx context.Context, args ...string) error{
"build": cmdBuild,
"push": cmdPush,
"serve": cmdServe,
"registry": cmdRegistry,
}
// Main is the entry point for the blammo command.
func Main(args ...string) error {
cmd := args[0]
args = args[1:]
if f, ok := commands[cmd]; ok {
ctx := context.TODO()
return f(ctx, args...)
}
return fmt.Errorf("blammo: unknown command %q", cmd)
}
func cmdBuild(ctx context.Context, args ...string) error {
var v struct {
Modelfile string `flag:"f,the Modelfile to use"`
}
fs := readFlags("build", args, &v)
if fs.NArg() != 1 {
return TODOUsage
}
modelfile, err := os.ReadFile(cmp.Or(v.Modelfile, "Modelfile"))
if err != nil {
return err
}
return ollama.Default().Build(ctx, args[0], modelfile, os.DirFS("."))
}
func cmdRegistry(_ context.Context, _ ...string) error {
var s registry.Server
return http.ListenAndServe(":8888", &s)
}
func cmdServe(ctx context.Context, args ...string) error {
bs, err := build.Open("")
if err != nil {
return err
}
return http.ListenAndServe(":11434", &api.Server{Build: bs})
}
func cmdPush(ctx context.Context, args ...string) error {
fs := readFlags("push", args, nil)
if fs.NArg() != 1 {
return TODOUsage
}
return ollama.Default().Push(ctx, fs.Arg(0))
}

59
x/cmd/bllamo/flags.go Normal file
View File

@ -0,0 +1,59 @@
package main
import (
"flag"
"fmt"
"reflect"
"strings"
)
// parseArgs parses the provided args using a flag.FlagSet that is
// dynamically build using reflection for the provided type. The type fields
// that have a "flag" tag are used to build the flags. The flag tag should
// include either a ('-'). Example usage:
//
// func main() {
// var flags struct {
// Modelfile string `flag:"f,path to the Modelfile"`
// }
//
// fs := readFlags(os.Args[1:], &flags)
// fs.Parse(os.Args[1:])
// }
func readFlags(name string, args []string, v any) *flag.FlagSet {
fs := flag.NewFlagSet(name, flag.ExitOnError)
defer fs.Parse(args)
if v == nil {
return fs
}
for i := 0; i < reflect.ValueOf(v).NumField(); i++ {
f := reflect.ValueOf(v).Field(i)
if !f.CanSet() {
continue
}
tag := f.Type().Field(i).Tag.Get("flag")
if tag == "" {
continue
}
var name, usage string
if i := strings.Index(tag, ","); i != -1 {
name = tag[:i]
usage = tag[i+1:]
} else {
name = tag
}
// TODO(bmizerany): add more types as needed
switch f.Kind() {
case reflect.String:
fs.StringVar(f.Addr().Interface().(*string), name, "", usage)
case reflect.Bool:
fs.BoolVar(f.Addr().Interface().(*bool), name, false, usage)
default:
panic(fmt.Sprintf("unsupported type %v", f.Kind()))
}
}
return fs
}

97
x/cmd/gguf/gguf.go Normal file
View File

@ -0,0 +1,97 @@
// Gguf is a tool for learning about GGUF files.
//
// Usage:
//
// gguf [flags] <file>
package main
import (
"flag"
"fmt"
"io"
"log"
"os"
"text/tabwriter"
"bllamo.com/encoding/gguf"
)
func main() {
if err := Main(os.Stdout, os.Args[1:]...); err != nil {
log.Fatal(err)
}
}
func Main(stdout io.Writer, args ...string) error {
fs := flag.NewFlagSet("gguf", flag.ExitOnError)
flagGPU := fs.Uint64("gpu", 0, "use N bytes of GPU memory (default is 0)")
fs.Usage = func() {
io.WriteString(stdout, "Gguf is a tool for learning about GGUF files.\n")
io.WriteString(stdout, "\n")
io.WriteString(stdout, "Usage:\n")
io.WriteString(stdout, "\n")
io.WriteString(stdout, "\tgguf [flags] <file>\n")
io.WriteString(stdout, "\n")
var numFlags int
fs.VisitAll(func(*flag.Flag) { numFlags++ })
if numFlags > 0 {
io.WriteString(stdout, "Flags:\n")
fs.PrintDefaults()
}
}
fs.Parse(args)
if fs.NArg() != 1 {
fs.Usage()
os.Exit(2)
}
file := fs.Arg(0)
f, err := os.Open(file)
if err != nil {
log.Fatal(err)
}
defer f.Close()
g, err := gguf.ReadFile(f)
if err != nil {
log.Fatal(err)
}
tw := tabwriter.NewWriter(stdout, 0, 2, 2, ' ', 0)
defer tw.Flush()
fmt.Fprintf(tw, "version:\t%d\n", g.Version())
for m, err := range g.Metadata {
if err != nil {
log.Fatal(err)
}
if len(m.Values) > 5 {
fmt.Fprintf(tw, "meta:\t%q: ... (%d values)\n", m.Key, len(m.Values))
} else {
fmt.Fprintf(tw, "meta:\t%q: %v\n", m.Key, m.Values)
}
}
var i int
var totalLayerBytes uint64
var offGPU bool
for t, err := range g.Tensors {
if err != nil {
log.Fatal(err)
}
totalLayerBytes += t.Size
if totalLayerBytes > *flagGPU {
offGPU = true
}
const msg = "tensor (layer %000d):\t%q\t%s\tdims=%v\toffset=%d\tsize=%d\tonGPU=%v\n"
fmt.Fprintf(tw, msg, i, t.Name, t.Type, t.Dimensions, t.Offset, t.Size, !offGPU)
i++
}
return nil
}

376
x/encoding/gguf/gguf.go Normal file
View File

@ -0,0 +1,376 @@
package gguf
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"os"
"strconv"
"strings"
"bllamo.com/types/structs"
)
// TODO(bmizerany): determine a more reasonable value for MaxDimensions.
// MaxDimensions is the maximum number of dimensions a tensor can have.
const MaxDimensions uint32 = 1e6
// Errors
var (
// ErrBadMagic is returned when the magic bytes at the start of the
// file. This is useful for detecting if the file is not a gguf
// file.
ErrBadMagic = errors.New("gguf: bad magic")
ErrUnsupportedVersion = errors.New("gguf: unsupported version")
ErrMangled = errors.New("gguf: mangled data")
)
type Type uint32
const (
TypeF32 Type = 0
TypeF16 Type = 1
TypeQ4_0 Type = 2
TypeQ4_1 Type = 3
TypeQ5_0 Type = 6
TypeQ5_1 Type = 7
TypeQ8_0 Type = 8
TypeQ8_1 Type = 9
TypeQ2_K Type = 10
TypeQ3_K Type = 11
TypeQ4_K Type = 12
TypeQ5_K Type = 13
TypeQ6_K Type = 14
TypeQ8_K Type = 15
TypeI8 Type = 16
TypeI16 Type = 17
TypeI32 Type = 18
TypeCount Type = 19
)
var typeNames = map[Type]string{
TypeF32: "F32",
TypeF16: "F16",
TypeQ4_0: "Q4_0",
TypeQ4_1: "Q4_1",
TypeQ5_0: "Q5_0",
TypeQ5_1: "Q5_1",
TypeQ8_0: "Q8_0",
TypeQ8_1: "Q8_1",
TypeQ2_K: "Q2_K",
TypeQ3_K: "Q3_K",
TypeQ4_K: "Q4_K",
TypeQ5_K: "Q5_K",
TypeQ6_K: "Q6_K",
TypeQ8_K: "Q8_K",
TypeI8: "I8",
TypeI16: "I16",
TypeI32: "I32",
TypeCount: "COUNT",
}
func (t Type) String() string {
if name := typeNames[t]; name != "" {
return name
}
return fmt.Sprintf("(!unknown_type %d!)", t)
}
// ValueType is the type of a metadata value.
type ValueType uint32
func (t ValueType) String() string {
if name := metaTypeNames[t]; name != "" {
return name
}
return fmt.Sprintf("(!unknown_value_type %d!)", t)
}
const (
ValueTypeUint8 ValueType = 0
ValueTypeInt8 ValueType = 1
ValueTypeUint16 ValueType = 2
ValueTypeInt16 ValueType = 3
ValueTypeUint32 ValueType = 4
ValueTypeInt32 ValueType = 5
ValueTypeFloat32 ValueType = 6
ValueTypeBool ValueType = 7
ValueTypeString ValueType = 8
ValueTypeArray ValueType = 9
ValueTypeUint64 ValueType = 10
ValueTypeInt64 ValueType = 11
ValueTypeFloat64 ValueType = 12
)
var metaTypeNames = map[ValueType]string{
ValueTypeUint8: "uint8",
ValueTypeInt8: "int8",
ValueTypeUint16: "uint16",
ValueTypeInt16: "int16",
ValueTypeUint32: "uint32",
ValueTypeInt32: "int32",
ValueTypeFloat32: "float32",
ValueTypeBool: "bool",
ValueTypeString: "string",
ValueTypeArray: "array",
ValueTypeUint64: "uint64",
ValueTypeInt64: "int64",
ValueTypeFloat64: "float64",
}
type TensorInfo struct {
Name string
Dimensions []uint64
Type Type
Offset uint64
Size uint64
}
type MetaValue struct {
Type ValueType
Value []byte
}
func (v MetaValue) String() string {
var b strings.Builder
b.WriteString(v.Type.String())
b.WriteString("(")
switch v.Type {
case ValueTypeArray:
b.WriteString("[...]")
case ValueTypeString:
b.WriteString(strconv.Quote(string(v.Value)))
case ValueTypeBool:
if len(v.Value) == 0 {
b.WriteString("(!invalid bool)")
}
switch v.Value[0] {
case 0:
b.WriteString("false")
case 1:
b.WriteString("true")
default:
b.WriteString("!invalid bool")
}
case ValueTypeUint8, ValueTypeInt8, ValueTypeUint16, ValueTypeInt16, ValueTypeUint32, ValueTypeInt32, ValueTypeUint64, ValueTypeInt64, ValueTypeFloat32, ValueTypeFloat64:
var buf [8]byte
if len(v.Value) < 8 {
copy(buf[:], v.Value)
}
fmt.Fprintf(&b, "%v", binary.LittleEndian.Uint64(buf[:]))
default:
fmt.Fprintf(&b, "%v", v.Value)
}
b.WriteString(")")
return b.String()
}
type MetaEntry struct {
Key string
Type ValueType
Values []MetaValue
}
func (e MetaEntry) String() string {
if len(e.Values) == 0 {
return ""
}
return string(e.Values[0].Value)
}
func (e MetaEntry) Uint32() uint32 {
if len(e.Values) == 0 {
return 0
}
return binary.LittleEndian.Uint32(e.Values[0].Value)
}
func (e MetaEntry) FileType() Type {
if len(e.Values) == 0 {
return TypeCount
}
return Type(e.Uint32())
}
func (e MetaEntry) GoString() string {
var b strings.Builder
b.WriteString(e.Key)
b.WriteString(": ")
b.WriteString(e.Type.String())
b.WriteString("(")
for i, v := range e.Values {
if i > 0 {
b.WriteString(", ")
}
b.WriteString(v.String())
}
b.WriteString(")")
return b.String()
}
type Info struct {
_ structs.Incomparable // prevent comparison of Info values so we can change the implementation later
Version int
FileType Type
}
func Stat(path string) (Info, error) {
f, err := os.Open(path)
if err != nil {
return Info{}, err
}
defer f.Close()
return StatReader(f)
}
// StatReader reads the header information from r and returns an Info
// struct with the version and file type.
//
// It returns an error if any.
//
// As a special case, it returns ErrBadMagic if the file does not start with
// the magic bytes. This can be used to detect if the file is not a GGUF
// file.
func StatReader(r io.ReadSeeker) (Info, error) {
if _, err := r.Seek(0, 0); err != nil {
return Info{}, err
}
f, err := ReadFile(r)
if err != nil {
return Info{}, err
}
info := Info{Version: f.Version()}
for m, err := range f.Metadata {
if err != nil {
return Info{}, err
}
if m.Key == "general.file_type" {
if m.Type != ValueTypeUint32 {
return Info{}, fmt.Errorf("unexpected type for metadata key %q: %v, want %v", m.Key, m.Type, ValueTypeUint32)
}
info.FileType = m.FileType()
}
}
return info, nil
}
type File struct {
version uint32
numMetaValues uint64
numTensors uint64
gr *ggufReader
}
// ReadFile reads header information from r and returns a File, ready for
// iteration over Metadata and Tensors.
func ReadFile(r io.Reader) (*File, error) {
f, err := readFile(r)
if err != nil {
return nil, err
}
return f, nil
}
func (f *File) Version() int {
return int(f.version)
}
// Metadata iterates over the metadata in the file. It must be exhausted
// before calling Tensors.
//
// It is not resumable.
func (f *File) Metadata(yield func(MetaEntry, error) bool) {
var n int
for range f.numMetaValues {
meta, err := f.gr.readMetaEntry()
if err != nil {
err = fmt.Errorf("error reading metadata entry %d: %w", n, err)
yield(MetaEntry{}, err)
return
}
if !yield(meta, nil) {
return
}
n++
}
}
// Tensors iterates over the tensors in the file. It must only be called
// after exhausting the metadata iterator.
//
// It is not resumable.
func (f *File) Tensors(yield func(TensorInfo, error) bool) {
var last TensorInfo
for range f.numTensors {
info, err := f.gr.readTensorInfo()
// If the last tensor had a valid offset, yield it.
//
// NOTE: No tensor should have an offset of 0 because the
// offset is the start of the tensor data which is always
// afer the magic bytes, version, numMetaValues, and
// numTensors, which MUST all be non-zero bytes as per the
// GGUF spec.
if last.Offset > 0 {
if !yield(last, err) {
return
}
}
if err != nil {
yield(TensorInfo{}, err)
return
}
// Tensor data does not include size, so we need to
// calculate it based on the offset of the previous tensor
// offset to the current.
offset0 := last.Offset
last = info
last.Size = info.Offset - offset0
}
if last.Offset > 0 {
yield(last, nil)
}
}
var magicBytes = []byte{0x47, 0x47, 0x55, 0x46}
func readFile(r io.Reader) (*File, error) {
gr := &ggufReader{r: &reader{r: r}}
magic, err := gr.next(4)
if err != nil {
return nil, errors.Join(err, ErrBadMagic)
}
if !bytes.Equal(magic, magicBytes) {
return nil, ErrBadMagic
}
version, err := gr.readUint32()
if err != nil {
return nil, err
}
if version != 3 {
return nil, fmt.Errorf("%w: %d", ErrUnsupportedVersion, version)
}
numTensors, err := gr.readUint64()
if err != nil {
return nil, err
}
numMetaValues, err := gr.readUint64()
if err != nil {
return nil, err
}
info := &File{
version: version,
numMetaValues: numMetaValues,
numTensors: numTensors,
gr: gr,
}
return info, nil
}

View File

@ -0,0 +1,345 @@
package gguf
import (
"errors"
"io"
"strings"
"testing"
"kr.dev/diff"
)
func TestStat(t *testing.T) {
cases := []struct {
name string
data string
wantInfo Info
wantErr error
}{
{
name: "empty",
wantErr: ErrBadMagic,
},
{
name: "bad magic",
data: "\xBB\xAA\xDD\x00",
wantErr: ErrBadMagic,
},
{
name: "bad version",
data: string(magicBytes) +
"\x02\x00\x00\x00" + // version
"",
wantErr: ErrUnsupportedVersion,
},
{
name: "valid general.file_type",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
// general.file_type key
"\x11\x00\x00\x00\x00\x00\x00\x00" + // key length
"general.file_type" + // key
"\x04\x00\x00\x00" + // type (uint32)
"\x01\x00\x00\x00\x00\x00\x00\x00" + // uint32 value
"",
wantInfo: Info{
Version: 3,
FileType: 1,
},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
info, err := StatReader(strings.NewReader(tt.data))
if tt.wantErr != nil {
if !errors.Is(err, tt.wantErr) {
t.Fatalf("err = %v; want %q", err, tt.wantErr)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
diff.Test(t, t.Errorf, info, tt.wantInfo)
})
}
}
func TestReadInfo(t *testing.T) {
cases := []struct {
name string
data string
wantMeta []MetaEntry
wantTensor []TensorInfo
wantReadErr error
wantMetaErr error
wantTensorErr error
wantInfo Info
}{
{
name: "empty",
wantReadErr: io.ErrUnexpectedEOF,
},
{
name: "bad magic",
data: "\xBB\xAA\xDD\x00",
wantReadErr: ErrBadMagic,
},
{
name: "bad version",
data: string(magicBytes) +
"\x02\x00\x00\x00" + // version
"",
wantReadErr: ErrUnsupportedVersion,
},
{
name: "no metadata or tensors",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"",
wantReadErr: nil,
},
{
name: "good metadata",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
"K" + // key
"\x08\x00\x00\x00" + // type (string)
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
"VV" + // string value
"",
wantMeta: []MetaEntry{
{Key: "K", Type: ValueTypeString, Values: []MetaValue{{Type: ValueTypeString, Value: []byte("VV")}}},
},
},
{
name: "good metadata with multiple values",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x02\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
// MetaEntry 1
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
"x" + // key
"\x08\x00\x00\x00" + // type (string)
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
"XX" + // string value
// MetaEntry 2
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
"y" + // key
"\x04\x00\x00\x00" + // type (uint32)
"\x99\x88\x77\x66" + // uint32 value
"",
wantMeta: []MetaEntry{
{Key: "x", Type: ValueTypeString, Values: []MetaValue{{
Type: ValueTypeString,
Value: []byte("XX"),
}}},
{Key: "y", Type: ValueTypeUint32, Values: []MetaValue{{
Type: ValueTypeUint32,
Value: []byte{0x99, 0x88, 0x77, 0x66},
}}},
},
},
{
name: "negative string length in meta key",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF" + // key length
"K" + // key
"\x08\x00\x00\x00" + // type (string)
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
"VV" + // string value
"",
wantMetaErr: ErrMangled,
},
// Tensor tests
{
name: "good tensor",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
// Tensor 1
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
"t" +
// dimensions
"\x01\x00\x00\x00" + // dimensions length
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
"\x03\x00\x00\x00" + // type (i8)
"\x00\x01\x00\x00\x00\x00\x00\x00" + // offset
"",
wantTensor: []TensorInfo{
{
Name: "t",
Dimensions: []uint64{1},
Type: TypeQ4_1,
Offset: 256,
Size: 256,
},
},
},
{
name: "too many dimensions",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
// Tensor 1
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
"t" +
"\x00\x00\x00\x01" + // dimensions length
"",
wantTensorErr: ErrMangled,
},
{
name: "size computed",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x02\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
// Tensor 1
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
"t" +
"\x01\x00\x00\x00" + // dimensions length
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
"\x03\x00\x00\x00" + // type (i8)
"\x00\x01\x00\x00\x00\x00\x00\x00" + // offset
// Tensor 2
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
"t" +
"\x01\x00\x00\x00" + // dimensions length
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
"\x03\x00\x00\x00" + // type (i8)
"\x00\x03\x00\x00\x00\x00\x00\x00" + // offset
"",
wantTensor: []TensorInfo{
{
Name: "t",
Dimensions: []uint64{1},
Type: TypeQ4_1,
Offset: 256,
Size: 256,
},
{
Name: "t",
Dimensions: []uint64{1},
Type: TypeQ4_1,
Offset: 768,
Size: 512,
},
},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
f, err := ReadFile(strings.NewReader(tt.data))
if err != nil {
if !errors.Is(err, tt.wantReadErr) {
t.Fatalf("unexpected ReadFile error: %v", err)
}
return
}
var got []MetaEntry
for meta, err := range f.Metadata {
if !errors.Is(err, tt.wantMetaErr) {
t.Fatalf("err = %v; want %v", err, ErrMangled)
}
if err != nil {
return
}
got = append(got, meta)
}
diff.Test(t, t.Errorf, got, tt.wantMeta)
var gotT []TensorInfo
for tinfo, err := range f.Tensors {
if !errors.Is(err, tt.wantTensorErr) {
t.Fatalf("err = %v; want %v", err, tt.wantTensorErr)
}
if err != nil {
return
}
gotT = append(gotT, tinfo)
}
diff.Test(t, t.Errorf, gotT, tt.wantTensor)
})
}
}
func FuzzReadInfo(f *testing.F) {
f.Add(string(magicBytes))
f.Add(string(magicBytes) +
"\x03\x00\x00\x00" + // version
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"")
f.Add(string(magicBytes) +
"\x03\x00\x00\x00" + // version
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
"K" + // key
"\x08\x00\x00\x00" + // type (string)
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
"VV" + // string value
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
"t" +
"\x01\x00\x00\x00" + // dimensions length
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
"\x03\x00\x00\x00" + // type (i8)
"\x05\x00\x00\x00\x00\x00\x00\x00" + // offset
"")
f.Fuzz(func(t *testing.T, data string) {
gf, err := ReadFile(strings.NewReader(data))
if err != nil {
t.Logf("ReadFile error: %v", err)
t.Skip()
}
for _, err := range gf.Metadata {
if err != nil {
t.Logf("metadata error: %v", err)
t.Skip()
}
}
for tinfo, err := range gf.Tensors {
if err != nil {
t.Logf("tensor error: %v", err)
t.Skip()
}
if tinfo.Offset <= 0 {
t.Logf("invalid tensor offset: %+v", t)
t.Skip()
}
if tinfo.Size <= 0 {
t.Logf("invalid tensor size: %+v", t)
t.Skip()
}
}
})
}

195
x/encoding/gguf/ggufio.go Normal file
View File

@ -0,0 +1,195 @@
package gguf
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"iter"
)
type ggufReader struct {
r *reader
n int
}
func (r *ggufReader) readMetaEntry() (MetaEntry, error) {
key, err := r.readString()
if err != nil {
return MetaEntry{}, err
}
typ, err := r.readValueType()
if err != nil {
return MetaEntry{}, err
}
var values []MetaValue
for v, err := range r.readMetaValues(typ) {
if err != nil {
err = fmt.Errorf("(key=%q type=%s): %w", key, typ, err)
return MetaEntry{}, err
}
values = append(values, v)
}
return MetaEntry{
Key: string(key),
Type: typ,
Values: values,
}, nil
}
func (r *ggufReader) readMetaValue(typ ValueType) (MetaValue, error) {
var value []byte
var err error
switch typ {
case ValueTypeUint8, ValueTypeInt8:
value, err = r.next(1)
case ValueTypeUint16, ValueTypeInt16:
value, err = r.next(2)
case ValueTypeUint32, ValueTypeInt32, ValueTypeFloat32:
value, err = r.next(4)
case ValueTypeUint64, ValueTypeInt64, ValueTypeFloat64:
value, err = r.next(8)
case ValueTypeBool:
value, err = r.next(1)
case ValueTypeString:
value, err = r.readString()
case ValueTypeArray:
err = fmt.Errorf("nested arrays are not supported")
default:
err = fmt.Errorf("unsupported metadata type: %d", typ)
}
if err != nil {
return MetaValue{}, err
}
return MetaValue{
Type: typ,
Value: bytes.Clone(value),
}, nil
}
func (r *ggufReader) readMetaValues(typ ValueType) iter.Seq2[MetaValue, error] {
return func(yield func(MetaValue, error) bool) {
if typ == ValueTypeArray {
atyp, err := r.readValueType()
if err != nil {
err = fmt.Errorf("invalid type: %w", err)
yield(MetaValue{}, err)
return
}
n, err := r.readUint64()
if err != nil {
err = fmt.Errorf("invalid length: %w", err)
yield(MetaValue{}, err)
return
}
for i := range n {
v, err := r.readMetaValue(atyp)
if err != nil {
err = fmt.Errorf("invalid entry (type=%s) %d: %w", atyp, i, err)
yield(MetaValue{}, err)
return
}
if !yield(v, nil) {
return
}
}
} else {
v, err := r.readMetaValue(typ)
if err != nil {
err = fmt.Errorf("error reading metadata value: %w", err)
yield(MetaValue{}, err)
return
}
yield(v, nil)
}
}
}
func (r *ggufReader) readValueType() (ValueType, error) {
typ, err := r.readUint32()
return ValueType(typ), err
}
func (r *ggufReader) readTensorInfo() (TensorInfo, error) {
name, err := r.readString()
if err != nil {
return TensorInfo{}, err
}
numDimensions, err := r.readUint32()
if err != nil {
return TensorInfo{}, err
}
if numDimensions > MaxDimensions {
return TensorInfo{}, fmt.Errorf("%w: dimensions length (%d) exceeds %d", ErrMangled, numDimensions, MaxDimensions)
}
dims := make([]uint64, numDimensions)
for i := range dims {
d, err := r.readUint64()
if err != nil {
return TensorInfo{}, err
}
dims[i] = d
}
typ, err := r.readUint32()
if err != nil {
return TensorInfo{}, err
}
offset, err := r.readUint64()
if err != nil {
return TensorInfo{}, err
}
// TODO(bmizerany): check offset is multiple of ALIGNMENT
return TensorInfo{
Name: string(name),
Dimensions: dims,
Type: Type(typ),
Offset: offset,
}, nil
}
func (r *ggufReader) next(n int) ([]byte, error) {
if n < 0 {
return nil, errors.Join(fmt.Errorf("invalid read length: %d", n), ErrMangled)
}
w := r.r.window()
for len(w) < n {
if r.r.extend() == 0 {
return nil, io.ErrUnexpectedEOF
}
w = r.r.window()
}
r.r.release(n)
r.n += n
return w[:n], nil
}
func (r *ggufReader) readString() ([]byte, error) {
n, err := r.readUint64()
if err != nil {
return nil, err
}
// TODO(bmizerany): limit max string length
return r.next(int(n))
}
func (r *ggufReader) readUint32() (uint32, error) {
b, err := r.next(4)
if err != nil {
return 0, err
}
n := binary.LittleEndian.Uint32(b)
return n, nil
}
func (r *ggufReader) readUint64() (uint64, error) {
b, err := r.next(8)
if err != nil {
return 0, err
}
n := binary.LittleEndian.Uint64(b)
return n, nil
}

70
x/encoding/gguf/reader.go Normal file
View File

@ -0,0 +1,70 @@
package gguf
import "io"
// A reader implements a sliding window over an io.Reader.
type reader struct {
data []byte
offset int
r io.Reader
err error
}
// release discards n bytes from the front of the window.
func (b *reader) release(n int) {
b.offset += n
}
// window returns the current window.
// The window is invalidated by calls to release or extend.
func (b *reader) window() []byte {
return b.data[b.offset:]
}
// tuning constants for byteReader.extend.
const (
newBufferSize = 8 << 10
minReadSize = newBufferSize >> 2
)
// extend extends the window with data from the underlying reader.
func (b *reader) extend() int {
if b.err != nil {
return 0
}
remaining := len(b.data) - b.offset
if remaining == 0 {
b.data = b.data[:0]
b.offset = 0
}
if cap(b.data)-len(b.data) >= minReadSize {
// nothing to do, enough space exists between len and cap.
} else if cap(b.data)-remaining >= minReadSize {
// buffer has enough space if we move the data to the front.
b.compact()
} else {
// otherwise, we must allocate/extend a new buffer
b.grow()
}
remaining += b.offset
n, err := b.r.Read(b.data[remaining:cap(b.data)])
// reduce length to the existing plus the data we read.
b.data = b.data[:remaining+n]
b.err = err
return n
}
// grow grows the buffer, moving the active data to the front.
func (b *reader) grow() {
buf := make([]byte, max(cap(b.data)*2, newBufferSize))
copy(buf, b.data[b.offset:])
b.data = buf
b.offset = 0
}
// compact moves the active data to the front of the buffer.
func (b *reader) compact() {
copy(b.data, b.data[b.offset:])
b.offset = 0
}

View File

@ -0,0 +1,2 @@
go test fuzz v1
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xa6\x00\x00\x00\x00\x00\x00\x00\x02\x000\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")

View File

@ -0,0 +1,2 @@
go test fuzz v1
string("GGUF\x03\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")

View File

@ -0,0 +1,2 @@
go test fuzz v1
string("GGUF\x03\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xfd\xff\xff\xff\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")

View File

@ -0,0 +1,2 @@
go test fuzz v1
string("GGUF\x03\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00K\b\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00VV\x01\x00\x00\x00\x00\\x00\\x00\\x00\\x00")

View File

@ -0,0 +1,2 @@
go test fuzz v1
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xa6\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")

View File

@ -0,0 +1,2 @@
go test fuzz v1
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")

View File

@ -0,0 +1,2 @@
go test fuzz v1
string("GGUF\x03\x00\x00\x0000000000000000000000000\xe5")

View File

@ -0,0 +1,2 @@
go test fuzz v1
string("GGUF\x03\x00\x00\x0000000000\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x000\b\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x0000\x01\x00\x00\x00\x00\x00\x00\x000\x01\x00\x001\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\a")

30
x/go.mod Normal file
View File

@ -0,0 +1,30 @@
module bllamo.com
go 1.22.1
require kr.dev/diff v0.3.0
require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.17.6 // indirect
github.com/klauspost/cpuid/v2 v2.2.6 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
github.com/minio/minio-go/v7 v7.0.69 // indirect
github.com/minio/sha256-simd v1.0.1 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e // indirect
github.com/rogpeppe/go-internal v1.9.0 // indirect
github.com/rs/xid v1.5.0 // indirect
golang.org/x/crypto v0.19.0 // indirect
golang.org/x/exp v0.0.0-20220218215828-6cf2b201936e // indirect
golang.org/x/net v0.21.0 // indirect
golang.org/x/sync v0.6.0 // indirect
golang.org/x/sys v0.17.0 // indirect
golang.org/x/text v0.14.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
)

63
x/go.sum Normal file
View File

@ -0,0 +1,63 @@
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI=
github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc=
github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
github.com/minio/minio-go/v7 v7.0.69 h1:l8AnsQFyY1xiwa/DaQskY4NXSLA2yrGsW5iD9nRPVS0=
github.com/minio/minio-go/v7 v7.0.69/go.mod h1:XAvOPJQ5Xlzk5o3o/ArO2NMbhSGkimC+bpW/ngRKDmQ=
github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg=
github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/exp v0.0.0-20220218215828-6cf2b201936e h1:iWVPgObh6F4UDtjBLK51zsy5UHTPLQwCmsNjCsbKhQ0=
golang.org/x/exp v0.0.0-20220218215828-6cf2b201936e/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
kr.dev/diff v0.3.0 h1:o/T8/tkAq9IuRIuFqCupyKPC5iSY3WXpVZ2p6ZK3Emw=
kr.dev/diff v0.3.0/go.mod h1:XiTaLOg2/PD0cmXY7WQXUR8RAF3RwWpqIQEj910J2NY=

126
x/model/file.go Normal file
View File

@ -0,0 +1,126 @@
package model
import (
"bufio"
"io"
"iter"
"strings"
)
type Param struct {
Key string
Value string
}
type Message struct {
Role string
Content string
}
type File struct {
// From is a required pragma that specifies the source of the model,
// either on disk, or by reference (see blob.ParseRef).
From string
// Optional
Params []Param
Template string
System string
Adapter string
Messages []Message
License string
}
type Error struct {
Pragma string
Message string
}
func (e *Error) Error() string {
return e.Pragma + ": " + e.Message
}
type Pragma struct {
// The pragma name
Name string
// Args contains the user-defined arguments for the pragma. If no
// arguments were provided, it is nil.
Args []string
}
func (p Pragma) Arg(i int) string {
if i >= len(p.Args) {
return ""
}
return p.Args[i]
}
func Pragmas(r io.Reader) iter.Seq2[Pragma, error] {
return func(yield func(Pragma, error) bool) {
sc := bufio.NewScanner(r)
for sc.Scan() {
line := sc.Text()
// TODO(bmizerany): set a max num fields/args to
// prevent mem bloat
args := strings.Fields(line)
if len(args) == 0 {
continue
}
p := Pragma{
Name: strings.ToUpper(args[0]),
}
if p.Name == "MESSAGE" {
// handle special case where message content
// is space separated on the _rest_ of the
// line like: `MESSAGE user Is Ontario in
// Canada?`
panic("TODO")
}
if len(args) > 1 {
p.Args = args[1:]
}
if !yield(p, nil) {
return
}
}
if sc.Err() != nil {
yield(Pragma{}, sc.Err())
}
}
}
func Decode(r io.Reader) (File, error) {
var f File
for p, err := range Pragmas(r) {
if err != nil {
return File{}, err
}
switch p.Name {
case "FROM":
f.From = p.Arg(0)
case "PARAMETER":
f.Params = append(f.Params, Param{
Key: strings.ToLower(p.Arg(0)),
Value: p.Arg(1),
})
case "TEMPLATE":
f.Template = p.Arg(0)
case "SYSTEM":
f.System = p.Arg(0)
case "ADAPTER":
f.Adapter = p.Arg(0)
case "MESSAGE":
f.Messages = append(f.Messages, Message{
Role: p.Arg(0),
Content: p.Arg(1),
})
case "LICENSE":
f.License = p.Arg(0)
}
}
return f, nil
}

86
x/oweb/oweb.go Normal file
View File

@ -0,0 +1,86 @@
package oweb
import (
"cmp"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"bllamo.com/client/ollama"
)
func Missing(field string) error {
return &ollama.Error{
Status: 400,
Code: "missing",
Field: field,
Message: fmt.Sprintf("%s is required", field),
}
}
func Mistake(code, field, message string) error {
return &ollama.Error{
Status: 400,
Code: code,
Field: field,
Message: fmt.Sprintf("%s: %s", field, message),
}
}
// Convenience errors
var (
ErrNotFound = &ollama.Error{Status: 404, Code: "not_found"}
ErrInternal = &ollama.Error{Status: 500, Code: "internal_error"}
ErrMethodNotAllowed = &ollama.Error{Status: 405, Code: "method_not_allowed"}
)
type HandlerFunc func(w http.ResponseWriter, r *http.Request) error
func Serve(h HandlerFunc, w http.ResponseWriter, r *http.Request) {
if err := h(w, r); err != nil {
// TODO: take a slog.Logger
log.Printf("error: %v", err)
var oe *ollama.Error
if !errors.As(err, &oe) {
oe = ErrInternal
}
oe.Status = cmp.Or(oe.Status, 400)
w.WriteHeader(oe.Status)
if err := EncodeJSON(w, oe); err != nil {
log.Printf("error encoding error: %v", err)
}
}
}
func DecodeUserJSON[T any](field string, r io.Reader) (*T, error) {
v, err := DecodeJSON[T](r)
if err == nil {
return v, nil
}
var msg string
var e *json.SyntaxError
if errors.As(err, &e) {
msg = e.Error()
}
var se *json.UnmarshalTypeError
if errors.As(err, &se) {
msg = fmt.Sprintf("%s (%q) is not a %s", se.Field, se.Value, se.Type)
}
return nil, Mistake("invalid_json", field, msg)
}
func DecodeJSON[T any](r io.Reader) (*T, error) {
var v *T
if err := json.NewDecoder(r).Decode(&v); err != nil {
var zero T
return &zero, err
}
return v, nil
}
func EncodeJSON(w io.Writer, v any) error {
return json.NewEncoder(w).Encode(v)
}

View File

@ -0,0 +1,46 @@
package apitype
import "encoding/json"
type Manifest struct {
Layers []Layer `json:"layers"`
}
type CompletePart struct {
URL string `json:"url"` // contains PartNumber and UploadID from server
ETag string `json:"etag"`
}
type Layer struct {
Digest string `json:"digest"`
MediaType string `json:"mediaType"`
Size int64 `json:"size"`
}
type PushRequest struct {
Ref string `json:"ref"`
Manifest json.RawMessage `json:"manifest"`
// Parts is a list of upload parts that the client upload in the previous
// push.
Uploaded []CompletePart `json:"part_uploads"`
}
type Requirement struct {
Digest string `json:"digest"`
Offset int64 `json:"offset"`
Size int64 `json:"Size"`
// URL is the url to PUT the layer to.
//
// Clients must include it as the URL, alond with the ETag in the
// response headers from the PUT request, in the next push request
// in the Uploaded field.
URL string `json:"url"`
}
type PushResponse struct {
// Requirements is a list of digests that the client needs to push before
// repushing the manifest.
Requirements []Requirement `json:"requirements,omitempty"`
}

81
x/registry/client.go Normal file
View File

@ -0,0 +1,81 @@
package registry
import (
"cmp"
"context"
"encoding/xml"
"fmt"
"io"
"net/http"
"bllamo.com/client/ollama"
"bllamo.com/registry/apitype"
)
type Client struct {
BaseURL string
HTTPClient *http.Client
}
func (c *Client) oclient() *ollama.Client {
return (*ollama.Client)(c)
}
type PushParams struct {
Uploaded []apitype.CompletePart
}
// Push pushes a manifest to the server.
func (c *Client) Push(ctx context.Context, ref string, manifest []byte, p *PushParams) ([]apitype.Requirement, error) {
p = cmp.Or(p, &PushParams{})
// TODO(bmizerany): backoff
v, err := ollama.Do[apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{
Ref: ref,
Manifest: manifest,
Uploaded: p.Uploaded,
})
if err != nil {
return nil, err
}
return v.Requirements, nil
}
func PushLayer(ctx context.Context, dstURL string, off, size int64, file io.ReaderAt) (etag string, err error) {
sr := io.NewSectionReader(file, off, size)
req, err := http.NewRequestWithContext(ctx, "PUT", dstURL, sr)
if err != nil {
return "", err
}
req.ContentLength = size
res, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()
if res.StatusCode != 200 {
return "", parseS3Error(res)
}
return res.Header.Get("ETag"), nil
}
type s3Error struct {
XMLName xml.Name `xml:"Error"`
Code string `xml:"Code"`
Message string `xml:"Message"`
Resource string `xml:"Resource"`
RequestId string `xml:"RequestId"`
}
func (e *s3Error) Error() string {
return fmt.Sprintf("S3 (%s): %s: %s: %s", e.RequestId, e.Resource, e.Code, e.Message)
}
// parseS3Error parses an XML error response from S3.
func parseS3Error(res *http.Response) error {
var se *s3Error
if err := xml.NewDecoder(res.Body).Decode(&se); err != nil {
return err
}
return se
}

233
x/registry/server.go Normal file
View File

@ -0,0 +1,233 @@
// Package implements an Ollama registry client and server package registry
package registry
import (
"bytes"
"cmp"
"context"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"os"
"path"
"strconv"
"time"
"bllamo.com/build/blob"
"bllamo.com/client/ollama"
"bllamo.com/oweb"
"bllamo.com/registry/apitype"
"bllamo.com/utils/upload"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
)
// Defaults
const (
DefaultUploadChunkSize = 50 * 1024 * 1024
)
// TODO(bmizerany): move all env things to package envkobs?
var defaultLibrary = cmp.Or(os.Getenv("OLLAMA_REGISTRY"), "registry.ollama.ai/library")
func DefaultLibrary() string {
return defaultLibrary
}
type Server struct {
UploadChunkSize int64 // default is DefaultUploadChunkSize
minioClient *minio.Client
}
func New(mc *minio.Client) *Server {
return &Server{minioClient: mc}
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err := s.serveHTTP(w, r); err != nil {
log.Printf("error: %v", err) // TODO(bmizerany): take a slog.Logger
var e *ollama.Error
if !errors.As(err, &e) {
e = oweb.ErrInternal
}
w.WriteHeader(cmp.Or(e.Status, 400))
if err := oweb.EncodeJSON(w, e); err != nil {
log.Printf("error encoding error: %v", err)
}
}
}
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error {
switch r.URL.Path {
case "/v1/push":
return s.handlePush(w, r)
case "/v1/pull":
return s.handlePull(w, r)
default:
return oweb.ErrNotFound
}
}
func (s *Server) uploadChunkSize() int64 {
return cmp.Or(s.UploadChunkSize, DefaultUploadChunkSize)
}
func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
const bucketTODO = "test"
pr, err := oweb.DecodeUserJSON[apitype.PushRequest]("", r.Body)
if err != nil {
return err
}
ref := blob.ParseRef(pr.Ref)
if !ref.Complete() {
return oweb.Mistake("invalid", "name", "must be complete")
}
m, err := oweb.DecodeUserJSON[apitype.Manifest]("manifest", bytes.NewReader(pr.Manifest))
if err != nil {
return err
}
mcc := &minio.Core{Client: s.mc()}
// TODO(bmizerany): complete uploads before stats for any with ETag
type completeParts struct {
key string
parts []minio.CompletePart
}
completePartsByUploadID := make(map[string]completeParts)
for _, pu := range pr.Uploaded {
// parse the URL
u, err := url.Parse(pu.URL)
if err != nil {
return err
}
q := u.Query()
uploadID := q.Get("UploadId")
if uploadID == "" {
return oweb.Mistake("invalid", "url", "missing UploadId")
}
partNumber, err := strconv.Atoi(q.Get("PartNumber"))
if err != nil {
return oweb.Mistake("invalid", "url", "invalid or missing PartNumber")
}
etag := pu.ETag
if etag == "" {
return oweb.Mistake("invalid", "etag", "missing")
}
cp, ok := completePartsByUploadID[uploadID]
if !ok {
cp = completeParts{key: u.Path}
completePartsByUploadID[uploadID] = cp
}
cp.parts = append(cp.parts, minio.CompletePart{
PartNumber: partNumber,
ETag: etag,
})
fmt.Println("uploadID", uploadID, "partNumber", partNumber, "etag", etag)
completePartsByUploadID[uploadID] = cp
}
for uploadID, cp := range completePartsByUploadID {
var zeroOpts minio.PutObjectOptions
_, err := mcc.CompleteMultipartUpload(r.Context(), bucketTODO, cp.key, uploadID, cp.parts, zeroOpts)
if err != nil {
// log and continue; put backpressure on the client
log.Printf("error completing upload: %v", err)
}
}
var requirements []apitype.Requirement
for _, l := range m.Layers {
// TODO(bmizerany): do in parallel
if l.Size == 0 {
continue
}
// TODO(bmizerany): "global" throttle of rate of transfer
pushed, err := s.statObject(r.Context(), l.Digest)
if err != nil {
return err
}
if !pushed {
key := path.Join("blobs", l.Digest)
uploadID, err := mcc.NewMultipartUpload(r.Context(), bucketTODO, key, minio.PutObjectOptions{})
if err != nil {
return err
}
for partNumber, c := range upload.Chunks(l.Size, s.uploadChunkSize()) {
const timeToStartUpload = 15 * time.Minute
signedURL, err := s.mc().Presign(r.Context(), "PUT", bucketTODO, key, timeToStartUpload, url.Values{
"UploadId": []string{uploadID},
"PartNumber": []string{strconv.Itoa(partNumber)},
"ContentLength": []string{strconv.FormatInt(c.Size, 10)},
})
if err != nil {
return err
}
requirements = append(requirements, apitype.Requirement{
Digest: l.Digest,
Offset: c.Offset,
Size: c.Size,
URL: signedURL.String(),
})
}
}
}
if len(requirements) == 0 {
// Commit the manifest
body := bytes.NewReader(pr.Manifest)
path := path.Join("manifests", path.Join(ref.Parts()...))
_, err := s.mc().PutObject(r.Context(), bucketTODO, path, body, int64(len(pr.Manifest)), minio.PutObjectOptions{})
if err != nil {
return err
}
}
return oweb.EncodeJSON(w, &apitype.PushResponse{Requirements: requirements})
}
func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error {
// lookup manifest
panic("TODO")
}
func (s *Server) statObject(ctx context.Context, digest string) (pushed bool, err error) {
// HEAD the object
path := path.Join("blobs", digest)
_, err = s.mc().StatObject(ctx, "test", path, minio.StatObjectOptions{})
if err != nil {
if isNoSuchKey(err) {
err = nil
}
return false, err
}
return true, nil
}
func isNoSuchKey(err error) bool {
var e minio.ErrorResponse
return errors.As(err, &e) && e.Code == "NoSuchKey"
}
func (s *Server) mc() *minio.Client {
if s.minioClient != nil {
return s.minioClient
}
mc, err := minio.New("localhost:9000", &minio.Options{
Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""),
Secure: false,
})
if err != nil {
panic(err)
}
return mc
}

253
x/registry/server_test.go Normal file
View File

@ -0,0 +1,253 @@
package registry
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http/httptest"
"os"
"os/exec"
"strings"
"testing"
"time"
"bllamo.com/registry/apitype"
"bllamo.com/utils/backoff"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"kr.dev/diff"
)
const abc = "abcdefghijklmnopqrstuvwxyz"
func testPush(t *testing.T, chunkSize int64) {
t.Run(fmt.Sprintf("chunkSize=%d", chunkSize), func(t *testing.T) {
mc := startMinio(t, false)
manifest := []byte(`{
"layers": [
{"digest": "sha256-1", "size": 1},
{"digest": "sha256-2", "size": 2},
{"digest": "sha256-3", "size": 3}
]
}`)
const ref = "registry.ollama.ai/x/y:latest+Z"
hs := httptest.NewServer(&Server{
minioClient: mc,
UploadChunkSize: chunkSize,
})
t.Cleanup(hs.Close)
c := &Client{BaseURL: hs.URL}
requirements, err := c.Push(context.Background(), ref, manifest, nil)
if err != nil {
t.Fatal(err)
}
if len(requirements) < 3 {
t.Fatalf("expected at least 3 requirements; got %d", len(requirements))
t.Logf("requirements: %v", requirements)
}
var uploaded []apitype.CompletePart
for i, r := range requirements {
t.Logf("[%d] pushing layer: offset=%d size=%d", i, r.Offset, r.Size)
body := strings.NewReader(abc)
etag, err := PushLayer(context.Background(), r.URL, r.Offset, r.Size, body)
if err != nil {
t.Fatal(err)
}
uploaded = append(uploaded, apitype.CompletePart{
URL: r.URL,
ETag: etag,
})
}
requirements, err = c.Push(context.Background(), ref, manifest, &PushParams{
Uploaded: uploaded,
})
if err != nil {
t.Fatal(err)
}
if len(requirements) != 0 {
t.Fatalf("unexpected requirements: %v", requirements)
}
var paths []string
keys := mc.ListObjects(context.Background(), "test", minio.ListObjectsOptions{
Recursive: true,
})
for k := range keys {
paths = append(paths, k.Key)
}
t.Logf("paths: %v", paths)
diff.Test(t, t.Errorf, paths, []string{
"blobs/sha256-1",
"blobs/sha256-2",
"blobs/sha256-3",
"manifests/registry.ollama.ai/x/y/latest/Z",
})
obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{})
if err != nil {
t.Fatal(err)
}
defer obj.Close()
var gotM apitype.Manifest
if err := json.NewDecoder(obj).Decode(&gotM); err != nil {
t.Fatal(err)
}
diff.Test(t, t.Errorf, gotM, apitype.Manifest{
Layers: []apitype.Layer{
{Digest: "sha256-1", Size: 1},
{Digest: "sha256-2", Size: 2},
{Digest: "sha256-3", Size: 3},
},
})
// checksum the blobs
for i, l := range gotM.Layers {
obj, err := mc.GetObject(context.Background(), "test", "blobs/"+l.Digest, minio.GetObjectOptions{})
if err != nil {
t.Fatal(err)
}
defer obj.Close()
info, err := obj.Stat()
if err != nil {
t.Fatal(err)
}
t.Logf("[%d] layer info: name=%q l.Size=%d size=%d", i, info.Key, l.Size, info.Size)
data, err := io.ReadAll(obj)
if err != nil {
t.Fatal(err)
}
got := string(data)
want := abc[:l.Size]
if got != want {
t.Errorf("[%d] got layer data = %q; want %q", i, got, want)
}
}
})
}
func TestPush(t *testing.T) {
testPush(t, 0)
testPush(t, 1)
}
func availableAddr() string {
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
panic(err)
}
defer l.Close()
return l.Addr().String()
}
func startMinio(t *testing.T, debug bool) *minio.Client {
t.Helper()
dir := t.TempDir()
t.Logf(">> minio data dir: %s", dir)
addr := availableAddr()
cmd := exec.Command("minio", "server", "--address", addr, dir)
cmd.Env = os.Environ()
if debug {
stdout, err := cmd.StdoutPipe()
if err != nil {
t.Fatal(err)
}
doneLogging := make(chan struct{})
t.Cleanup(func() {
<-doneLogging
})
go func() {
defer close(doneLogging)
sc := bufio.NewScanner(stdout)
for sc.Scan() {
t.Logf("minio: %s", sc.Text())
}
}()
}
// TODO(bmizerany): wait delay etc...
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
cmd.Process.Kill()
if err := cmd.Wait(); err != nil {
var e *exec.ExitError
if errors.As(err, &e) && e.Exited() {
t.Logf("minio stderr: %s", e.Stderr)
t.Logf("minio exit status: %v", e.ExitCode())
t.Logf("minio exited: %v", e.Exited())
t.Error(err)
}
}
})
mc, err := minio.New(addr, &minio.Options{
Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""),
Secure: false,
})
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
deadline, ok := t.Deadline()
if ok {
ctx, cancel = context.WithDeadline(ctx, deadline.Add(-100*time.Millisecond))
defer cancel()
}
// wait for server to start with exponential backoff
for _, err := range backoff.Upto(ctx, 1*time.Second) {
if err != nil {
t.Fatal(err)
}
if mc.IsOnline() {
break
}
}
if err := mc.MakeBucket(context.Background(), "test", minio.MakeBucketOptions{}); err != nil {
t.Fatal(err)
}
return mc
}
// contextForTest returns a context that is canceled when the test deadline,
// if any, is reached. The returned doneLogging function should be called
// after all Log/Error/Fatalf calls are done before the test returns.
func contextForTest(t *testing.T) (_ context.Context, doneLogging func()) {
done := make(chan struct{})
deadline, ok := t.Deadline()
if !ok {
return context.Background(), func() {}
}
ctx, cancel := context.WithDeadline(context.Background(), deadline.Add(-100*time.Millisecond))
t.Cleanup(func() {
cancel()
<-done
})
return ctx, func() { close(done) }
}

4
x/types/empty/message.go Normal file
View File

@ -0,0 +1,4 @@
package empty
// Message is a placeholder type used when encoding json messages.
type Message struct{}

View File

@ -0,0 +1,15 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package structs contains the Incomparable type.
package structs
// Incomparable is a zero-width incomparable type. If added as the
// first field in a struct, it marks that struct as not comparable
// (can't do == or be a map key) and usually doesn't add any width to
// the struct (unless the struct has only small fields).
//
// By making a struct incomparable, you can prevent misuse (prevent
// people from using ==), but also you can shrink generated binaries,
// as the compiler can omit equality funcs from the binary.
type Incomparable [0]func()

12
x/types/they/want.go Normal file
View File

@ -0,0 +1,12 @@
package they
import (
"net/http"
"strings"
)
// Want returns true if the request method is method and the request path
// starts with pathPrefix.
func Want(r *http.Request, method string, pathPrefix string) bool {
return r.Method == method && strings.HasPrefix(r.URL.Path, pathPrefix)
}

View File

@ -0,0 +1,58 @@
package backoff
import (
"context"
"errors"
"iter"
"math/rand"
"time"
)
// Errors
var (
// ErrMaxAttempts is not used by backoff but is available for use by
// callers that want to signal that a maximum number of retries has
// been exceeded. This should eliminate the need for callers to invent
// their own error.
ErrMaxAttempts = errors.New("max retries exceeded")
)
// Upto implements a backoff strategy that yields nil errors until the
// context is canceled, the maxRetries is exceeded, or yield returns false.
//
// The backoff strategy is a simple exponential backoff with a maximum
// backoff of maxBackoff. The backoff is randomized between 0.5-1.5 times
// the current backoff, in order to prevent accidental "thundering herd"
// problems.
func Upto(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
var n int
return func(yield func(int, error) bool) {
for {
if ctx.Err() != nil {
yield(n, ctx.Err())
return
}
n++
// n^2 backoff timer is a little smoother than the
// common choice of 2^n.
d := time.Duration(n*n) * 10 * time.Millisecond
if d > maxBackoff {
d = maxBackoff
}
// Randomize the delay between 0.5-1.5 x msec, in order
// to prevent accidental "thundering herd" problems.
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
t := time.NewTimer(d)
select {
case <-ctx.Done():
t.Stop()
case <-t.C:
if !yield(n, nil) {
return
}
}
}
}
}

27
x/utils/upload/upload.go Normal file
View File

@ -0,0 +1,27 @@
package upload
import (
"iter"
"golang.org/x/exp/constraints"
)
type Chunk[I constraints.Integer] struct {
Offset I
Size I
}
// Chunks yields a sequence of a part number and a Chunk. The Chunk is the offset
// and size of the chunk. The last chunk may be smaller than chunkSize if size is
// not a multiple of chunkSize.
//
// The first part number is 1 and increases monotonically.
func Chunks[I constraints.Integer](size, chunkSize I) iter.Seq2[int, Chunk[I]] {
return func(yield func(int, Chunk[I]) bool) {
var n int
for off := I(0); off < size; off += chunkSize {
n++
yield(n, Chunk[I]{off, min(chunkSize, size-off)})
}
}
}

View File

@ -0,0 +1,37 @@
package upload
import (
"testing"
"kr.dev/diff"
)
func TestChunks(t *testing.T) {
const size = 101
const chunkSize = 10
var got []Chunk[int]
var lastN int
for n, c := range Chunks(size, chunkSize) {
if n != lastN+1 {
t.Errorf("n = %d; want %d", n, lastN+1)
}
got = append(got, c)
lastN = n
}
want := []Chunk[int]{
{0, 10},
{10, 10},
{20, 10},
{30, 10},
{40, 10},
{50, 10},
{60, 10},
{70, 10},
{80, 10},
{90, 10},
{100, 1},
}
diff.Test(t, t.Errorf, got, want)
}