Add 'x/' from commit 'a10a11b9d371f36b7c3510da32a1d70b74e27bd1'
git-subtree-dir: x git-subtree-mainline: 7d05a6ee8f44b314fa697a427439e5fa4d78c3d7 git-subtree-split: a10a11b9d371f36b7c3510da32a1d70b74e27bd1
This commit is contained in:
commit
adc23d5f96
116
x/api/api.go
Normal file
116
x/api/api.go
Normal file
@ -0,0 +1,116 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"bllamo.com/build"
|
||||
"bllamo.com/client/ollama/apitype"
|
||||
"bllamo.com/oweb"
|
||||
"bllamo.com/registry"
|
||||
regtype "bllamo.com/registry/apitype"
|
||||
)
|
||||
|
||||
// Common API Errors
|
||||
var (
|
||||
errUnqualifiedRef = oweb.Mistake("invalid", "name", "must be fully qualified")
|
||||
errRefNotFound = oweb.Mistake("not_found", "name", "no such model")
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
Build *build.Server
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
oweb.Serve(s.serveHTTP, w, r)
|
||||
}
|
||||
|
||||
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error {
|
||||
switch r.URL.Path {
|
||||
case "/v1/push":
|
||||
return s.handlePush(w, r)
|
||||
default:
|
||||
return oweb.ErrNotFound
|
||||
}
|
||||
}
|
||||
|
||||
func want(r *http.Request, method, path string) bool {
|
||||
return r.Method == method && r.URL.Path == path
|
||||
}
|
||||
|
||||
func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "POST" {
|
||||
return oweb.ErrMethodNotAllowed
|
||||
}
|
||||
|
||||
params, err := oweb.DecodeJSON[apitype.PushRequest](r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if params.Name == "" {
|
||||
return oweb.Missing("name")
|
||||
}
|
||||
|
||||
const registryURLTODO = "http://localhost:8888"
|
||||
|
||||
man, err := s.Build.ManifestData(params.Name)
|
||||
if err != nil {
|
||||
if errors.Is(err, build.ErrNotFound) {
|
||||
return errRefNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
c := registry.Client{BaseURL: registryURLTODO}
|
||||
requirements, err := c.Push(r.Context(), params.Name, man, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var uploads []regtype.CompletePart
|
||||
for _, rq := range requirements {
|
||||
l, err := s.Build.LayerFile(rq.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = func() error {
|
||||
f, err := os.Open(l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
etag, err := registry.PushLayer(r.Context(), rq.URL, rq.Offset, rq.Size, f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
uploads = append(uploads, regtype.CompletePart{
|
||||
URL: rq.URL,
|
||||
ETag: etag,
|
||||
})
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// commit the manifest to the registry
|
||||
requirements, err = c.Push(r.Context(), params.Name, man, ®istry.PushParams{
|
||||
Uploaded: uploads,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, r := range requirements {
|
||||
err = errors.Join(err, fmt.Errorf("push failed for %q", r.Digest))
|
||||
}
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
return oweb.ErrNotFound
|
||||
}
|
295
x/build/blob/ref.go
Normal file
295
x/build/blob/ref.go
Normal file
@ -0,0 +1,295 @@
|
||||
package blob
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Levels of concreteness
|
||||
const (
|
||||
domain = iota
|
||||
namespace
|
||||
name
|
||||
tag
|
||||
build
|
||||
)
|
||||
|
||||
// Ref is an opaque reference to a blob.
|
||||
//
|
||||
// It is comparable and can be used as a map key.
|
||||
//
|
||||
// Users or Ref must check Valid before using it.
|
||||
type Ref struct {
|
||||
domain string
|
||||
namespace string
|
||||
name string
|
||||
tag string
|
||||
build string
|
||||
}
|
||||
|
||||
// WithDomain returns a copy of r with the provided domain. If the provided
|
||||
// domain is empty, it returns the short, unqualified copy of r.
|
||||
func (r Ref) WithDomain(s string) Ref {
|
||||
return with(r, domain, s)
|
||||
}
|
||||
|
||||
// WithNamespace returns a copy of r with the provided namespace. If the
|
||||
// provided namespace is empty, it returns the short, unqualified copy of r.
|
||||
func (r Ref) WithNamespace(s string) Ref {
|
||||
return with(r, namespace, s)
|
||||
}
|
||||
|
||||
func (r Ref) WithTag(s string) Ref {
|
||||
return with(r, tag, s)
|
||||
}
|
||||
|
||||
// WithBuild returns a copy of r with the provided build. If the provided
|
||||
// build is empty, it returns the short, unqualified copy of r.
|
||||
func (r Ref) WithBuild(s string) Ref {
|
||||
return with(r, build, s)
|
||||
}
|
||||
|
||||
func with(r Ref, part int, value string) Ref {
|
||||
if value != "" && !isValidPart(value) {
|
||||
return Ref{}
|
||||
}
|
||||
switch part {
|
||||
case domain:
|
||||
r.domain = value
|
||||
case namespace:
|
||||
r.namespace = value
|
||||
case name:
|
||||
r.name = value
|
||||
case tag:
|
||||
r.tag = value
|
||||
case build:
|
||||
r.build = value
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid completeness: %d", part))
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Format returns a string representation of the ref with the given
|
||||
// concreteness. If a part is missing, it is replaced with a loud
|
||||
// placeholder.
|
||||
func (r Ref) Full() string {
|
||||
r.domain = cmp.Or(r.domain, "!(MISSING DOMAIN)")
|
||||
r.namespace = cmp.Or(r.namespace, "!(MISSING NAMESPACE)")
|
||||
r.name = cmp.Or(r.name, "!(MISSING NAME)")
|
||||
r.tag = cmp.Or(r.tag, "!(MISSING TAG)")
|
||||
r.build = cmp.Or(r.build, "!(MISSING BUILD)")
|
||||
return r.String()
|
||||
}
|
||||
|
||||
func (r Ref) NameAndTag() string {
|
||||
r.domain = ""
|
||||
r.namespace = ""
|
||||
r.build = ""
|
||||
return r.String()
|
||||
}
|
||||
|
||||
func (r Ref) NameTagAndBuild() string {
|
||||
r.domain = ""
|
||||
r.namespace = ""
|
||||
return r.String()
|
||||
}
|
||||
|
||||
// String returns the fully qualified ref string.
|
||||
func (r Ref) String() string {
|
||||
var b strings.Builder
|
||||
if r.domain != "" {
|
||||
b.WriteString(r.domain)
|
||||
b.WriteString("/")
|
||||
}
|
||||
if r.namespace != "" {
|
||||
b.WriteString(r.namespace)
|
||||
b.WriteString("/")
|
||||
}
|
||||
b.WriteString(r.name)
|
||||
if r.tag != "" {
|
||||
b.WriteString(":")
|
||||
b.WriteString(r.tag)
|
||||
}
|
||||
if r.build != "" {
|
||||
b.WriteString("+")
|
||||
b.WriteString(r.build)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Complete returns true if the ref is valid and has no empty parts.
|
||||
func (r Ref) Complete() bool {
|
||||
return r.Valid() && !slices.Contains(r.Parts(), "")
|
||||
}
|
||||
|
||||
func (r Ref) CompleteWithoutBuild() bool {
|
||||
return r.Valid() && !slices.Contains(r.Parts()[:tag], "")
|
||||
}
|
||||
|
||||
// Less returns true if r is less concrete than o; false otherwise.
|
||||
func (r Ref) Less(o Ref) bool {
|
||||
rp := r.Parts()
|
||||
op := o.Parts()
|
||||
for i := range rp {
|
||||
if rp[i] < op[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Parts returns the parts of the ref in order of concreteness.
|
||||
//
|
||||
// The length of the returned slice is always 5.
|
||||
func (r Ref) Parts() []string {
|
||||
return []string{
|
||||
domain: r.domain,
|
||||
namespace: r.namespace,
|
||||
name: r.name,
|
||||
tag: r.tag,
|
||||
build: r.build,
|
||||
}
|
||||
}
|
||||
|
||||
func (r Ref) Domain() string { return r.namespace }
|
||||
func (r Ref) Namespace() string { return r.namespace }
|
||||
func (r Ref) Name() string { return r.name }
|
||||
func (r Ref) Tag() string { return r.tag }
|
||||
func (r Ref) Build() string { return r.build }
|
||||
|
||||
// ParseRef parses a ref string into a Ref. A ref string is a name, an
|
||||
// optional tag, and an optional build, separated by colons and pluses.
|
||||
//
|
||||
// The name must be valid ascii [a-zA-Z0-9_].
|
||||
// The tag must be valid ascii [a-zA-Z0-9_].
|
||||
// The build must be valid ascii [a-zA-Z0-9_].
|
||||
//
|
||||
// It returns then zero value if the ref is invalid.
|
||||
//
|
||||
// // Valid Examples:
|
||||
// ParseRef("mistral:latest") returns ("mistral", "latest", "")
|
||||
// ParseRef("mistral") returns ("mistral", "", "")
|
||||
// ParseRef("mistral:30B") returns ("mistral", "30B", "")
|
||||
// ParseRef("mistral:7b") returns ("mistral", "7b", "")
|
||||
// ParseRef("mistral:7b+Q4_0") returns ("mistral", "7b", "Q4_0")
|
||||
// ParseRef("mistral+KQED") returns ("mistral", "latest", "KQED")
|
||||
// ParseRef(".x.:7b+Q4_0:latest") returns (".x.", "7b", "Q4_0")
|
||||
// ParseRef("-grok-f.oo:7b+Q4_0") returns ("-grok-f.oo", "7b", "Q4_0")
|
||||
//
|
||||
// // Invalid Examples:
|
||||
// ParseRef("m stral") returns ("", "", "") // zero
|
||||
// ParseRef("... 129 chars ...") returns ("", "", "") // zero
|
||||
func ParseRef(s string) Ref {
|
||||
if len(s) > 128 {
|
||||
return Ref{}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(s, "http://") {
|
||||
s = s[len("http://"):]
|
||||
}
|
||||
if strings.HasPrefix(s, "https://") {
|
||||
s = s[len("https://"):]
|
||||
}
|
||||
|
||||
var r Ref
|
||||
|
||||
state, j := build, len(s)
|
||||
for i := len(s) - 1; i >= 0; i-- {
|
||||
c := s[i]
|
||||
switch c {
|
||||
case '+':
|
||||
switch state {
|
||||
case build:
|
||||
r.build = s[i+1 : j]
|
||||
if r.build == "" {
|
||||
return Ref{}
|
||||
}
|
||||
r.build = strings.ToUpper(r.build)
|
||||
state, j = tag, i
|
||||
default:
|
||||
return Ref{}
|
||||
}
|
||||
case ':':
|
||||
switch state {
|
||||
case build, tag:
|
||||
r.tag = s[i+1 : j]
|
||||
if r.tag == "" {
|
||||
return Ref{}
|
||||
}
|
||||
state, j = name, i
|
||||
default:
|
||||
return Ref{}
|
||||
}
|
||||
case '/':
|
||||
switch state {
|
||||
case name, tag, build:
|
||||
r.name = s[i+1 : j]
|
||||
state, j = namespace, i
|
||||
case namespace:
|
||||
r.namespace = s[i+1 : j]
|
||||
state, j = domain, i
|
||||
default:
|
||||
return Ref{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle the first part based on final state
|
||||
switch state {
|
||||
case domain:
|
||||
r.domain = s[:j]
|
||||
case namespace:
|
||||
r.namespace = s[:j]
|
||||
default:
|
||||
r.name = s[:j]
|
||||
}
|
||||
|
||||
if !r.Valid() {
|
||||
return Ref{}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r Ref) Valid() bool {
|
||||
// Name is required
|
||||
if !isValidPart(r.name) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Optional parts must be valid if present
|
||||
if r.domain != "" && !isValidPart(r.domain) {
|
||||
return false
|
||||
}
|
||||
if r.namespace != "" && !isValidPart(r.namespace) {
|
||||
return false
|
||||
}
|
||||
if r.tag != "" && !isValidPart(r.tag) {
|
||||
return false
|
||||
}
|
||||
if r.build != "" && !isValidPart(r.build) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-]
|
||||
func isValidPart(s string) bool {
|
||||
if len(s) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, c := range []byte(s) {
|
||||
if c == '.' || c == '-' {
|
||||
return true
|
||||
}
|
||||
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' {
|
||||
continue
|
||||
} else {
|
||||
return false
|
||||
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
80
x/build/blob/ref_test.go
Normal file
80
x/build/blob/ref_test.go
Normal file
@ -0,0 +1,80 @@
|
||||
package blob
|
||||
|
||||
import "testing"
|
||||
|
||||
// test refs
|
||||
const (
|
||||
refTooLong = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
)
|
||||
|
||||
var testRefs = map[string]Ref{
|
||||
"mistral:latest": {name: "mistral", tag: "latest"},
|
||||
"mistral": {name: "mistral"},
|
||||
"mistral:30B": {name: "mistral", tag: "30B"},
|
||||
"mistral:7b": {name: "mistral", tag: "7b"},
|
||||
"mistral:7b+Q4_0": {name: "mistral", tag: "7b", build: "Q4_0"},
|
||||
"mistral+KQED": {name: "mistral", build: "KQED"},
|
||||
"mistral.x-3:7b+Q4_0": {name: "mistral.x-3", tag: "7b", build: "Q4_0"},
|
||||
"mistral:7b+q4_0": {name: "mistral", tag: "7b", build: "Q4_0"},
|
||||
"llama2": {name: "llama2"},
|
||||
|
||||
// invalid
|
||||
"mistral:7b+Q4_0:latest": {},
|
||||
"mi tral": {},
|
||||
}
|
||||
|
||||
func TestRefParts(t *testing.T) {
|
||||
const wantNumParts = 5
|
||||
var ref Ref
|
||||
if len(ref.Parts()) != wantNumParts {
|
||||
t.Errorf("Parts() = %d; want %d", len(ref.Parts()), wantNumParts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRef(t *testing.T) {
|
||||
for s, want := range testRefs {
|
||||
t.Run(s, func(t *testing.T) {
|
||||
got := ParseRef(s)
|
||||
if got != want {
|
||||
t.Errorf("ParseRef(%q) = %q; want %q", s, got, want)
|
||||
}
|
||||
|
||||
// test round-trip
|
||||
if ParseRef(got.String()) != got {
|
||||
t.Errorf("String() = %q; want %q", got.String(), s)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefFull(t *testing.T) {
|
||||
const empty = "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/!(MISSING NAME):!(MISSING TAG)+!(MISSING BUILD)"
|
||||
|
||||
cases := []struct {
|
||||
in string
|
||||
wantFull string
|
||||
}{
|
||||
{"", empty},
|
||||
{"example.com/mistral:7b+x", "!(MISSING DOMAIN)/example.com/mistral:7b+X"},
|
||||
{"example.com/mistral:7b+Q4_0", "!(MISSING DOMAIN)/example.com/mistral:7b+Q4_0"},
|
||||
{"example.com/x/mistral:latest", "example.com/x/mistral:latest+!(MISSING BUILD)"},
|
||||
{"example.com/x/mistral:latest+Q4_0", "example.com/x/mistral:latest+Q4_0"},
|
||||
|
||||
{"mistral:7b+x", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:7b+X"},
|
||||
{"mistral:7b+q4_0", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:7b+Q4_0"},
|
||||
{"mistral:7b+Q4_0", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:7b+Q4_0"},
|
||||
{"mistral:latest", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:latest+!(MISSING BUILD)"},
|
||||
{"mistral", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:!(MISSING TAG)+!(MISSING BUILD)"},
|
||||
{"mistral:30b", "!(MISSING DOMAIN)/!(MISSING NAMESPACE)/mistral:30b+!(MISSING BUILD)"},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
ref := ParseRef(tt.in)
|
||||
t.Logf("ParseRef(%q) = %#v", tt.in, ref)
|
||||
if g := ref.Full(); g != tt.wantFull {
|
||||
t.Errorf("Full(%q) = %q; want %q", tt.in, g, tt.wantFull)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
210
x/build/build.go
Normal file
210
x/build/build.go
Normal file
@ -0,0 +1,210 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"bllamo.com/build/blob"
|
||||
"bllamo.com/build/internal/blobstore"
|
||||
"bllamo.com/model"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrIncompleteRef = errors.New("unqualified ref")
|
||||
ErrBuildPresentInRef = errors.New("build present in ref")
|
||||
ErrUnsupportedModelFormat = errors.New("unsupported model format")
|
||||
ErrMissingFileType = errors.New("missing 'general.file_type' key")
|
||||
ErrNotFound = errors.New("not found")
|
||||
)
|
||||
|
||||
type mediaType string
|
||||
|
||||
// Known media types
|
||||
const (
|
||||
mediaTypeModel mediaType = "application/vnd.ollama.image.model"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
st *blobstore.Store
|
||||
}
|
||||
|
||||
// Open starts a new build server that uses dir as the base directory for all
|
||||
// build artifacts. If dir is empty, DefaultDir is used.
|
||||
//
|
||||
// It returns an error if the provided or default dir cannot be initialized.
|
||||
func Open(dir string) (*Server, error) {
|
||||
if dir == "" {
|
||||
var err error
|
||||
dir, err = DefaultDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
st, err := blobstore.Open(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Server{st: st}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Build(ref string, f model.File) error {
|
||||
br := blob.ParseRef(ref)
|
||||
if !br.CompleteWithoutBuild() {
|
||||
return fmt.Errorf("%w: %q", ErrIncompleteRef, ref)
|
||||
}
|
||||
|
||||
// 1. Resolve FROM
|
||||
// a. If it's a local file (gguf), hash it and add it to the store.
|
||||
// c. If it's a remote file (http), refuse.
|
||||
// 2. Turn other pragmas into layers, and add them to the store.
|
||||
// 3. Create a manifest from the layers.
|
||||
// 4. Store the manifest in the manifest cache
|
||||
// 5. Done.
|
||||
|
||||
if f.From == "" {
|
||||
return &model.Error{Pragma: "FROM", Message: "missing"}
|
||||
}
|
||||
|
||||
var layers []layerJSON
|
||||
|
||||
id, info, size, err := s.importModel(f.From)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
layers = append(layers, layerJSON{
|
||||
ID: id,
|
||||
MediaType: mediaTypeModel,
|
||||
Size: size,
|
||||
})
|
||||
|
||||
id, size, err = blobstore.PutString(s.st, f.License)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
layers = append(layers, layerJSON{
|
||||
ID: id,
|
||||
MediaType: "text/plain",
|
||||
Size: size,
|
||||
})
|
||||
|
||||
data, err := json.Marshal(manifestJSON{Layers: layers})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.setManifestData(
|
||||
br.WithBuild(info.FileType.String()),
|
||||
data,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *Server) LayerFile(digest string) (string, error) {
|
||||
fileName := s.st.OutputFilename(blobstore.ParseID(digest))
|
||||
_, err := os.Stat(fileName)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return "", fmt.Errorf("%w: %q", ErrNotFound, digest)
|
||||
}
|
||||
return fileName, nil
|
||||
}
|
||||
|
||||
func (s *Server) ManifestData(ref string) ([]byte, error) {
|
||||
data, _, err := s.resolve(blob.ParseRef(ref))
|
||||
return data, err
|
||||
}
|
||||
|
||||
// WeightFile returns the absolute path to the weights file for the given model ref.
|
||||
func (s *Server) WeightsFile(ref string) (string, error) {
|
||||
m, err := s.getManifest(blob.ParseRef(ref))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, l := range m.Layers {
|
||||
if l.MediaType == mediaTypeModel {
|
||||
return s.st.OutputFilename(l.ID), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("missing weights layer for %q", ref)
|
||||
}
|
||||
|
||||
// resolve returns the data for the given ref, if any.
|
||||
//
|
||||
// TODO: This should ideally return an ID, but the current on
|
||||
// disk layout is that the actual manifest is stored in the "ref" instead of
|
||||
// a pointer to a content-addressed blob. I (bmizerany) think we should
|
||||
// change the on-disk layout to store the manifest in a content-addressed
|
||||
// blob, and then have the ref point to that blob. This would simplify the
|
||||
// code, allow us to have integrity checks on the manifest, and clean up
|
||||
// this interface.
|
||||
func (s *Server) resolve(ref blob.Ref) (data []byte, path string, err error) {
|
||||
path, err = s.refFileName(ref)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
data, err = os.ReadFile(path)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, "", fmt.Errorf("%w: %q", ErrNotFound, ref)
|
||||
}
|
||||
if err != nil {
|
||||
// do not wrap the error here, as it is likely an I/O error
|
||||
// and we want to preserve the absraction since we may not
|
||||
// be on disk later.
|
||||
return nil, "", fmt.Errorf("manifest read error: %v", err)
|
||||
}
|
||||
return data, path, nil
|
||||
}
|
||||
|
||||
func (s *Server) SetManifestData(ref string, data []byte) error {
|
||||
return s.setManifestData(blob.ParseRef(ref), data)
|
||||
}
|
||||
|
||||
// Set sets the data for the given ref.
|
||||
func (s *Server) setManifestData(br blob.Ref, data []byte) error {
|
||||
path, err := s.refFileName(br)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0666); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) refFileName(ref blob.Ref) (string, error) {
|
||||
if !ref.Complete() {
|
||||
return "", fmt.Errorf("ref not fully qualified: %q", ref)
|
||||
}
|
||||
return filepath.Join(s.st.Dir(), "manifests", filepath.Join(ref.Parts()...)), nil
|
||||
}
|
||||
|
||||
type manifestJSON struct {
|
||||
// Layers is the list of layers in the manifest.
|
||||
Layers []layerJSON `json:"layers"`
|
||||
}
|
||||
|
||||
// Layer is a layer in a model manifest.
|
||||
type layerJSON struct {
|
||||
// ID is the ID of the layer.
|
||||
ID blobstore.ID `json:"digest"`
|
||||
MediaType mediaType `json:"mediaType"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
func (s *Server) getManifest(ref blob.Ref) (manifestJSON, error) {
|
||||
data, path, err := s.resolve(ref)
|
||||
if err != nil {
|
||||
return manifestJSON{}, err
|
||||
}
|
||||
var m manifestJSON
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return manifestJSON{}, &fs.PathError{Op: "unmarshal", Path: path, Err: err}
|
||||
}
|
||||
return m, nil
|
||||
}
|
163
x/build/build_test.go
Normal file
163
x/build/build_test.go
Normal file
@ -0,0 +1,163 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"bllamo.com/encoding/gguf"
|
||||
"bllamo.com/model"
|
||||
)
|
||||
|
||||
const qualifiedRef = "x/y/z:latest+Q4_0"
|
||||
|
||||
func TestServerBuildErrors(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
s, err := Open(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("unqualified ref", func(t *testing.T) {
|
||||
err := s.Build("x", model.File{})
|
||||
if !errors.Is(err, ErrIncompleteRef) {
|
||||
t.Fatalf("Build() err = %v; want unqualified ref", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM pragma missing", func(t *testing.T) {
|
||||
err := s.Build(qualifiedRef, model.File{})
|
||||
var e *model.Error
|
||||
if !errors.As(err, &e) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if e.Pragma != "FROM" {
|
||||
t.Errorf("e.Pragma = %s; want FROM", e.Pragma)
|
||||
}
|
||||
if e.Message != "missing" {
|
||||
t.Errorf("e.Message = %s; want missing", e.Message)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM file not found", func(t *testing.T) {
|
||||
err := s.Build(qualifiedRef, model.File{From: "bar"})
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
t.Fatalf("Build() err = %v; want file not found", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM gguf", func(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
// Write a gguf file without general.file_type metadata.
|
||||
w.write("gguf", ""+
|
||||
"GGUF"+ // magic
|
||||
"\x03\x00\x00\x00"+ // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numMetaValues
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numTensors
|
||||
"",
|
||||
)
|
||||
|
||||
err := s.Build(qualifiedRef, model.File{From: w.fileName("gguf")})
|
||||
if !errors.Is(err, ErrMissingFileType) {
|
||||
t.Fatalf("Build() err = %#v; want missing file type", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM obscure dir", func(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
w.mkdirAll("unknown")
|
||||
if err := s.Build(qualifiedRef, model.File{From: w.fileName("unknown")}); err != ErrUnsupportedModelFormat {
|
||||
t.Fatalf("Build() err = %#v; want unsupported model type", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM unsupported model type", func(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
from := w.write("unknown", "unknown content")
|
||||
err := s.Build(qualifiedRef, model.File{From: from})
|
||||
if !errors.Is(err, ErrUnsupportedModelFormat) {
|
||||
t.Fatalf("Build() err = %#v; want unsupported model type", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildBasicGGUF(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
w.write("gguf", ""+
|
||||
"GGUF"+ // magic
|
||||
"\x03\x00\x00\x00"+ // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00"+ // numMetaValues
|
||||
|
||||
// general.file_type key
|
||||
"\x11\x00\x00\x00\x00\x00\x00\x00"+ // key length
|
||||
"general.file_type"+ // key
|
||||
"\x04\x00\x00\x00"+ // type (uint32)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00"+ // uint32 value
|
||||
"",
|
||||
)
|
||||
|
||||
dir := t.TempDir()
|
||||
s, err := Open(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := s.Build(qualifiedRef, model.File{From: w.fileName("gguf")}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
filepath.Walk(dir, func(p string, info os.FileInfo, err error) error {
|
||||
t.Logf("file: %s", p)
|
||||
return nil
|
||||
})
|
||||
|
||||
_, err = s.WeightsFile("unknown/y/z:latest+Q4_0")
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Fatalf("WeightsFile() err = %v; want not found", err)
|
||||
}
|
||||
|
||||
path, err := s.WeightsFile("x/y/z:latest+Q4_0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
info, err := gguf.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if info.FileType != gguf.TypeQ4_0 {
|
||||
t.Errorf("info.FileType = %d; want 1", info.FileType)
|
||||
}
|
||||
}
|
||||
|
||||
type work struct {
|
||||
t testing.TB
|
||||
dir string
|
||||
}
|
||||
|
||||
func newWorkDir(t *testing.T) work {
|
||||
return work{t: t, dir: t.TempDir()}
|
||||
}
|
||||
|
||||
func (w work) write(name, content string) (path string) {
|
||||
w.t.Helper()
|
||||
path = w.fileName(name)
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
w.t.Fatal(err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func (w work) fileName(name string) string {
|
||||
w.t.Helper()
|
||||
return filepath.Join(w.dir, name)
|
||||
}
|
||||
|
||||
func (w work) mkdirAll(path string) {
|
||||
w.t.Helper()
|
||||
if err := os.MkdirAll(filepath.Join(w.dir, path), 0755); err != nil {
|
||||
w.t.Fatal(err)
|
||||
}
|
||||
}
|
12
x/build/convert.go
Normal file
12
x/build/convert.go
Normal file
@ -0,0 +1,12 @@
|
||||
package build
|
||||
|
||||
func convertSafeTensorToGGUF(path string) (ggufPath string, err error) {
|
||||
// TODO: decine on hueristic for converting safetensor to gguf and
|
||||
// the errors that can be returned. For now, we just say
|
||||
// "unsupported", however it may be intended to be a valid safe
|
||||
// tensor but we hit an error in the conversion.
|
||||
//
|
||||
// I (bmizernay) think this will naturally evolve as we implement
|
||||
// the conversion.
|
||||
return "", ErrUnsupportedModelFormat
|
||||
}
|
28
x/build/default.go
Normal file
28
x/build/default.go
Normal file
@ -0,0 +1,28 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultDir = sync.OnceValues(func() (string, error) {
|
||||
dir := os.Getenv("OLLAMA_MODELS")
|
||||
if dir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
dir = filepath.Join(home, ".ollama", "models")
|
||||
}
|
||||
return dir, nil
|
||||
})
|
||||
)
|
||||
|
||||
// DefaultDir returns the default directory for models. It returns the value
|
||||
// of the OLLAMA_MODELS environment variable if set; otherwise it returns
|
||||
// "$HOME/.ollama/models".
|
||||
func DefaultDir() (string, error) {
|
||||
return defaultDir()
|
||||
}
|
59
x/build/import.go
Normal file
59
x/build/import.go
Normal file
@ -0,0 +1,59 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"bllamo.com/build/internal/blobstore"
|
||||
"bllamo.com/encoding/gguf"
|
||||
)
|
||||
|
||||
func importError(err error) (blobstore.ID, gguf.Info, int64, error) {
|
||||
return blobstore.ID{}, gguf.Info{}, 0, err
|
||||
}
|
||||
|
||||
func (s *Server) importModel(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
if info.IsDir() {
|
||||
return s.importSafeTensor(path)
|
||||
} else {
|
||||
return s.importGGUF(path)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) importGGUF(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
info, err := gguf.StatReader(f)
|
||||
if errors.Is(err, gguf.ErrBadMagic) {
|
||||
return importError(ErrUnsupportedModelFormat)
|
||||
}
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
|
||||
if info.FileType == 0 {
|
||||
return importError(fmt.Errorf("%w: %q", ErrMissingFileType, path))
|
||||
}
|
||||
id, size, err := s.st.Put(f)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
return id, info, size, nil
|
||||
}
|
||||
|
||||
func (s *Server) importSafeTensor(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
|
||||
path, err := convertSafeTensorToGGUF(path)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
return s.importGGUF(path)
|
||||
}
|
329
x/build/internal/blobstore/blob.go
Normal file
329
x/build/internal/blobstore/blob.go
Normal file
@ -0,0 +1,329 @@
|
||||
// Package blobstore implements a blob store.
|
||||
package blobstore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"bllamo.com/types/structs"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidID = errors.New("invalid ID")
|
||||
)
|
||||
|
||||
const HashSize = 32
|
||||
|
||||
// An ID is a blob output key, the hash of an output of a computation.
|
||||
type ID struct {
|
||||
a [HashSize]byte
|
||||
}
|
||||
|
||||
func (id ID) MarshalText() ([]byte, error) {
|
||||
return []byte(id.String()), nil
|
||||
}
|
||||
|
||||
func (id *ID) UnmarshalText(text []byte) error {
|
||||
*id = ParseID(string(text))
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParseID(s string) ID {
|
||||
const prefix = "sha256-"
|
||||
h, ok := strings.CutPrefix(s, prefix)
|
||||
if !ok {
|
||||
return ID{}
|
||||
}
|
||||
|
||||
if len(h) != HashSize*2 {
|
||||
return ID{}
|
||||
}
|
||||
|
||||
var b []byte
|
||||
_, err := fmt.Sscanf(h, "%x", &b)
|
||||
if err != nil {
|
||||
return ID{}
|
||||
}
|
||||
|
||||
var id ID
|
||||
copy(id.a[:], b)
|
||||
return id
|
||||
}
|
||||
|
||||
func (id ID) String() string {
|
||||
if !id.Valid() {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("sha256-%x", id.a[:])
|
||||
}
|
||||
|
||||
func (id ID) Valid() bool {
|
||||
return id != ID{}
|
||||
}
|
||||
|
||||
func (id ID) Match(h [HashSize]byte) bool {
|
||||
return id.a == h
|
||||
}
|
||||
|
||||
// A Store is a blob store, backed by a file system directory tree.
|
||||
type Store struct {
|
||||
dir string
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
// Open opens and returns the store in the given directory.
|
||||
//
|
||||
// It is safe for multiple processes on a single machine to use the
|
||||
// same store directory in a local file system simultaneously.
|
||||
// They will coordinate using operating system file locks and may
|
||||
// duplicate effort but will not corrupt the store.
|
||||
//
|
||||
// However, it is NOT safe for multiple processes on different machines
|
||||
// to share a store directory (for example, if the directory were stored
|
||||
// in a network file system). File locking is notoriously unreliable in
|
||||
// network file systems and may not suffice to protect the store.
|
||||
func Open(dir string) (*Store, error) {
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return nil, &fs.PathError{Op: "open", Path: dir, Err: fmt.Errorf("not a directory")}
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Join(dir, "blobs"), 0777); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &Store{
|
||||
dir: dir,
|
||||
now: time.Now,
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (s *Store) Dir() string {
|
||||
return s.dir
|
||||
}
|
||||
|
||||
// fileName returns the name of the blob file corresponding to the given id.
|
||||
func (s *Store) fileName(id ID) string {
|
||||
return filepath.Join(s.dir, "blobs", fmt.Sprintf("sha256-%x", id.a[:]))
|
||||
}
|
||||
|
||||
// An entryNotFoundError indicates that a store entry was not found, with an
|
||||
// optional underlying reason.
|
||||
type entryNotFoundError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *entryNotFoundError) Error() string {
|
||||
if e.Err == nil {
|
||||
return "store entry not found"
|
||||
}
|
||||
return fmt.Sprintf("store entry not found: %v", e.Err)
|
||||
}
|
||||
|
||||
func (e *entryNotFoundError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type Entry struct {
|
||||
_ structs.Incomparable
|
||||
|
||||
ID ID
|
||||
Size int64
|
||||
Time time.Time // when added to store
|
||||
}
|
||||
|
||||
// GetFile looks up the blob ID in the store and returns
|
||||
// the name of the corresponding data file.
|
||||
func GetFile(s *Store, id ID) (file string, entry Entry, err error) {
|
||||
entry, err = s.Get(id)
|
||||
if err != nil {
|
||||
return "", Entry{}, err
|
||||
}
|
||||
file = s.OutputFilename(entry.ID)
|
||||
info, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return "", Entry{}, &entryNotFoundError{Err: err}
|
||||
}
|
||||
if info.Size() != entry.Size {
|
||||
return "", Entry{}, &entryNotFoundError{Err: errors.New("file incomplete")}
|
||||
}
|
||||
return file, entry, nil
|
||||
}
|
||||
|
||||
// GetBytes looks up the blob ID in the store and returns
|
||||
// the corresponding output bytes.
|
||||
// GetBytes should only be used for data that can be expected to fit in memory.
|
||||
func GetBytes(s *Store, id ID) ([]byte, Entry, error) {
|
||||
entry, err := s.Get(id)
|
||||
if err != nil {
|
||||
return nil, entry, err
|
||||
}
|
||||
data, _ := os.ReadFile(s.OutputFilename(entry.ID))
|
||||
if entry.ID.Match(sha256.Sum256(data)) {
|
||||
return nil, entry, &entryNotFoundError{Err: errors.New("bad checksum")}
|
||||
}
|
||||
return data, entry, nil
|
||||
}
|
||||
|
||||
// OutputFilename returns the name of the blob file for the given ID.
|
||||
func (s *Store) OutputFilename(id ID) string {
|
||||
file := s.fileName(id)
|
||||
// TODO(bmizerany): touch as "used" for cache trimming. (see
|
||||
// cache.go in cmd/go/internal/cache for the full reference implementation to go off of.
|
||||
return file
|
||||
}
|
||||
|
||||
// Get looks up the blob ID in the store,
|
||||
// returning the corresponding output ID and file size, if any.
|
||||
// Note that finding an output ID does not guarantee that the
|
||||
// saved file for that output ID is still available.
|
||||
func (s *Store) Get(id ID) (Entry, error) {
|
||||
file := s.fileName(id)
|
||||
info, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return Entry{}, &entryNotFoundError{Err: err}
|
||||
}
|
||||
return Entry{
|
||||
ID: id,
|
||||
Size: info.Size(),
|
||||
Time: info.ModTime(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) Close() error {
|
||||
// TODO(bmizerany): return c.Trim()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Put stores the data read from the given file into the store as ID.
|
||||
//
|
||||
// It may read file twice. The content of file must not change between the
|
||||
// two passes.
|
||||
func (s *Store) Put(file io.ReadSeeker) (ID, int64, error) {
|
||||
return s.put(file)
|
||||
}
|
||||
|
||||
func PutBytes(s *Store, data []byte) (ID, int64, error) {
|
||||
return s.Put(bytes.NewReader(data))
|
||||
}
|
||||
|
||||
func PutString(s *Store, data string) (ID, int64, error) {
|
||||
return s.Put(strings.NewReader(data))
|
||||
}
|
||||
|
||||
func (s *Store) put(file io.ReadSeeker) (ID, int64, error) {
|
||||
// Compute output ID.
|
||||
h := sha256.New()
|
||||
if _, err := file.Seek(0, 0); err != nil {
|
||||
return ID{}, 0, err
|
||||
}
|
||||
size, err := io.Copy(h, file)
|
||||
if err != nil {
|
||||
return ID{}, 0, err
|
||||
}
|
||||
var out ID
|
||||
h.Sum(out.a[:0])
|
||||
|
||||
// Copy to blob file (if not already present).
|
||||
if err := s.copyFile(file, out, size); err != nil {
|
||||
return out, size, err
|
||||
}
|
||||
|
||||
// TODO: Add to manifest index.
|
||||
return out, size, nil
|
||||
}
|
||||
|
||||
// copyFile copies file into the store, expecting it to have the given
|
||||
// output ID and size, if that file is not present already.
|
||||
func (s *Store) copyFile(file io.ReadSeeker, out ID, size int64) error {
|
||||
name := s.fileName(out)
|
||||
println("name", name)
|
||||
info, err := os.Stat(name)
|
||||
if err == nil && info.Size() == size {
|
||||
// Check hash.
|
||||
if f, err := os.Open(name); err == nil {
|
||||
h := sha256.New()
|
||||
io.Copy(h, f)
|
||||
f.Close()
|
||||
var out2 ID
|
||||
h.Sum(out2.a[:0])
|
||||
if out == out2 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// Hash did not match. Fall through and rewrite file.
|
||||
}
|
||||
|
||||
// Copy file to blobs directory.
|
||||
mode := os.O_RDWR | os.O_CREATE
|
||||
if err == nil && info.Size() > size { // shouldn't happen but fix in case
|
||||
mode |= os.O_TRUNC
|
||||
}
|
||||
f, err := os.OpenFile(name, mode, 0666)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
if size == 0 {
|
||||
// File now exists with correct size.
|
||||
// Only one possible zero-length file, so contents are OK too.
|
||||
// Early return here makes sure there's a "last byte" for code below.
|
||||
return nil
|
||||
}
|
||||
|
||||
// From here on, if any of the I/O writing the file fails,
|
||||
// we make a best-effort attempt to truncate the file f
|
||||
// before returning, to avoid leaving bad bytes in the file.
|
||||
|
||||
// Copy file to f, but also into h to double-check hash.
|
||||
if _, err := file.Seek(0, 0); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
h := sha256.New()
|
||||
w := io.MultiWriter(f, h)
|
||||
if _, err := io.CopyN(w, file, size-1); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
// Check last byte before writing it; writing it will make the size match
|
||||
// what other processes expect to find and might cause them to start
|
||||
// using the file.
|
||||
buf := make([]byte, 1)
|
||||
if _, err := file.Read(buf); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
h.Write(buf)
|
||||
sum := h.Sum(nil)
|
||||
if !bytes.Equal(sum, out.a[:]) {
|
||||
f.Truncate(0)
|
||||
return fmt.Errorf("file content changed underfoot")
|
||||
}
|
||||
|
||||
// Commit manifest entry.
|
||||
if _, err := f.Write(buf); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
// Data might not have been written,
|
||||
// but file may look like it is the right size.
|
||||
// To be extra careful, remove stored file.
|
||||
os.Remove(name)
|
||||
return err
|
||||
}
|
||||
os.Chtimes(name, s.now(), s.now()) // mainly for tests
|
||||
|
||||
return nil
|
||||
}
|
54
x/build/internal/blobstore/blob_test.go
Normal file
54
x/build/internal/blobstore/blob_test.go
Normal file
@ -0,0 +1,54 @@
|
||||
package blobstore
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseID(t *testing.T) {
|
||||
const valid = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
var invalid = strings.Repeat("\x00", HashSize*2)
|
||||
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"", invalid},
|
||||
{"sha256-", invalid},
|
||||
{"sha256-" + valid, valid},
|
||||
|
||||
{"" + valid, invalid}, // no prefix
|
||||
{"sha123-" + valid, invalid}, // invalid prefix
|
||||
{"sha256-" + valid[1:], invalid}, // too short
|
||||
{"sha256-" + valid + "a", invalid}, // too long
|
||||
{"sha256-!" + valid[1:], invalid}, // invalid hex
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
// sanity check
|
||||
if len(tt.want) > HashSize*2 {
|
||||
panic("invalid test")
|
||||
}
|
||||
|
||||
got := ParseID(tt.in)
|
||||
|
||||
wantValid := tt.want != invalid
|
||||
if wantValid {
|
||||
if !got.Valid() {
|
||||
t.Errorf("ParseID(%q).Valid() = false; want true", tt.in)
|
||||
}
|
||||
if got.String() != "sha256-"+tt.want {
|
||||
t.Errorf("ParseID(%q).String() = %q; want %q", tt.in, got.String(), "sha256-"+tt.want)
|
||||
}
|
||||
} else {
|
||||
if got.Valid() {
|
||||
t.Errorf("ParseID(%q).Valid() = true; want false", tt.in)
|
||||
}
|
||||
if got.String() != "" {
|
||||
t.Errorf("ParseID(%q).String() = %q; want %q", tt.in, got.String(), "")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
128
x/build/internal/blobstore/store_test.go
Normal file
128
x/build/internal/blobstore/store_test.go
Normal file
@ -0,0 +1,128 @@
|
||||
package blobstore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"iter"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"bllamo.com/build/blob"
|
||||
"kr.dev/diff"
|
||||
)
|
||||
|
||||
const (
|
||||
blobNameHello = "sha256-2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
|
||||
)
|
||||
|
||||
func TestStoreBasicBlob(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
checkDir(t, dir, nil)
|
||||
|
||||
st, err := Open(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
st.now = func() time.Time { return now }
|
||||
|
||||
checkDir(t, dir, []string{
|
||||
"blobs/",
|
||||
})
|
||||
|
||||
id, size, err := PutBytes(st, []byte("hello"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if id != ParseID(blobNameHello) {
|
||||
t.Errorf("unexpected ID: %s", id)
|
||||
}
|
||||
if size != 5 {
|
||||
t.Errorf("unexpected size: %d", size)
|
||||
}
|
||||
|
||||
checkDir(t, dir, []string{
|
||||
"blobs/",
|
||||
"blobs/" + blobNameHello,
|
||||
})
|
||||
|
||||
got, err := st.Get(id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
diff.Test(t, t.Errorf, got, Entry{
|
||||
ID: id,
|
||||
Size: 5,
|
||||
Time: now,
|
||||
})
|
||||
|
||||
file := st.OutputFilename(id)
|
||||
wantFile := filepath.Join(dir, "blobs", blobNameHello)
|
||||
if file != wantFile {
|
||||
t.Errorf("unexpected file: %s", file)
|
||||
}
|
||||
|
||||
// Check tags
|
||||
ref := blob.ParseRef("registry.ollama.ai/library/test:latest+KQED")
|
||||
|
||||
t.Logf("RESOLVING: %q", ref.Parts())
|
||||
|
||||
}
|
||||
|
||||
// checkDir checks that the directory at dir contains the files in want. The
|
||||
// files in want must be relative to dir.
|
||||
//
|
||||
// direcotories are suffixed with a slash (e.g. "foo/" instead of "foo").
|
||||
//
|
||||
// want must be in lexicographic order.
|
||||
func checkDir(t testing.TB, dir string, want []string) {
|
||||
t.Helper()
|
||||
|
||||
var matches []string
|
||||
for path, err := range walkDir(dir) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("found %s", path)
|
||||
if path == "./" {
|
||||
continue
|
||||
}
|
||||
path = filepath.ToSlash(path)
|
||||
matches = append(matches, path)
|
||||
}
|
||||
|
||||
diff.Test(t, t.Errorf, matches, want)
|
||||
}
|
||||
|
||||
var errStop = errors.New("stop")
|
||||
|
||||
func walkDir(dir string) iter.Seq2[string, error] {
|
||||
return func(yield func(string, error) bool) {
|
||||
err := filepath.WalkDir(dir, func(path string, info os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
path, err = filepath.Rel(dir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
path = filepath.ToSlash(path)
|
||||
if info.IsDir() {
|
||||
path += "/"
|
||||
}
|
||||
if !yield(path, nil) {
|
||||
return errStop
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if !errors.Is(err, errStop) && err != nil {
|
||||
yield("", err)
|
||||
}
|
||||
}
|
||||
}
|
31
x/client/ollama/apitype/apitype.go
Normal file
31
x/client/ollama/apitype/apitype.go
Normal file
@ -0,0 +1,31 @@
|
||||
package apitype
|
||||
|
||||
import "time"
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
Ref string `json:"ref"`
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
ModifiedAt int64 `json:"modified"`
|
||||
}
|
||||
|
||||
func (m Model) Modifed() time.Time {
|
||||
return time.Unix(0, m.ModifiedAt)
|
||||
}
|
||||
|
||||
type PushRequest struct {
|
||||
Name string `json:"name"` // Ref is the official term, "name" is for backward compatibility with exiting clients.
|
||||
Insecure bool `json:"insecure"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type PushStatus struct {
|
||||
Status string `json:"status"`
|
||||
Digest string `json:"digest"`
|
||||
Total int64 `json:"total"`
|
||||
}
|
162
x/client/ollama/ollama.go
Normal file
162
x/client/ollama/ollama.go
Normal file
@ -0,0 +1,162 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"bllamo.com/client/ollama/apitype"
|
||||
"bllamo.com/types/empty"
|
||||
)
|
||||
|
||||
// TODO(bmizerany): PROGRESS INDICATORS!!!!
|
||||
|
||||
const DefaultBaseURL = "http://localhost:11434"
|
||||
|
||||
var envBaseURL = cmp.Or(os.Getenv("OLLAMA_BASE_URL"), DefaultBaseURL)
|
||||
|
||||
// Default returns a new client with the default base URL.
|
||||
func Default() *Client {
|
||||
return &Client{BaseURL: envBaseURL}
|
||||
}
|
||||
|
||||
// I_Acknowledge_This_API_Is_Under_Development is a flag that must be set to
|
||||
// true for any instance of Client to work.
|
||||
var I_Acknowledge_This_API_Is_Under_Development bool
|
||||
|
||||
// Client is a client for the Ollama API.
|
||||
type Client struct {
|
||||
// BaseURL is the base URL of the Ollama API.
|
||||
BaseURL string
|
||||
|
||||
HTTPClient *http.Client // The HTTP client to use. If nil, http.DefaultClient is used.
|
||||
}
|
||||
|
||||
// Build requests the remote Ollama service to build a model. It uploads any
|
||||
// source files the server needs.
|
||||
func (c *Client) Build(ctx context.Context, ref string, modelfile []byte, source fs.FS) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// Push requests the remote Ollama service to push a model to the server.
|
||||
func (c *Client) Push(ctx context.Context, ref string) error {
|
||||
_, err := Do[empty.Message](ctx, c, "POST", "/v1/push", apitype.PushRequest{Name: ref})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) Pull(ctx context.Context, ref string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) List(ctx context.Context) iter.Seq2[apitype.Model, error] {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Show(ctx context.Context, ref string) (*apitype.Model, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Remove(ctx context.Context, ref string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Copy(ctx context.Context, dstRef, srcRef string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Run(ctx context.Context, ref string, messages []apitype.Message) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
// Status is the HTTP status code returned by the server.
|
||||
Status int `json:"status"`
|
||||
|
||||
// Code specifies a machine readable code indicating the class of
|
||||
// error this error is. See http://docs.ollama.com/errors for a full
|
||||
// list of error codes.
|
||||
Code string `json:"code"`
|
||||
|
||||
// Message is a humage readable message that describes the error. It
|
||||
// may change across versions of the API, so it should not be used for
|
||||
// programmatic decisions.
|
||||
Message string `json:"message,omitempty"`
|
||||
|
||||
// Field is the field in the request that caused the error, if any.
|
||||
Field string `json:"field,omitempty"`
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
var b strings.Builder
|
||||
b.WriteString("ollama: ")
|
||||
b.WriteString(e.Code)
|
||||
if e.Message != "" {
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.Message)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Do encodes in and sends it in a request to the Ollama server and decodes
|
||||
// the response into Res, or an error response (non-2xx) into an *Error, or
|
||||
// any error encounted decoding the response.
|
||||
func Do[Res any](ctx context.Context, c *Client, method, path string, in any) (*Res, error) {
|
||||
var body bytes.Buffer
|
||||
// TODO(bmizerany): pool and reuse this buffer AND the encoder
|
||||
if err := encodeJSON(&body, in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
urlStr := c.BaseURL + path
|
||||
req, err := http.NewRequestWithContext(ctx, method, urlStr, &body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hc := cmp.Or(c.HTTPClient, http.DefaultClient)
|
||||
res, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
var buf bytes.Buffer
|
||||
body := io.TeeReader(res.Body, &buf)
|
||||
e, err := decodeJSON[Error](body)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("ollama: invalid error response from server (status %d): %q", res.StatusCode, buf.String())
|
||||
return nil, err
|
||||
}
|
||||
return nil, e
|
||||
}
|
||||
|
||||
return decodeJSON[Res](res.Body)
|
||||
}
|
||||
|
||||
// decodeJSON decodes JSON from r into a new value of type T.
|
||||
//
|
||||
// NOTE: This is (and encodeJSON) are copies and paste from oweb.go, please
|
||||
// do not try and consolidate so we can keep ollama/client free from
|
||||
// dependencies which are moving targets and not pulling enough weight to
|
||||
// justify their inclusion.
|
||||
func decodeJSON[T any](r io.Reader) (*T, error) {
|
||||
var v T
|
||||
if err := json.NewDecoder(r).Decode(&v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
// NOTE: see NOT above decodeJSON
|
||||
func encodeJSON(w io.Writer, v any) error {
|
||||
// TODO(bmizerany): pool and reuse encoder
|
||||
return json.NewEncoder(w).Encode(v)
|
||||
}
|
100
x/cmd/bllamo/bllamo.go
Normal file
100
x/cmd/bllamo/bllamo.go
Normal file
@ -0,0 +1,100 @@
|
||||
// Bllamo is a (new) tool for managing Ollama models.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// bllamo <command> [arguments]
|
||||
//
|
||||
// The commands are:
|
||||
//
|
||||
// build build a model from a Modelfile
|
||||
// list list all models
|
||||
// push push a model from an ollama registry
|
||||
// pull pull a model from an ollama registry
|
||||
// delete delete a model from an ollama registry
|
||||
// help display help for a command
|
||||
package main
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"bllamo.com/api"
|
||||
"bllamo.com/build"
|
||||
"bllamo.com/client/ollama"
|
||||
"bllamo.com/registry"
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
args := flag.Args()
|
||||
if len(args) < 1 {
|
||||
fmt.Fprintln(os.Stderr, "bllamo: no command provided")
|
||||
os.Exit(2)
|
||||
}
|
||||
if err := Main(flag.Args()...); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
var TODOUsage = fmt.Errorf("TODO: usage")
|
||||
|
||||
var commands = map[string]func(ctx context.Context, args ...string) error{
|
||||
"build": cmdBuild,
|
||||
"push": cmdPush,
|
||||
"serve": cmdServe,
|
||||
"registry": cmdRegistry,
|
||||
}
|
||||
|
||||
// Main is the entry point for the blammo command.
|
||||
func Main(args ...string) error {
|
||||
cmd := args[0]
|
||||
args = args[1:]
|
||||
if f, ok := commands[cmd]; ok {
|
||||
ctx := context.TODO()
|
||||
return f(ctx, args...)
|
||||
}
|
||||
return fmt.Errorf("blammo: unknown command %q", cmd)
|
||||
}
|
||||
|
||||
func cmdBuild(ctx context.Context, args ...string) error {
|
||||
var v struct {
|
||||
Modelfile string `flag:"f,the Modelfile to use"`
|
||||
}
|
||||
|
||||
fs := readFlags("build", args, &v)
|
||||
if fs.NArg() != 1 {
|
||||
return TODOUsage
|
||||
}
|
||||
|
||||
modelfile, err := os.ReadFile(cmp.Or(v.Modelfile, "Modelfile"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ollama.Default().Build(ctx, args[0], modelfile, os.DirFS("."))
|
||||
}
|
||||
|
||||
func cmdRegistry(_ context.Context, _ ...string) error {
|
||||
var s registry.Server
|
||||
return http.ListenAndServe(":8888", &s)
|
||||
}
|
||||
|
||||
func cmdServe(ctx context.Context, args ...string) error {
|
||||
bs, err := build.Open("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return http.ListenAndServe(":11434", &api.Server{Build: bs})
|
||||
}
|
||||
|
||||
func cmdPush(ctx context.Context, args ...string) error {
|
||||
fs := readFlags("push", args, nil)
|
||||
if fs.NArg() != 1 {
|
||||
return TODOUsage
|
||||
}
|
||||
return ollama.Default().Push(ctx, fs.Arg(0))
|
||||
}
|
59
x/cmd/bllamo/flags.go
Normal file
59
x/cmd/bllamo/flags.go
Normal file
@ -0,0 +1,59 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// parseArgs parses the provided args using a flag.FlagSet that is
|
||||
// dynamically build using reflection for the provided type. The type fields
|
||||
// that have a "flag" tag are used to build the flags. The flag tag should
|
||||
// include either a ('-'). Example usage:
|
||||
//
|
||||
// func main() {
|
||||
// var flags struct {
|
||||
// Modelfile string `flag:"f,path to the Modelfile"`
|
||||
// }
|
||||
//
|
||||
// fs := readFlags(os.Args[1:], &flags)
|
||||
// fs.Parse(os.Args[1:])
|
||||
// }
|
||||
func readFlags(name string, args []string, v any) *flag.FlagSet {
|
||||
fs := flag.NewFlagSet(name, flag.ExitOnError)
|
||||
defer fs.Parse(args)
|
||||
if v == nil {
|
||||
return fs
|
||||
}
|
||||
|
||||
for i := 0; i < reflect.ValueOf(v).NumField(); i++ {
|
||||
f := reflect.ValueOf(v).Field(i)
|
||||
if !f.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
tag := f.Type().Field(i).Tag.Get("flag")
|
||||
if tag == "" {
|
||||
continue
|
||||
}
|
||||
var name, usage string
|
||||
if i := strings.Index(tag, ","); i != -1 {
|
||||
name = tag[:i]
|
||||
usage = tag[i+1:]
|
||||
} else {
|
||||
name = tag
|
||||
}
|
||||
|
||||
// TODO(bmizerany): add more types as needed
|
||||
switch f.Kind() {
|
||||
case reflect.String:
|
||||
fs.StringVar(f.Addr().Interface().(*string), name, "", usage)
|
||||
case reflect.Bool:
|
||||
fs.BoolVar(f.Addr().Interface().(*bool), name, false, usage)
|
||||
default:
|
||||
panic(fmt.Sprintf("unsupported type %v", f.Kind()))
|
||||
}
|
||||
}
|
||||
return fs
|
||||
}
|
97
x/cmd/gguf/gguf.go
Normal file
97
x/cmd/gguf/gguf.go
Normal file
@ -0,0 +1,97 @@
|
||||
// Gguf is a tool for learning about GGUF files.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// gguf [flags] <file>
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"text/tabwriter"
|
||||
|
||||
"bllamo.com/encoding/gguf"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := Main(os.Stdout, os.Args[1:]...); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Main(stdout io.Writer, args ...string) error {
|
||||
fs := flag.NewFlagSet("gguf", flag.ExitOnError)
|
||||
flagGPU := fs.Uint64("gpu", 0, "use N bytes of GPU memory (default is 0)")
|
||||
|
||||
fs.Usage = func() {
|
||||
io.WriteString(stdout, "Gguf is a tool for learning about GGUF files.\n")
|
||||
io.WriteString(stdout, "\n")
|
||||
io.WriteString(stdout, "Usage:\n")
|
||||
io.WriteString(stdout, "\n")
|
||||
io.WriteString(stdout, "\tgguf [flags] <file>\n")
|
||||
io.WriteString(stdout, "\n")
|
||||
var numFlags int
|
||||
fs.VisitAll(func(*flag.Flag) { numFlags++ })
|
||||
if numFlags > 0 {
|
||||
io.WriteString(stdout, "Flags:\n")
|
||||
fs.PrintDefaults()
|
||||
}
|
||||
}
|
||||
fs.Parse(args)
|
||||
|
||||
if fs.NArg() != 1 {
|
||||
fs.Usage()
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
file := fs.Arg(0)
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
g, err := gguf.ReadFile(f)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tw := tabwriter.NewWriter(stdout, 0, 2, 2, ' ', 0)
|
||||
defer tw.Flush()
|
||||
|
||||
fmt.Fprintf(tw, "version:\t%d\n", g.Version())
|
||||
|
||||
for m, err := range g.Metadata {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if len(m.Values) > 5 {
|
||||
fmt.Fprintf(tw, "meta:\t%q: ... (%d values)\n", m.Key, len(m.Values))
|
||||
} else {
|
||||
fmt.Fprintf(tw, "meta:\t%q: %v\n", m.Key, m.Values)
|
||||
}
|
||||
}
|
||||
|
||||
var i int
|
||||
var totalLayerBytes uint64
|
||||
var offGPU bool
|
||||
for t, err := range g.Tensors {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
totalLayerBytes += t.Size
|
||||
if totalLayerBytes > *flagGPU {
|
||||
offGPU = true
|
||||
}
|
||||
|
||||
const msg = "tensor (layer %000d):\t%q\t%s\tdims=%v\toffset=%d\tsize=%d\tonGPU=%v\n"
|
||||
fmt.Fprintf(tw, msg, i, t.Name, t.Type, t.Dimensions, t.Offset, t.Size, !offGPU)
|
||||
|
||||
i++
|
||||
}
|
||||
return nil
|
||||
}
|
376
x/encoding/gguf/gguf.go
Normal file
376
x/encoding/gguf/gguf.go
Normal file
@ -0,0 +1,376 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"bllamo.com/types/structs"
|
||||
)
|
||||
|
||||
// TODO(bmizerany): determine a more reasonable value for MaxDimensions.
|
||||
|
||||
// MaxDimensions is the maximum number of dimensions a tensor can have.
|
||||
const MaxDimensions uint32 = 1e6
|
||||
|
||||
// Errors
|
||||
var (
|
||||
// ErrBadMagic is returned when the magic bytes at the start of the
|
||||
// file. This is useful for detecting if the file is not a gguf
|
||||
// file.
|
||||
ErrBadMagic = errors.New("gguf: bad magic")
|
||||
|
||||
ErrUnsupportedVersion = errors.New("gguf: unsupported version")
|
||||
ErrMangled = errors.New("gguf: mangled data")
|
||||
)
|
||||
|
||||
type Type uint32
|
||||
|
||||
const (
|
||||
TypeF32 Type = 0
|
||||
TypeF16 Type = 1
|
||||
TypeQ4_0 Type = 2
|
||||
TypeQ4_1 Type = 3
|
||||
TypeQ5_0 Type = 6
|
||||
TypeQ5_1 Type = 7
|
||||
TypeQ8_0 Type = 8
|
||||
TypeQ8_1 Type = 9
|
||||
TypeQ2_K Type = 10
|
||||
TypeQ3_K Type = 11
|
||||
TypeQ4_K Type = 12
|
||||
TypeQ5_K Type = 13
|
||||
TypeQ6_K Type = 14
|
||||
TypeQ8_K Type = 15
|
||||
TypeI8 Type = 16
|
||||
TypeI16 Type = 17
|
||||
TypeI32 Type = 18
|
||||
TypeCount Type = 19
|
||||
)
|
||||
|
||||
var typeNames = map[Type]string{
|
||||
TypeF32: "F32",
|
||||
TypeF16: "F16",
|
||||
TypeQ4_0: "Q4_0",
|
||||
TypeQ4_1: "Q4_1",
|
||||
TypeQ5_0: "Q5_0",
|
||||
TypeQ5_1: "Q5_1",
|
||||
TypeQ8_0: "Q8_0",
|
||||
TypeQ8_1: "Q8_1",
|
||||
TypeQ2_K: "Q2_K",
|
||||
TypeQ3_K: "Q3_K",
|
||||
TypeQ4_K: "Q4_K",
|
||||
TypeQ5_K: "Q5_K",
|
||||
TypeQ6_K: "Q6_K",
|
||||
TypeQ8_K: "Q8_K",
|
||||
TypeI8: "I8",
|
||||
TypeI16: "I16",
|
||||
TypeI32: "I32",
|
||||
TypeCount: "COUNT",
|
||||
}
|
||||
|
||||
func (t Type) String() string {
|
||||
if name := typeNames[t]; name != "" {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("(!unknown_type %d!)", t)
|
||||
}
|
||||
|
||||
// ValueType is the type of a metadata value.
|
||||
type ValueType uint32
|
||||
|
||||
func (t ValueType) String() string {
|
||||
if name := metaTypeNames[t]; name != "" {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("(!unknown_value_type %d!)", t)
|
||||
}
|
||||
|
||||
const (
|
||||
ValueTypeUint8 ValueType = 0
|
||||
ValueTypeInt8 ValueType = 1
|
||||
ValueTypeUint16 ValueType = 2
|
||||
ValueTypeInt16 ValueType = 3
|
||||
ValueTypeUint32 ValueType = 4
|
||||
ValueTypeInt32 ValueType = 5
|
||||
ValueTypeFloat32 ValueType = 6
|
||||
ValueTypeBool ValueType = 7
|
||||
ValueTypeString ValueType = 8
|
||||
ValueTypeArray ValueType = 9
|
||||
ValueTypeUint64 ValueType = 10
|
||||
ValueTypeInt64 ValueType = 11
|
||||
ValueTypeFloat64 ValueType = 12
|
||||
)
|
||||
|
||||
var metaTypeNames = map[ValueType]string{
|
||||
ValueTypeUint8: "uint8",
|
||||
ValueTypeInt8: "int8",
|
||||
ValueTypeUint16: "uint16",
|
||||
ValueTypeInt16: "int16",
|
||||
ValueTypeUint32: "uint32",
|
||||
ValueTypeInt32: "int32",
|
||||
ValueTypeFloat32: "float32",
|
||||
ValueTypeBool: "bool",
|
||||
ValueTypeString: "string",
|
||||
ValueTypeArray: "array",
|
||||
ValueTypeUint64: "uint64",
|
||||
ValueTypeInt64: "int64",
|
||||
ValueTypeFloat64: "float64",
|
||||
}
|
||||
|
||||
type TensorInfo struct {
|
||||
Name string
|
||||
Dimensions []uint64
|
||||
Type Type
|
||||
Offset uint64
|
||||
Size uint64
|
||||
}
|
||||
|
||||
type MetaValue struct {
|
||||
Type ValueType
|
||||
Value []byte
|
||||
}
|
||||
|
||||
func (v MetaValue) String() string {
|
||||
var b strings.Builder
|
||||
b.WriteString(v.Type.String())
|
||||
b.WriteString("(")
|
||||
switch v.Type {
|
||||
case ValueTypeArray:
|
||||
b.WriteString("[...]")
|
||||
case ValueTypeString:
|
||||
b.WriteString(strconv.Quote(string(v.Value)))
|
||||
case ValueTypeBool:
|
||||
if len(v.Value) == 0 {
|
||||
b.WriteString("(!invalid bool)")
|
||||
}
|
||||
switch v.Value[0] {
|
||||
case 0:
|
||||
b.WriteString("false")
|
||||
case 1:
|
||||
b.WriteString("true")
|
||||
default:
|
||||
b.WriteString("!invalid bool")
|
||||
}
|
||||
case ValueTypeUint8, ValueTypeInt8, ValueTypeUint16, ValueTypeInt16, ValueTypeUint32, ValueTypeInt32, ValueTypeUint64, ValueTypeInt64, ValueTypeFloat32, ValueTypeFloat64:
|
||||
var buf [8]byte
|
||||
if len(v.Value) < 8 {
|
||||
copy(buf[:], v.Value)
|
||||
}
|
||||
fmt.Fprintf(&b, "%v", binary.LittleEndian.Uint64(buf[:]))
|
||||
default:
|
||||
fmt.Fprintf(&b, "%v", v.Value)
|
||||
}
|
||||
b.WriteString(")")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type MetaEntry struct {
|
||||
Key string
|
||||
Type ValueType
|
||||
Values []MetaValue
|
||||
}
|
||||
|
||||
func (e MetaEntry) String() string {
|
||||
if len(e.Values) == 0 {
|
||||
return ""
|
||||
}
|
||||
return string(e.Values[0].Value)
|
||||
}
|
||||
|
||||
func (e MetaEntry) Uint32() uint32 {
|
||||
if len(e.Values) == 0 {
|
||||
return 0
|
||||
}
|
||||
return binary.LittleEndian.Uint32(e.Values[0].Value)
|
||||
}
|
||||
|
||||
func (e MetaEntry) FileType() Type {
|
||||
if len(e.Values) == 0 {
|
||||
return TypeCount
|
||||
}
|
||||
return Type(e.Uint32())
|
||||
}
|
||||
|
||||
func (e MetaEntry) GoString() string {
|
||||
var b strings.Builder
|
||||
b.WriteString(e.Key)
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.Type.String())
|
||||
b.WriteString("(")
|
||||
for i, v := range e.Values {
|
||||
if i > 0 {
|
||||
b.WriteString(", ")
|
||||
}
|
||||
b.WriteString(v.String())
|
||||
}
|
||||
b.WriteString(")")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type Info struct {
|
||||
_ structs.Incomparable // prevent comparison of Info values so we can change the implementation later
|
||||
|
||||
Version int
|
||||
FileType Type
|
||||
}
|
||||
|
||||
func Stat(path string) (Info, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
defer f.Close()
|
||||
return StatReader(f)
|
||||
}
|
||||
|
||||
// StatReader reads the header information from r and returns an Info
|
||||
// struct with the version and file type.
|
||||
//
|
||||
// It returns an error if any.
|
||||
//
|
||||
// As a special case, it returns ErrBadMagic if the file does not start with
|
||||
// the magic bytes. This can be used to detect if the file is not a GGUF
|
||||
// file.
|
||||
func StatReader(r io.ReadSeeker) (Info, error) {
|
||||
if _, err := r.Seek(0, 0); err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
f, err := ReadFile(r)
|
||||
if err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
info := Info{Version: f.Version()}
|
||||
for m, err := range f.Metadata {
|
||||
if err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
if m.Key == "general.file_type" {
|
||||
if m.Type != ValueTypeUint32 {
|
||||
return Info{}, fmt.Errorf("unexpected type for metadata key %q: %v, want %v", m.Key, m.Type, ValueTypeUint32)
|
||||
}
|
||||
info.FileType = m.FileType()
|
||||
}
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
|
||||
type File struct {
|
||||
version uint32
|
||||
numMetaValues uint64
|
||||
numTensors uint64
|
||||
|
||||
gr *ggufReader
|
||||
}
|
||||
|
||||
// ReadFile reads header information from r and returns a File, ready for
|
||||
// iteration over Metadata and Tensors.
|
||||
func ReadFile(r io.Reader) (*File, error) {
|
||||
f, err := readFile(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *File) Version() int {
|
||||
return int(f.version)
|
||||
}
|
||||
|
||||
// Metadata iterates over the metadata in the file. It must be exhausted
|
||||
// before calling Tensors.
|
||||
//
|
||||
// It is not resumable.
|
||||
func (f *File) Metadata(yield func(MetaEntry, error) bool) {
|
||||
var n int
|
||||
for range f.numMetaValues {
|
||||
meta, err := f.gr.readMetaEntry()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error reading metadata entry %d: %w", n, err)
|
||||
yield(MetaEntry{}, err)
|
||||
return
|
||||
}
|
||||
if !yield(meta, nil) {
|
||||
return
|
||||
}
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
// Tensors iterates over the tensors in the file. It must only be called
|
||||
// after exhausting the metadata iterator.
|
||||
//
|
||||
// It is not resumable.
|
||||
func (f *File) Tensors(yield func(TensorInfo, error) bool) {
|
||||
var last TensorInfo
|
||||
for range f.numTensors {
|
||||
info, err := f.gr.readTensorInfo()
|
||||
|
||||
// If the last tensor had a valid offset, yield it.
|
||||
//
|
||||
// NOTE: No tensor should have an offset of 0 because the
|
||||
// offset is the start of the tensor data which is always
|
||||
// afer the magic bytes, version, numMetaValues, and
|
||||
// numTensors, which MUST all be non-zero bytes as per the
|
||||
// GGUF spec.
|
||||
if last.Offset > 0 {
|
||||
if !yield(last, err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
yield(TensorInfo{}, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Tensor data does not include size, so we need to
|
||||
// calculate it based on the offset of the previous tensor
|
||||
// offset to the current.
|
||||
offset0 := last.Offset
|
||||
last = info
|
||||
last.Size = info.Offset - offset0
|
||||
}
|
||||
if last.Offset > 0 {
|
||||
yield(last, nil)
|
||||
}
|
||||
}
|
||||
|
||||
var magicBytes = []byte{0x47, 0x47, 0x55, 0x46}
|
||||
|
||||
func readFile(r io.Reader) (*File, error) {
|
||||
gr := &ggufReader{r: &reader{r: r}}
|
||||
magic, err := gr.next(4)
|
||||
if err != nil {
|
||||
return nil, errors.Join(err, ErrBadMagic)
|
||||
}
|
||||
if !bytes.Equal(magic, magicBytes) {
|
||||
return nil, ErrBadMagic
|
||||
}
|
||||
version, err := gr.readUint32()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if version != 3 {
|
||||
return nil, fmt.Errorf("%w: %d", ErrUnsupportedVersion, version)
|
||||
}
|
||||
numTensors, err := gr.readUint64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
numMetaValues, err := gr.readUint64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info := &File{
|
||||
version: version,
|
||||
|
||||
numMetaValues: numMetaValues,
|
||||
numTensors: numTensors,
|
||||
gr: gr,
|
||||
}
|
||||
return info, nil
|
||||
}
|
345
x/encoding/gguf/gguf_test.go
Normal file
345
x/encoding/gguf/gguf_test.go
Normal file
@ -0,0 +1,345 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"kr.dev/diff"
|
||||
)
|
||||
|
||||
func TestStat(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
data string
|
||||
wantInfo Info
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
wantErr: ErrBadMagic,
|
||||
},
|
||||
{
|
||||
name: "bad magic",
|
||||
data: "\xBB\xAA\xDD\x00",
|
||||
wantErr: ErrBadMagic,
|
||||
},
|
||||
{
|
||||
name: "bad version",
|
||||
data: string(magicBytes) +
|
||||
"\x02\x00\x00\x00" + // version
|
||||
"",
|
||||
wantErr: ErrUnsupportedVersion,
|
||||
},
|
||||
{
|
||||
name: "valid general.file_type",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// general.file_type key
|
||||
"\x11\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"general.file_type" + // key
|
||||
"\x04\x00\x00\x00" + // type (uint32)
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // uint32 value
|
||||
"",
|
||||
wantInfo: Info{
|
||||
Version: 3,
|
||||
FileType: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
info, err := StatReader(strings.NewReader(tt.data))
|
||||
if tt.wantErr != nil {
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("err = %v; want %q", err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
diff.Test(t, t.Errorf, info, tt.wantInfo)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadInfo(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
data string
|
||||
|
||||
wantMeta []MetaEntry
|
||||
wantTensor []TensorInfo
|
||||
wantReadErr error
|
||||
wantMetaErr error
|
||||
wantTensorErr error
|
||||
wantInfo Info
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
wantReadErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
name: "bad magic",
|
||||
data: "\xBB\xAA\xDD\x00",
|
||||
wantReadErr: ErrBadMagic,
|
||||
},
|
||||
{
|
||||
name: "bad version",
|
||||
data: string(magicBytes) +
|
||||
"\x02\x00\x00\x00" + // version
|
||||
"",
|
||||
wantReadErr: ErrUnsupportedVersion,
|
||||
},
|
||||
{
|
||||
name: "no metadata or tensors",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"",
|
||||
wantReadErr: nil,
|
||||
},
|
||||
{
|
||||
name: "good metadata",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"K" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"VV" + // string value
|
||||
"",
|
||||
wantMeta: []MetaEntry{
|
||||
{Key: "K", Type: ValueTypeString, Values: []MetaValue{{Type: ValueTypeString, Value: []byte("VV")}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "good metadata with multiple values",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// MetaEntry 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"x" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"XX" + // string value
|
||||
|
||||
// MetaEntry 2
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"y" + // key
|
||||
"\x04\x00\x00\x00" + // type (uint32)
|
||||
"\x99\x88\x77\x66" + // uint32 value
|
||||
"",
|
||||
wantMeta: []MetaEntry{
|
||||
{Key: "x", Type: ValueTypeString, Values: []MetaValue{{
|
||||
Type: ValueTypeString,
|
||||
Value: []byte("XX"),
|
||||
}}},
|
||||
{Key: "y", Type: ValueTypeUint32, Values: []MetaValue{{
|
||||
Type: ValueTypeUint32,
|
||||
Value: []byte{0x99, 0x88, 0x77, 0x66},
|
||||
}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "negative string length in meta key",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF" + // key length
|
||||
"K" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"VV" + // string value
|
||||
"",
|
||||
wantMetaErr: ErrMangled,
|
||||
},
|
||||
|
||||
// Tensor tests
|
||||
{
|
||||
name: "good tensor",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// Tensor 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
|
||||
// dimensions
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x00\x01\x00\x00\x00\x00\x00\x00" + // offset
|
||||
"",
|
||||
wantTensor: []TensorInfo{
|
||||
{
|
||||
Name: "t",
|
||||
Dimensions: []uint64{1},
|
||||
Type: TypeQ4_1,
|
||||
Offset: 256,
|
||||
Size: 256,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "too many dimensions",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// Tensor 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
|
||||
"\x00\x00\x00\x01" + // dimensions length
|
||||
"",
|
||||
wantTensorErr: ErrMangled,
|
||||
},
|
||||
{
|
||||
name: "size computed",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// Tensor 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x00\x01\x00\x00\x00\x00\x00\x00" + // offset
|
||||
|
||||
// Tensor 2
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x00\x03\x00\x00\x00\x00\x00\x00" + // offset
|
||||
"",
|
||||
wantTensor: []TensorInfo{
|
||||
{
|
||||
Name: "t",
|
||||
Dimensions: []uint64{1},
|
||||
Type: TypeQ4_1,
|
||||
Offset: 256,
|
||||
Size: 256,
|
||||
},
|
||||
{
|
||||
Name: "t",
|
||||
Dimensions: []uint64{1},
|
||||
Type: TypeQ4_1,
|
||||
Offset: 768,
|
||||
Size: 512,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f, err := ReadFile(strings.NewReader(tt.data))
|
||||
if err != nil {
|
||||
if !errors.Is(err, tt.wantReadErr) {
|
||||
t.Fatalf("unexpected ReadFile error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var got []MetaEntry
|
||||
for meta, err := range f.Metadata {
|
||||
if !errors.Is(err, tt.wantMetaErr) {
|
||||
t.Fatalf("err = %v; want %v", err, ErrMangled)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
got = append(got, meta)
|
||||
}
|
||||
diff.Test(t, t.Errorf, got, tt.wantMeta)
|
||||
|
||||
var gotT []TensorInfo
|
||||
for tinfo, err := range f.Tensors {
|
||||
if !errors.Is(err, tt.wantTensorErr) {
|
||||
t.Fatalf("err = %v; want %v", err, tt.wantTensorErr)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
gotT = append(gotT, tinfo)
|
||||
}
|
||||
diff.Test(t, t.Errorf, gotT, tt.wantTensor)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzReadInfo(f *testing.F) {
|
||||
f.Add(string(magicBytes))
|
||||
f.Add(string(magicBytes) +
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"")
|
||||
f.Add(string(magicBytes) +
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"K" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"VV" + // string value
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x05\x00\x00\x00\x00\x00\x00\x00" + // offset
|
||||
"")
|
||||
|
||||
f.Fuzz(func(t *testing.T, data string) {
|
||||
gf, err := ReadFile(strings.NewReader(data))
|
||||
if err != nil {
|
||||
t.Logf("ReadFile error: %v", err)
|
||||
t.Skip()
|
||||
}
|
||||
for _, err := range gf.Metadata {
|
||||
if err != nil {
|
||||
t.Logf("metadata error: %v", err)
|
||||
t.Skip()
|
||||
}
|
||||
}
|
||||
for tinfo, err := range gf.Tensors {
|
||||
if err != nil {
|
||||
t.Logf("tensor error: %v", err)
|
||||
t.Skip()
|
||||
}
|
||||
if tinfo.Offset <= 0 {
|
||||
t.Logf("invalid tensor offset: %+v", t)
|
||||
t.Skip()
|
||||
}
|
||||
if tinfo.Size <= 0 {
|
||||
t.Logf("invalid tensor size: %+v", t)
|
||||
t.Skip()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
195
x/encoding/gguf/ggufio.go
Normal file
195
x/encoding/gguf/ggufio.go
Normal file
@ -0,0 +1,195 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
)
|
||||
|
||||
type ggufReader struct {
|
||||
r *reader
|
||||
n int
|
||||
}
|
||||
|
||||
func (r *ggufReader) readMetaEntry() (MetaEntry, error) {
|
||||
key, err := r.readString()
|
||||
if err != nil {
|
||||
return MetaEntry{}, err
|
||||
}
|
||||
typ, err := r.readValueType()
|
||||
if err != nil {
|
||||
return MetaEntry{}, err
|
||||
}
|
||||
var values []MetaValue
|
||||
for v, err := range r.readMetaValues(typ) {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("(key=%q type=%s): %w", key, typ, err)
|
||||
return MetaEntry{}, err
|
||||
}
|
||||
values = append(values, v)
|
||||
}
|
||||
return MetaEntry{
|
||||
Key: string(key),
|
||||
Type: typ,
|
||||
Values: values,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readMetaValue(typ ValueType) (MetaValue, error) {
|
||||
var value []byte
|
||||
var err error
|
||||
switch typ {
|
||||
case ValueTypeUint8, ValueTypeInt8:
|
||||
value, err = r.next(1)
|
||||
case ValueTypeUint16, ValueTypeInt16:
|
||||
value, err = r.next(2)
|
||||
case ValueTypeUint32, ValueTypeInt32, ValueTypeFloat32:
|
||||
value, err = r.next(4)
|
||||
case ValueTypeUint64, ValueTypeInt64, ValueTypeFloat64:
|
||||
value, err = r.next(8)
|
||||
case ValueTypeBool:
|
||||
value, err = r.next(1)
|
||||
case ValueTypeString:
|
||||
value, err = r.readString()
|
||||
case ValueTypeArray:
|
||||
err = fmt.Errorf("nested arrays are not supported")
|
||||
default:
|
||||
err = fmt.Errorf("unsupported metadata type: %d", typ)
|
||||
}
|
||||
if err != nil {
|
||||
return MetaValue{}, err
|
||||
}
|
||||
return MetaValue{
|
||||
Type: typ,
|
||||
Value: bytes.Clone(value),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readMetaValues(typ ValueType) iter.Seq2[MetaValue, error] {
|
||||
return func(yield func(MetaValue, error) bool) {
|
||||
if typ == ValueTypeArray {
|
||||
atyp, err := r.readValueType()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid type: %w", err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
n, err := r.readUint64()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid length: %w", err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
for i := range n {
|
||||
v, err := r.readMetaValue(atyp)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid entry (type=%s) %d: %w", atyp, i, err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
if !yield(v, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
v, err := r.readMetaValue(typ)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error reading metadata value: %w", err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
yield(v, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ggufReader) readValueType() (ValueType, error) {
|
||||
typ, err := r.readUint32()
|
||||
return ValueType(typ), err
|
||||
}
|
||||
|
||||
func (r *ggufReader) readTensorInfo() (TensorInfo, error) {
|
||||
name, err := r.readString()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
numDimensions, err := r.readUint32()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
if numDimensions > MaxDimensions {
|
||||
return TensorInfo{}, fmt.Errorf("%w: dimensions length (%d) exceeds %d", ErrMangled, numDimensions, MaxDimensions)
|
||||
}
|
||||
|
||||
dims := make([]uint64, numDimensions)
|
||||
for i := range dims {
|
||||
d, err := r.readUint64()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
dims[i] = d
|
||||
}
|
||||
typ, err := r.readUint32()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
offset, err := r.readUint64()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
// TODO(bmizerany): check offset is multiple of ALIGNMENT
|
||||
return TensorInfo{
|
||||
Name: string(name),
|
||||
Dimensions: dims,
|
||||
Type: Type(typ),
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) next(n int) ([]byte, error) {
|
||||
if n < 0 {
|
||||
return nil, errors.Join(fmt.Errorf("invalid read length: %d", n), ErrMangled)
|
||||
}
|
||||
w := r.r.window()
|
||||
for len(w) < n {
|
||||
if r.r.extend() == 0 {
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
w = r.r.window()
|
||||
}
|
||||
r.r.release(n)
|
||||
r.n += n
|
||||
return w[:n], nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readString() ([]byte, error) {
|
||||
n, err := r.readUint64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO(bmizerany): limit max string length
|
||||
return r.next(int(n))
|
||||
}
|
||||
|
||||
func (r *ggufReader) readUint32() (uint32, error) {
|
||||
b, err := r.next(4)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n := binary.LittleEndian.Uint32(b)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readUint64() (uint64, error) {
|
||||
b, err := r.next(8)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n := binary.LittleEndian.Uint64(b)
|
||||
return n, nil
|
||||
}
|
70
x/encoding/gguf/reader.go
Normal file
70
x/encoding/gguf/reader.go
Normal file
@ -0,0 +1,70 @@
|
||||
package gguf
|
||||
|
||||
import "io"
|
||||
|
||||
// A reader implements a sliding window over an io.Reader.
|
||||
type reader struct {
|
||||
data []byte
|
||||
offset int
|
||||
r io.Reader
|
||||
err error
|
||||
}
|
||||
|
||||
// release discards n bytes from the front of the window.
|
||||
func (b *reader) release(n int) {
|
||||
b.offset += n
|
||||
}
|
||||
|
||||
// window returns the current window.
|
||||
// The window is invalidated by calls to release or extend.
|
||||
func (b *reader) window() []byte {
|
||||
return b.data[b.offset:]
|
||||
}
|
||||
|
||||
// tuning constants for byteReader.extend.
|
||||
const (
|
||||
newBufferSize = 8 << 10
|
||||
minReadSize = newBufferSize >> 2
|
||||
)
|
||||
|
||||
// extend extends the window with data from the underlying reader.
|
||||
func (b *reader) extend() int {
|
||||
if b.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
remaining := len(b.data) - b.offset
|
||||
if remaining == 0 {
|
||||
b.data = b.data[:0]
|
||||
b.offset = 0
|
||||
}
|
||||
if cap(b.data)-len(b.data) >= minReadSize {
|
||||
// nothing to do, enough space exists between len and cap.
|
||||
} else if cap(b.data)-remaining >= minReadSize {
|
||||
// buffer has enough space if we move the data to the front.
|
||||
b.compact()
|
||||
} else {
|
||||
// otherwise, we must allocate/extend a new buffer
|
||||
b.grow()
|
||||
}
|
||||
remaining += b.offset
|
||||
n, err := b.r.Read(b.data[remaining:cap(b.data)])
|
||||
// reduce length to the existing plus the data we read.
|
||||
b.data = b.data[:remaining+n]
|
||||
b.err = err
|
||||
return n
|
||||
}
|
||||
|
||||
// grow grows the buffer, moving the active data to the front.
|
||||
func (b *reader) grow() {
|
||||
buf := make([]byte, max(cap(b.data)*2, newBufferSize))
|
||||
copy(buf, b.data[b.offset:])
|
||||
b.data = buf
|
||||
b.offset = 0
|
||||
}
|
||||
|
||||
// compact moves the active data to the front of the buffer.
|
||||
func (b *reader) compact() {
|
||||
copy(b.data, b.data[b.offset:])
|
||||
b.offset = 0
|
||||
}
|
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/787da6e90e4be491
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/787da6e90e4be491
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xa6\x00\x00\x00\x00\x00\x00\x00\x02\x000\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/8b42c37d144cd2c6
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/8b42c37d144cd2c6
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/92b890e394a77cfc
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/92b890e394a77cfc
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xfd\xff\xff\xff\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/9cfd6a48931a2753
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/9cfd6a48931a2753
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00K\b\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00VV\x01\x00\x00\x00\x00\\x00\\x00\\x00\\x00")
|
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/a8c5454e2a164af2
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/a8c5454e2a164af2
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xa6\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/a931e37cb6f932d4
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/a931e37cb6f932d4
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/bcd20fa73e7351a2
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/bcd20fa73e7351a2
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x0000000000000000000000000\xe5")
|
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/d29846a68e32052d
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/d29846a68e32052d
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x0000000000\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x000\b\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x0000\x01\x00\x00\x00\x00\x00\x00\x000\x01\x00\x001\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\a")
|
30
x/go.mod
Normal file
30
x/go.mod
Normal file
@ -0,0 +1,30 @@
|
||||
module bllamo.com
|
||||
|
||||
go 1.22.1
|
||||
|
||||
require kr.dev/diff v0.3.0
|
||||
|
||||
require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.17.6 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.6 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
github.com/minio/minio-go/v7 v7.0.69 // indirect
|
||||
github.com/minio/sha256-simd v1.0.1 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e // indirect
|
||||
github.com/rogpeppe/go-internal v1.9.0 // indirect
|
||||
github.com/rs/xid v1.5.0 // indirect
|
||||
golang.org/x/crypto v0.19.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20220218215828-6cf2b201936e // indirect
|
||||
golang.org/x/net v0.21.0 // indirect
|
||||
golang.org/x/sync v0.6.0 // indirect
|
||||
golang.org/x/sys v0.17.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
)
|
63
x/go.sum
Normal file
63
x/go.sum
Normal file
@ -0,0 +1,63 @@
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI=
|
||||
github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc=
|
||||
github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||
github.com/minio/minio-go/v7 v7.0.69 h1:l8AnsQFyY1xiwa/DaQskY4NXSLA2yrGsW5iD9nRPVS0=
|
||||
github.com/minio/minio-go/v7 v7.0.69/go.mod h1:XAvOPJQ5Xlzk5o3o/ArO2NMbhSGkimC+bpW/ngRKDmQ=
|
||||
github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
|
||||
github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg=
|
||||
github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/exp v0.0.0-20220218215828-6cf2b201936e h1:iWVPgObh6F4UDtjBLK51zsy5UHTPLQwCmsNjCsbKhQ0=
|
||||
golang.org/x/exp v0.0.0-20220218215828-6cf2b201936e/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
|
||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
kr.dev/diff v0.3.0 h1:o/T8/tkAq9IuRIuFqCupyKPC5iSY3WXpVZ2p6ZK3Emw=
|
||||
kr.dev/diff v0.3.0/go.mod h1:XiTaLOg2/PD0cmXY7WQXUR8RAF3RwWpqIQEj910J2NY=
|
126
x/model/file.go
Normal file
126
x/model/file.go
Normal file
@ -0,0 +1,126 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"iter"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Param struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string
|
||||
Content string
|
||||
}
|
||||
|
||||
type File struct {
|
||||
// From is a required pragma that specifies the source of the model,
|
||||
// either on disk, or by reference (see blob.ParseRef).
|
||||
From string
|
||||
|
||||
// Optional
|
||||
Params []Param
|
||||
Template string
|
||||
System string
|
||||
Adapter string
|
||||
Messages []Message
|
||||
|
||||
License string
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Pragma string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
return e.Pragma + ": " + e.Message
|
||||
}
|
||||
|
||||
type Pragma struct {
|
||||
// The pragma name
|
||||
Name string
|
||||
|
||||
// Args contains the user-defined arguments for the pragma. If no
|
||||
// arguments were provided, it is nil.
|
||||
Args []string
|
||||
}
|
||||
|
||||
func (p Pragma) Arg(i int) string {
|
||||
if i >= len(p.Args) {
|
||||
return ""
|
||||
}
|
||||
return p.Args[i]
|
||||
}
|
||||
|
||||
func Pragmas(r io.Reader) iter.Seq2[Pragma, error] {
|
||||
return func(yield func(Pragma, error) bool) {
|
||||
sc := bufio.NewScanner(r)
|
||||
for sc.Scan() {
|
||||
line := sc.Text()
|
||||
|
||||
// TODO(bmizerany): set a max num fields/args to
|
||||
// prevent mem bloat
|
||||
args := strings.Fields(line)
|
||||
if len(args) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
p := Pragma{
|
||||
Name: strings.ToUpper(args[0]),
|
||||
}
|
||||
if p.Name == "MESSAGE" {
|
||||
// handle special case where message content
|
||||
// is space separated on the _rest_ of the
|
||||
// line like: `MESSAGE user Is Ontario in
|
||||
// Canada?`
|
||||
panic("TODO")
|
||||
}
|
||||
if len(args) > 1 {
|
||||
p.Args = args[1:]
|
||||
}
|
||||
if !yield(p, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if sc.Err() != nil {
|
||||
yield(Pragma{}, sc.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Decode(r io.Reader) (File, error) {
|
||||
var f File
|
||||
for p, err := range Pragmas(r) {
|
||||
if err != nil {
|
||||
return File{}, err
|
||||
}
|
||||
switch p.Name {
|
||||
case "FROM":
|
||||
f.From = p.Arg(0)
|
||||
case "PARAMETER":
|
||||
f.Params = append(f.Params, Param{
|
||||
Key: strings.ToLower(p.Arg(0)),
|
||||
Value: p.Arg(1),
|
||||
})
|
||||
case "TEMPLATE":
|
||||
f.Template = p.Arg(0)
|
||||
case "SYSTEM":
|
||||
f.System = p.Arg(0)
|
||||
case "ADAPTER":
|
||||
f.Adapter = p.Arg(0)
|
||||
case "MESSAGE":
|
||||
f.Messages = append(f.Messages, Message{
|
||||
Role: p.Arg(0),
|
||||
Content: p.Arg(1),
|
||||
})
|
||||
case "LICENSE":
|
||||
f.License = p.Arg(0)
|
||||
}
|
||||
}
|
||||
return f, nil
|
||||
}
|
86
x/oweb/oweb.go
Normal file
86
x/oweb/oweb.go
Normal file
@ -0,0 +1,86 @@
|
||||
package oweb
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"bllamo.com/client/ollama"
|
||||
)
|
||||
|
||||
func Missing(field string) error {
|
||||
return &ollama.Error{
|
||||
Status: 400,
|
||||
Code: "missing",
|
||||
Field: field,
|
||||
Message: fmt.Sprintf("%s is required", field),
|
||||
}
|
||||
}
|
||||
|
||||
func Mistake(code, field, message string) error {
|
||||
return &ollama.Error{
|
||||
Status: 400,
|
||||
Code: code,
|
||||
Field: field,
|
||||
Message: fmt.Sprintf("%s: %s", field, message),
|
||||
}
|
||||
}
|
||||
|
||||
// Convenience errors
|
||||
var (
|
||||
ErrNotFound = &ollama.Error{Status: 404, Code: "not_found"}
|
||||
ErrInternal = &ollama.Error{Status: 500, Code: "internal_error"}
|
||||
ErrMethodNotAllowed = &ollama.Error{Status: 405, Code: "method_not_allowed"}
|
||||
)
|
||||
|
||||
type HandlerFunc func(w http.ResponseWriter, r *http.Request) error
|
||||
|
||||
func Serve(h HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
||||
if err := h(w, r); err != nil {
|
||||
// TODO: take a slog.Logger
|
||||
log.Printf("error: %v", err)
|
||||
var oe *ollama.Error
|
||||
if !errors.As(err, &oe) {
|
||||
oe = ErrInternal
|
||||
}
|
||||
oe.Status = cmp.Or(oe.Status, 400)
|
||||
w.WriteHeader(oe.Status)
|
||||
if err := EncodeJSON(w, oe); err != nil {
|
||||
log.Printf("error encoding error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func DecodeUserJSON[T any](field string, r io.Reader) (*T, error) {
|
||||
v, err := DecodeJSON[T](r)
|
||||
if err == nil {
|
||||
return v, nil
|
||||
}
|
||||
var msg string
|
||||
var e *json.SyntaxError
|
||||
if errors.As(err, &e) {
|
||||
msg = e.Error()
|
||||
}
|
||||
var se *json.UnmarshalTypeError
|
||||
if errors.As(err, &se) {
|
||||
msg = fmt.Sprintf("%s (%q) is not a %s", se.Field, se.Value, se.Type)
|
||||
}
|
||||
return nil, Mistake("invalid_json", field, msg)
|
||||
}
|
||||
|
||||
func DecodeJSON[T any](r io.Reader) (*T, error) {
|
||||
var v *T
|
||||
if err := json.NewDecoder(r).Decode(&v); err != nil {
|
||||
var zero T
|
||||
return &zero, err
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func EncodeJSON(w io.Writer, v any) error {
|
||||
return json.NewEncoder(w).Encode(v)
|
||||
}
|
46
x/registry/apitype/apitype.go
Normal file
46
x/registry/apitype/apitype.go
Normal file
@ -0,0 +1,46 @@
|
||||
package apitype
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type Manifest struct {
|
||||
Layers []Layer `json:"layers"`
|
||||
}
|
||||
|
||||
type CompletePart struct {
|
||||
URL string `json:"url"` // contains PartNumber and UploadID from server
|
||||
ETag string `json:"etag"`
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
Digest string `json:"digest"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
type PushRequest struct {
|
||||
Ref string `json:"ref"`
|
||||
Manifest json.RawMessage `json:"manifest"`
|
||||
|
||||
// Parts is a list of upload parts that the client upload in the previous
|
||||
// push.
|
||||
Uploaded []CompletePart `json:"part_uploads"`
|
||||
}
|
||||
|
||||
type Requirement struct {
|
||||
Digest string `json:"digest"`
|
||||
Offset int64 `json:"offset"`
|
||||
Size int64 `json:"Size"`
|
||||
|
||||
// URL is the url to PUT the layer to.
|
||||
//
|
||||
// Clients must include it as the URL, alond with the ETag in the
|
||||
// response headers from the PUT request, in the next push request
|
||||
// in the Uploaded field.
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type PushResponse struct {
|
||||
// Requirements is a list of digests that the client needs to push before
|
||||
// repushing the manifest.
|
||||
Requirements []Requirement `json:"requirements,omitempty"`
|
||||
}
|
81
x/registry/client.go
Normal file
81
x/registry/client.go
Normal file
@ -0,0 +1,81 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"bllamo.com/client/ollama"
|
||||
"bllamo.com/registry/apitype"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
BaseURL string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
func (c *Client) oclient() *ollama.Client {
|
||||
return (*ollama.Client)(c)
|
||||
}
|
||||
|
||||
type PushParams struct {
|
||||
Uploaded []apitype.CompletePart
|
||||
}
|
||||
|
||||
// Push pushes a manifest to the server.
|
||||
func (c *Client) Push(ctx context.Context, ref string, manifest []byte, p *PushParams) ([]apitype.Requirement, error) {
|
||||
p = cmp.Or(p, &PushParams{})
|
||||
// TODO(bmizerany): backoff
|
||||
v, err := ollama.Do[apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{
|
||||
Ref: ref,
|
||||
Manifest: manifest,
|
||||
Uploaded: p.Uploaded,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return v.Requirements, nil
|
||||
}
|
||||
|
||||
func PushLayer(ctx context.Context, dstURL string, off, size int64, file io.ReaderAt) (etag string, err error) {
|
||||
sr := io.NewSectionReader(file, off, size)
|
||||
req, err := http.NewRequestWithContext(ctx, "PUT", dstURL, sr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.ContentLength = size
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != 200 {
|
||||
return "", parseS3Error(res)
|
||||
}
|
||||
return res.Header.Get("ETag"), nil
|
||||
}
|
||||
|
||||
type s3Error struct {
|
||||
XMLName xml.Name `xml:"Error"`
|
||||
Code string `xml:"Code"`
|
||||
Message string `xml:"Message"`
|
||||
Resource string `xml:"Resource"`
|
||||
RequestId string `xml:"RequestId"`
|
||||
}
|
||||
|
||||
func (e *s3Error) Error() string {
|
||||
return fmt.Sprintf("S3 (%s): %s: %s: %s", e.RequestId, e.Resource, e.Code, e.Message)
|
||||
}
|
||||
|
||||
// parseS3Error parses an XML error response from S3.
|
||||
func parseS3Error(res *http.Response) error {
|
||||
var se *s3Error
|
||||
if err := xml.NewDecoder(res.Body).Decode(&se); err != nil {
|
||||
return err
|
||||
}
|
||||
return se
|
||||
}
|
233
x/registry/server.go
Normal file
233
x/registry/server.go
Normal file
@ -0,0 +1,233 @@
|
||||
// Package implements an Ollama registry client and server package registry
|
||||
package registry
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"bllamo.com/build/blob"
|
||||
"bllamo.com/client/ollama"
|
||||
"bllamo.com/oweb"
|
||||
"bllamo.com/registry/apitype"
|
||||
"bllamo.com/utils/upload"
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
)
|
||||
|
||||
// Defaults
|
||||
const (
|
||||
DefaultUploadChunkSize = 50 * 1024 * 1024
|
||||
)
|
||||
|
||||
// TODO(bmizerany): move all env things to package envkobs?
|
||||
var defaultLibrary = cmp.Or(os.Getenv("OLLAMA_REGISTRY"), "registry.ollama.ai/library")
|
||||
|
||||
func DefaultLibrary() string {
|
||||
return defaultLibrary
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
UploadChunkSize int64 // default is DefaultUploadChunkSize
|
||||
minioClient *minio.Client
|
||||
}
|
||||
|
||||
func New(mc *minio.Client) *Server {
|
||||
return &Server{minioClient: mc}
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if err := s.serveHTTP(w, r); err != nil {
|
||||
log.Printf("error: %v", err) // TODO(bmizerany): take a slog.Logger
|
||||
var e *ollama.Error
|
||||
if !errors.As(err, &e) {
|
||||
e = oweb.ErrInternal
|
||||
}
|
||||
w.WriteHeader(cmp.Or(e.Status, 400))
|
||||
if err := oweb.EncodeJSON(w, e); err != nil {
|
||||
log.Printf("error encoding error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error {
|
||||
switch r.URL.Path {
|
||||
case "/v1/push":
|
||||
return s.handlePush(w, r)
|
||||
case "/v1/pull":
|
||||
return s.handlePull(w, r)
|
||||
default:
|
||||
return oweb.ErrNotFound
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) uploadChunkSize() int64 {
|
||||
return cmp.Or(s.UploadChunkSize, DefaultUploadChunkSize)
|
||||
}
|
||||
|
||||
func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
|
||||
const bucketTODO = "test"
|
||||
|
||||
pr, err := oweb.DecodeUserJSON[apitype.PushRequest]("", r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ref := blob.ParseRef(pr.Ref)
|
||||
if !ref.Complete() {
|
||||
return oweb.Mistake("invalid", "name", "must be complete")
|
||||
}
|
||||
|
||||
m, err := oweb.DecodeUserJSON[apitype.Manifest]("manifest", bytes.NewReader(pr.Manifest))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mcc := &minio.Core{Client: s.mc()}
|
||||
// TODO(bmizerany): complete uploads before stats for any with ETag
|
||||
|
||||
type completeParts struct {
|
||||
key string
|
||||
parts []minio.CompletePart
|
||||
}
|
||||
|
||||
completePartsByUploadID := make(map[string]completeParts)
|
||||
for _, pu := range pr.Uploaded {
|
||||
// parse the URL
|
||||
u, err := url.Parse(pu.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
q := u.Query()
|
||||
uploadID := q.Get("UploadId")
|
||||
if uploadID == "" {
|
||||
return oweb.Mistake("invalid", "url", "missing UploadId")
|
||||
}
|
||||
partNumber, err := strconv.Atoi(q.Get("PartNumber"))
|
||||
if err != nil {
|
||||
return oweb.Mistake("invalid", "url", "invalid or missing PartNumber")
|
||||
}
|
||||
etag := pu.ETag
|
||||
if etag == "" {
|
||||
return oweb.Mistake("invalid", "etag", "missing")
|
||||
}
|
||||
cp, ok := completePartsByUploadID[uploadID]
|
||||
if !ok {
|
||||
cp = completeParts{key: u.Path}
|
||||
completePartsByUploadID[uploadID] = cp
|
||||
}
|
||||
cp.parts = append(cp.parts, minio.CompletePart{
|
||||
PartNumber: partNumber,
|
||||
ETag: etag,
|
||||
})
|
||||
fmt.Println("uploadID", uploadID, "partNumber", partNumber, "etag", etag)
|
||||
completePartsByUploadID[uploadID] = cp
|
||||
}
|
||||
|
||||
for uploadID, cp := range completePartsByUploadID {
|
||||
var zeroOpts minio.PutObjectOptions
|
||||
_, err := mcc.CompleteMultipartUpload(r.Context(), bucketTODO, cp.key, uploadID, cp.parts, zeroOpts)
|
||||
if err != nil {
|
||||
// log and continue; put backpressure on the client
|
||||
log.Printf("error completing upload: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
var requirements []apitype.Requirement
|
||||
for _, l := range m.Layers {
|
||||
// TODO(bmizerany): do in parallel
|
||||
if l.Size == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO(bmizerany): "global" throttle of rate of transfer
|
||||
pushed, err := s.statObject(r.Context(), l.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !pushed {
|
||||
key := path.Join("blobs", l.Digest)
|
||||
uploadID, err := mcc.NewMultipartUpload(r.Context(), bucketTODO, key, minio.PutObjectOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for partNumber, c := range upload.Chunks(l.Size, s.uploadChunkSize()) {
|
||||
const timeToStartUpload = 15 * time.Minute
|
||||
|
||||
signedURL, err := s.mc().Presign(r.Context(), "PUT", bucketTODO, key, timeToStartUpload, url.Values{
|
||||
"UploadId": []string{uploadID},
|
||||
"PartNumber": []string{strconv.Itoa(partNumber)},
|
||||
"ContentLength": []string{strconv.FormatInt(c.Size, 10)},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
requirements = append(requirements, apitype.Requirement{
|
||||
Digest: l.Digest,
|
||||
Offset: c.Offset,
|
||||
Size: c.Size,
|
||||
URL: signedURL.String(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(requirements) == 0 {
|
||||
// Commit the manifest
|
||||
body := bytes.NewReader(pr.Manifest)
|
||||
path := path.Join("manifests", path.Join(ref.Parts()...))
|
||||
_, err := s.mc().PutObject(r.Context(), bucketTODO, path, body, int64(len(pr.Manifest)), minio.PutObjectOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return oweb.EncodeJSON(w, &apitype.PushResponse{Requirements: requirements})
|
||||
}
|
||||
|
||||
func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
// lookup manifest
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
func (s *Server) statObject(ctx context.Context, digest string) (pushed bool, err error) {
|
||||
// HEAD the object
|
||||
path := path.Join("blobs", digest)
|
||||
_, err = s.mc().StatObject(ctx, "test", path, minio.StatObjectOptions{})
|
||||
if err != nil {
|
||||
if isNoSuchKey(err) {
|
||||
err = nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func isNoSuchKey(err error) bool {
|
||||
var e minio.ErrorResponse
|
||||
return errors.As(err, &e) && e.Code == "NoSuchKey"
|
||||
}
|
||||
|
||||
func (s *Server) mc() *minio.Client {
|
||||
if s.minioClient != nil {
|
||||
return s.minioClient
|
||||
}
|
||||
mc, err := minio.New("localhost:9000", &minio.Options{
|
||||
Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""),
|
||||
Secure: false,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return mc
|
||||
}
|
253
x/registry/server_test.go
Normal file
253
x/registry/server_test.go
Normal file
@ -0,0 +1,253 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"bllamo.com/registry/apitype"
|
||||
"bllamo.com/utils/backoff"
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
"kr.dev/diff"
|
||||
)
|
||||
|
||||
const abc = "abcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
func testPush(t *testing.T, chunkSize int64) {
|
||||
t.Run(fmt.Sprintf("chunkSize=%d", chunkSize), func(t *testing.T) {
|
||||
mc := startMinio(t, false)
|
||||
|
||||
manifest := []byte(`{
|
||||
"layers": [
|
||||
{"digest": "sha256-1", "size": 1},
|
||||
{"digest": "sha256-2", "size": 2},
|
||||
{"digest": "sha256-3", "size": 3}
|
||||
]
|
||||
}`)
|
||||
|
||||
const ref = "registry.ollama.ai/x/y:latest+Z"
|
||||
|
||||
hs := httptest.NewServer(&Server{
|
||||
minioClient: mc,
|
||||
UploadChunkSize: chunkSize,
|
||||
})
|
||||
t.Cleanup(hs.Close)
|
||||
c := &Client{BaseURL: hs.URL}
|
||||
|
||||
requirements, err := c.Push(context.Background(), ref, manifest, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(requirements) < 3 {
|
||||
t.Fatalf("expected at least 3 requirements; got %d", len(requirements))
|
||||
t.Logf("requirements: %v", requirements)
|
||||
}
|
||||
|
||||
var uploaded []apitype.CompletePart
|
||||
for i, r := range requirements {
|
||||
t.Logf("[%d] pushing layer: offset=%d size=%d", i, r.Offset, r.Size)
|
||||
|
||||
body := strings.NewReader(abc)
|
||||
etag, err := PushLayer(context.Background(), r.URL, r.Offset, r.Size, body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
uploaded = append(uploaded, apitype.CompletePart{
|
||||
URL: r.URL,
|
||||
ETag: etag,
|
||||
})
|
||||
}
|
||||
|
||||
requirements, err = c.Push(context.Background(), ref, manifest, &PushParams{
|
||||
Uploaded: uploaded,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(requirements) != 0 {
|
||||
t.Fatalf("unexpected requirements: %v", requirements)
|
||||
}
|
||||
|
||||
var paths []string
|
||||
keys := mc.ListObjects(context.Background(), "test", minio.ListObjectsOptions{
|
||||
Recursive: true,
|
||||
})
|
||||
for k := range keys {
|
||||
paths = append(paths, k.Key)
|
||||
}
|
||||
|
||||
t.Logf("paths: %v", paths)
|
||||
|
||||
diff.Test(t, t.Errorf, paths, []string{
|
||||
"blobs/sha256-1",
|
||||
"blobs/sha256-2",
|
||||
"blobs/sha256-3",
|
||||
"manifests/registry.ollama.ai/x/y/latest/Z",
|
||||
})
|
||||
|
||||
obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer obj.Close()
|
||||
|
||||
var gotM apitype.Manifest
|
||||
if err := json.NewDecoder(obj).Decode(&gotM); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
diff.Test(t, t.Errorf, gotM, apitype.Manifest{
|
||||
Layers: []apitype.Layer{
|
||||
{Digest: "sha256-1", Size: 1},
|
||||
{Digest: "sha256-2", Size: 2},
|
||||
{Digest: "sha256-3", Size: 3},
|
||||
},
|
||||
})
|
||||
|
||||
// checksum the blobs
|
||||
for i, l := range gotM.Layers {
|
||||
obj, err := mc.GetObject(context.Background(), "test", "blobs/"+l.Digest, minio.GetObjectOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer obj.Close()
|
||||
|
||||
info, err := obj.Stat()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("[%d] layer info: name=%q l.Size=%d size=%d", i, info.Key, l.Size, info.Size)
|
||||
|
||||
data, err := io.ReadAll(obj)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := string(data)
|
||||
want := abc[:l.Size]
|
||||
if got != want {
|
||||
t.Errorf("[%d] got layer data = %q; want %q", i, got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPush(t *testing.T) {
|
||||
testPush(t, 0)
|
||||
testPush(t, 1)
|
||||
}
|
||||
|
||||
func availableAddr() string {
|
||||
l, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer l.Close()
|
||||
return l.Addr().String()
|
||||
}
|
||||
|
||||
func startMinio(t *testing.T, debug bool) *minio.Client {
|
||||
t.Helper()
|
||||
|
||||
dir := t.TempDir()
|
||||
t.Logf(">> minio data dir: %s", dir)
|
||||
addr := availableAddr()
|
||||
cmd := exec.Command("minio", "server", "--address", addr, dir)
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
if debug {
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
doneLogging := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
<-doneLogging
|
||||
})
|
||||
go func() {
|
||||
defer close(doneLogging)
|
||||
sc := bufio.NewScanner(stdout)
|
||||
for sc.Scan() {
|
||||
t.Logf("minio: %s", sc.Text())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// TODO(bmizerany): wait delay etc...
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
cmd.Process.Kill()
|
||||
if err := cmd.Wait(); err != nil {
|
||||
var e *exec.ExitError
|
||||
if errors.As(err, &e) && e.Exited() {
|
||||
t.Logf("minio stderr: %s", e.Stderr)
|
||||
t.Logf("minio exit status: %v", e.ExitCode())
|
||||
t.Logf("minio exited: %v", e.Exited())
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
mc, err := minio.New(addr, &minio.Options{
|
||||
Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""),
|
||||
Secure: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
deadline, ok := t.Deadline()
|
||||
if ok {
|
||||
ctx, cancel = context.WithDeadline(ctx, deadline.Add(-100*time.Millisecond))
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// wait for server to start with exponential backoff
|
||||
for _, err := range backoff.Upto(ctx, 1*time.Second) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if mc.IsOnline() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err := mc.MakeBucket(context.Background(), "test", minio.MakeBucketOptions{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return mc
|
||||
}
|
||||
|
||||
// contextForTest returns a context that is canceled when the test deadline,
|
||||
// if any, is reached. The returned doneLogging function should be called
|
||||
// after all Log/Error/Fatalf calls are done before the test returns.
|
||||
func contextForTest(t *testing.T) (_ context.Context, doneLogging func()) {
|
||||
done := make(chan struct{})
|
||||
deadline, ok := t.Deadline()
|
||||
if !ok {
|
||||
return context.Background(), func() {}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithDeadline(context.Background(), deadline.Add(-100*time.Millisecond))
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
<-done
|
||||
})
|
||||
return ctx, func() { close(done) }
|
||||
}
|
4
x/types/empty/message.go
Normal file
4
x/types/empty/message.go
Normal file
@ -0,0 +1,4 @@
|
||||
package empty
|
||||
|
||||
// Message is a placeholder type used when encoding json messages.
|
||||
type Message struct{}
|
15
x/types/structs/structs.go
Normal file
15
x/types/structs/structs.go
Normal file
@ -0,0 +1,15 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package structs contains the Incomparable type.
|
||||
package structs
|
||||
|
||||
// Incomparable is a zero-width incomparable type. If added as the
|
||||
// first field in a struct, it marks that struct as not comparable
|
||||
// (can't do == or be a map key) and usually doesn't add any width to
|
||||
// the struct (unless the struct has only small fields).
|
||||
//
|
||||
// By making a struct incomparable, you can prevent misuse (prevent
|
||||
// people from using ==), but also you can shrink generated binaries,
|
||||
// as the compiler can omit equality funcs from the binary.
|
||||
type Incomparable [0]func()
|
12
x/types/they/want.go
Normal file
12
x/types/they/want.go
Normal file
@ -0,0 +1,12 @@
|
||||
package they
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Want returns true if the request method is method and the request path
|
||||
// starts with pathPrefix.
|
||||
func Want(r *http.Request, method string, pathPrefix string) bool {
|
||||
return r.Method == method && strings.HasPrefix(r.URL.Path, pathPrefix)
|
||||
}
|
58
x/utils/backoff/backoff.go
Normal file
58
x/utils/backoff/backoff.go
Normal file
@ -0,0 +1,58 @@
|
||||
package backoff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"iter"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
// ErrMaxAttempts is not used by backoff but is available for use by
|
||||
// callers that want to signal that a maximum number of retries has
|
||||
// been exceeded. This should eliminate the need for callers to invent
|
||||
// their own error.
|
||||
ErrMaxAttempts = errors.New("max retries exceeded")
|
||||
)
|
||||
|
||||
// Upto implements a backoff strategy that yields nil errors until the
|
||||
// context is canceled, the maxRetries is exceeded, or yield returns false.
|
||||
//
|
||||
// The backoff strategy is a simple exponential backoff with a maximum
|
||||
// backoff of maxBackoff. The backoff is randomized between 0.5-1.5 times
|
||||
// the current backoff, in order to prevent accidental "thundering herd"
|
||||
// problems.
|
||||
func Upto(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
|
||||
var n int
|
||||
return func(yield func(int, error) bool) {
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
yield(n, ctx.Err())
|
||||
return
|
||||
}
|
||||
|
||||
n++
|
||||
|
||||
// n^2 backoff timer is a little smoother than the
|
||||
// common choice of 2^n.
|
||||
d := time.Duration(n*n) * 10 * time.Millisecond
|
||||
if d > maxBackoff {
|
||||
d = maxBackoff
|
||||
}
|
||||
// Randomize the delay between 0.5-1.5 x msec, in order
|
||||
// to prevent accidental "thundering herd" problems.
|
||||
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
||||
t := time.NewTimer(d)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
case <-t.C:
|
||||
if !yield(n, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
27
x/utils/upload/upload.go
Normal file
27
x/utils/upload/upload.go
Normal file
@ -0,0 +1,27 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"iter"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type Chunk[I constraints.Integer] struct {
|
||||
Offset I
|
||||
Size I
|
||||
}
|
||||
|
||||
// Chunks yields a sequence of a part number and a Chunk. The Chunk is the offset
|
||||
// and size of the chunk. The last chunk may be smaller than chunkSize if size is
|
||||
// not a multiple of chunkSize.
|
||||
//
|
||||
// The first part number is 1 and increases monotonically.
|
||||
func Chunks[I constraints.Integer](size, chunkSize I) iter.Seq2[int, Chunk[I]] {
|
||||
return func(yield func(int, Chunk[I]) bool) {
|
||||
var n int
|
||||
for off := I(0); off < size; off += chunkSize {
|
||||
n++
|
||||
yield(n, Chunk[I]{off, min(chunkSize, size-off)})
|
||||
}
|
||||
}
|
||||
}
|
37
x/utils/upload/upload_test.go
Normal file
37
x/utils/upload/upload_test.go
Normal file
@ -0,0 +1,37 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"kr.dev/diff"
|
||||
)
|
||||
|
||||
func TestChunks(t *testing.T) {
|
||||
const size = 101
|
||||
const chunkSize = 10
|
||||
var got []Chunk[int]
|
||||
var lastN int
|
||||
for n, c := range Chunks(size, chunkSize) {
|
||||
if n != lastN+1 {
|
||||
t.Errorf("n = %d; want %d", n, lastN+1)
|
||||
}
|
||||
got = append(got, c)
|
||||
lastN = n
|
||||
}
|
||||
|
||||
want := []Chunk[int]{
|
||||
{0, 10},
|
||||
{10, 10},
|
||||
{20, 10},
|
||||
{30, 10},
|
||||
{40, 10},
|
||||
{50, 10},
|
||||
{60, 10},
|
||||
{70, 10},
|
||||
{80, 10},
|
||||
{90, 10},
|
||||
{100, 1},
|
||||
}
|
||||
|
||||
diff.Test(t, t.Errorf, got, want)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user