ollama/x/registry/server.go
2024-04-03 15:58:04 -07:00

248 lines
6.0 KiB
Go

// Package implements an Ollama registry client and server package registry
package registry
import (
"bytes"
"cmp"
"context"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"os"
"path"
"strconv"
"time"
"bllamo.com/build/blob"
"bllamo.com/client/ollama"
"bllamo.com/oweb"
"bllamo.com/registry/apitype"
"bllamo.com/utils/upload"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
)
// Defaults
const (
DefaultUploadChunkSize = 50 * 1024 * 1024
)
// TODO(bmizerany): move all env things to package envkobs?
var defaultLibrary = cmp.Or(os.Getenv("OLLAMA_REGISTRY"), "registry.ollama.ai/library")
func DefaultLibrary() string {
return defaultLibrary
}
type Server struct {
UploadChunkSize int64 // default is DefaultUploadChunkSize
minioClient *minio.Client
}
func New(mc *minio.Client) *Server {
return &Server{minioClient: mc}
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err := s.serveHTTP(w, r); err != nil {
log.Printf("error: %v", err) // TODO(bmizerany): take a slog.Logger
var e *ollama.Error
if !errors.As(err, &e) {
e = oweb.ErrInternal
}
w.WriteHeader(cmp.Or(e.Status, 400))
if err := oweb.EncodeJSON(w, e); err != nil {
log.Printf("error encoding error: %v", err)
}
}
}
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error {
switch r.URL.Path {
case "/v1/push":
return s.handlePush(w, r)
case "/v1/pull":
return s.handlePull(w, r)
default:
return oweb.ErrNotFound
}
}
func (s *Server) uploadChunkSize() int64 {
return cmp.Or(s.UploadChunkSize, DefaultUploadChunkSize)
}
func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
const bucketTODO = "test"
const minimumMultipartSize = 5 * 1024 * 1024 // S3 spec
pr, err := oweb.DecodeUserJSON[apitype.PushRequest]("", r.Body)
if err != nil {
return err
}
ref := blob.ParseRef(pr.Ref)
if !ref.Complete() {
return oweb.Mistake("invalid", "name", "must be complete")
}
m, err := oweb.DecodeUserJSON[apitype.Manifest]("manifest", bytes.NewReader(pr.Manifest))
if err != nil {
return err
}
mcc := &minio.Core{Client: s.mc()}
// TODO(bmizerany): complete uploads before stats for any with ETag
type completeParts struct {
key string
parts []minio.CompletePart
}
completePartsByUploadID := make(map[string]completeParts)
for _, pu := range pr.Uploaded {
// parse the URL
u, err := url.Parse(pu.URL)
if err != nil {
return err
}
q := u.Query()
uploadID := q.Get("uploadId")
if uploadID == "" {
// not a part upload
continue
}
partNumber, err := strconv.Atoi(q.Get("partNumber"))
if err != nil {
return oweb.Mistake("invalid", "url", "invalid or missing PartNumber")
}
etag := pu.ETag
if etag == "" {
return oweb.Mistake("invalid", "etag", "missing")
}
cp, ok := completePartsByUploadID[uploadID]
if !ok {
cp = completeParts{key: u.Path}
completePartsByUploadID[uploadID] = cp
}
cp.parts = append(cp.parts, minio.CompletePart{
PartNumber: partNumber,
ETag: etag,
})
fmt.Println("uploadID", uploadID, "partNumber", partNumber, "etag", etag)
completePartsByUploadID[uploadID] = cp
}
for uploadID, cp := range completePartsByUploadID {
var zeroOpts minio.PutObjectOptions
_, err := mcc.CompleteMultipartUpload(r.Context(), bucketTODO, cp.key, uploadID, cp.parts, zeroOpts)
if err != nil {
// log and continue; put backpressure on the client
log.Printf("error completing upload: %v", err)
}
}
var requirements []apitype.Requirement
for _, l := range m.Layers {
// TODO(bmizerany): do in parallel
if l.Size == 0 {
continue
}
// TODO(bmizerany): "global" throttle of rate of transfer
pushed, err := s.statObject(r.Context(), l.Digest)
if err != nil {
return err
}
if !pushed {
key := path.Join("blobs", l.Digest)
if l.Size < minimumMultipartSize {
// single part upload
signedURL, err := s.mc().PresignedPutObject(r.Context(), bucketTODO, key, 15*time.Minute)
if err != nil {
return err
}
requirements = append(requirements, apitype.Requirement{
Digest: l.Digest,
Size: l.Size,
URL: signedURL.String(),
})
} else {
uploadID, err := mcc.NewMultipartUpload(r.Context(), bucketTODO, key, minio.PutObjectOptions{})
if err != nil {
return err
}
for partNumber, c := range upload.Chunks(l.Size, s.uploadChunkSize()) {
const timeToStartUpload = 15 * time.Minute
signedURL, err := s.mc().Presign(r.Context(), "PUT", bucketTODO, key, timeToStartUpload, url.Values{
"uploadId": []string{uploadID},
"partNumber": []string{strconv.Itoa(partNumber)},
})
if err != nil {
return err
}
requirements = append(requirements, apitype.Requirement{
Digest: l.Digest,
Offset: c.Offset,
Size: c.N,
URL: signedURL.String(),
})
}
}
}
}
if len(requirements) == 0 {
// Commit the manifest
body := bytes.NewReader(pr.Manifest)
path := path.Join("manifests", path.Join(ref.Parts()...))
_, err := s.mc().PutObject(r.Context(), bucketTODO, path, body, int64(len(pr.Manifest)), minio.PutObjectOptions{})
if err != nil {
return err
}
}
return oweb.EncodeJSON(w, &apitype.PushResponse{Requirements: requirements})
}
func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error {
// lookup manifest
panic("TODO")
}
func (s *Server) statObject(ctx context.Context, digest string) (pushed bool, err error) {
// HEAD the object
path := path.Join("blobs", digest)
_, err = s.mc().StatObject(ctx, "test", path, minio.StatObjectOptions{})
if err != nil {
if isNoSuchKey(err) {
err = nil
}
return false, err
}
return true, nil
}
func isNoSuchKey(err error) bool {
var e minio.ErrorResponse
return errors.As(err, &e) && e.Code == "NoSuchKey"
}
func (s *Server) mc() *minio.Client {
if s.minioClient != nil {
return s.minioClient
}
mc, err := minio.New("localhost:9000", &minio.Options{
Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""),
Secure: false,
})
if err != nil {
panic(err)
}
return mc
}