registry: initial work on multipart pushes

This commit is contained in:
Blake Mizerany 2024-04-03 10:39:30 -07:00
parent 94befe366a
commit a10a11b9d3
6 changed files with 378 additions and 123 deletions

View File

@ -10,6 +10,7 @@ import (
"bllamo.com/client/ollama/apitype"
"bllamo.com/oweb"
"bllamo.com/registry"
regtype "bllamo.com/registry/apitype"
)
// Common API Errors
@ -64,11 +65,12 @@ func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error {
}
c := registry.Client{BaseURL: registryURLTODO}
requirements, err := c.Push(r.Context(), params.Name, man)
requirements, err := c.Push(r.Context(), params.Name, man, nil)
if err != nil {
return err
}
var uploads []regtype.CompletePart
for _, rq := range requirements {
l, err := s.Build.LayerFile(rq.Digest)
if err != nil {
@ -80,7 +82,15 @@ func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error {
return err
}
defer f.Close()
return registry.PushLayer(r.Context(), rq.URL, rq.Size, f)
etag, err := registry.PushLayer(r.Context(), rq.URL, rq.Offset, rq.Size, f)
if err != nil {
return err
}
uploads = append(uploads, regtype.CompletePart{
URL: rq.URL,
ETag: etag,
})
return nil
}()
if err != nil {
return err
@ -88,7 +98,9 @@ func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error {
}
// commit the manifest to the registry
requirements, err = c.Push(r.Context(), params.Name, man)
requirements, err = c.Push(r.Context(), params.Name, man, &registry.PushParams{
Uploaded: uploads,
})
if err != nil {
return err
}

View File

@ -6,6 +6,11 @@ type Manifest struct {
Layers []Layer `json:"layers"`
}
type CompletePart struct {
URL string `json:"url"` // contains PartNumber and UploadID from server
ETag string `json:"etag"`
}
type Layer struct {
Digest string `json:"digest"`
MediaType string `json:"mediaType"`
@ -13,15 +18,25 @@ type Layer struct {
}
type PushRequest struct {
Ref string `json:"ref"`
Manifest json.RawMessage
Ref string `json:"ref"`
Manifest json.RawMessage `json:"manifest"`
// Parts is a list of upload parts that the client upload in the previous
// push.
Uploaded []CompletePart `json:"part_uploads"`
}
type Requirement struct {
Digest string `json:"digest"`
Offset int64 `json:"offset"`
Size int64 `json:"Size"`
URL string `json:"url"`
// URL is the url to PUT the layer to.
//
// Clients must include it as the URL, alond with the ETag in the
// response headers from the PUT request, in the next push request
// in the Uploaded field.
URL string `json:"url"`
}
type PushResponse struct {

View File

@ -1,7 +1,10 @@
package registry
import (
"cmp"
"context"
"encoding/xml"
"fmt"
"io"
"net/http"
@ -18,12 +21,18 @@ func (c *Client) oclient() *ollama.Client {
return (*ollama.Client)(c)
}
type PushParams struct {
Uploaded []apitype.CompletePart
}
// Push pushes a manifest to the server.
func (c *Client) Push(ctx context.Context, ref string, manifest []byte) ([]apitype.Requirement, error) {
func (c *Client) Push(ctx context.Context, ref string, manifest []byte, p *PushParams) ([]apitype.Requirement, error) {
p = cmp.Or(p, &PushParams{})
// TODO(bmizerany): backoff
v, err := ollama.Do[apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{
Ref: ref,
Manifest: manifest,
Uploaded: p.Uploaded,
})
if err != nil {
return nil, err
@ -31,26 +40,42 @@ func (c *Client) Push(ctx context.Context, ref string, manifest []byte) ([]apity
return v.Requirements, nil
}
func PushLayer(ctx context.Context, dstURL string, size int64, file io.Reader) error {
req, err := http.NewRequest("PUT", dstURL, file)
func PushLayer(ctx context.Context, dstURL string, off, size int64, file io.ReaderAt) (etag string, err error) {
sr := io.NewSectionReader(file, off, size)
req, err := http.NewRequestWithContext(ctx, "PUT", dstURL, sr)
if err != nil {
return err
return "", err
}
req.ContentLength = size
res, err := http.DefaultClient.Do(req)
if err != nil {
return err
return "", err
}
defer res.Body.Close()
if res.StatusCode != 200 {
e := &ollama.Error{Status: res.StatusCode}
msg, err := io.ReadAll(res.Body)
if err != nil {
return err
}
// TODO(bmizerany): format error message
e.Message = string(msg)
return "", parseS3Error(res)
}
return nil
return res.Header.Get("ETag"), nil
}
type s3Error struct {
XMLName xml.Name `xml:"Error"`
Code string `xml:"Code"`
Message string `xml:"Message"`
Resource string `xml:"Resource"`
RequestId string `xml:"RequestId"`
}
func (e *s3Error) Error() string {
return fmt.Sprintf("S3 (%s): %s: %s: %s", e.RequestId, e.Resource, e.Code, e.Message)
}
// parseS3Error parses an XML error response from S3.
func parseS3Error(res *http.Response) error {
var se *s3Error
if err := xml.NewDecoder(res.Body).Decode(&se); err != nil {
return err
}
return se
}

View File

@ -6,8 +6,8 @@ import (
"cmp"
"context"
"errors"
"fmt"
"log"
"math/rand"
"net/http"
"net/url"
"os"
@ -70,7 +70,13 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error {
}
}
func (s *Server) uploadChunkSize() int64 {
return cmp.Or(s.UploadChunkSize, DefaultUploadChunkSize)
}
func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
const bucketTODO = "test"
pr, err := oweb.DecodeUserJSON[apitype.PushRequest]("", r.Body)
if err != nil {
return err
@ -78,7 +84,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
ref := blob.ParseRef(pr.Ref)
if !ref.Complete() {
return oweb.Mistake("invalid", "name", "must be fully qualified")
return oweb.Mistake("invalid", "name", "must be complete")
}
m, err := oweb.DecodeUserJSON[apitype.Manifest]("manifest", bytes.NewReader(pr.Manifest))
@ -86,28 +92,80 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
return err
}
// TODO(bmizerany): parallelize
mcc := &minio.Core{Client: s.mc()}
// TODO(bmizerany): complete uploads before stats for any with ETag
type completeParts struct {
key string
parts []minio.CompletePart
}
completePartsByUploadID := make(map[string]completeParts)
for _, pu := range pr.Uploaded {
// parse the URL
u, err := url.Parse(pu.URL)
if err != nil {
return err
}
q := u.Query()
uploadID := q.Get("UploadId")
if uploadID == "" {
return oweb.Mistake("invalid", "url", "missing UploadId")
}
partNumber, err := strconv.Atoi(q.Get("PartNumber"))
if err != nil {
return oweb.Mistake("invalid", "url", "invalid or missing PartNumber")
}
etag := pu.ETag
if etag == "" {
return oweb.Mistake("invalid", "etag", "missing")
}
cp, ok := completePartsByUploadID[uploadID]
if !ok {
cp = completeParts{key: u.Path}
completePartsByUploadID[uploadID] = cp
}
cp.parts = append(cp.parts, minio.CompletePart{
PartNumber: partNumber,
ETag: etag,
})
fmt.Println("uploadID", uploadID, "partNumber", partNumber, "etag", etag)
completePartsByUploadID[uploadID] = cp
}
for uploadID, cp := range completePartsByUploadID {
var zeroOpts minio.PutObjectOptions
_, err := mcc.CompleteMultipartUpload(r.Context(), bucketTODO, cp.key, uploadID, cp.parts, zeroOpts)
if err != nil {
// log and continue; put backpressure on the client
log.Printf("error completing upload: %v", err)
}
}
var requirements []apitype.Requirement
for _, l := range m.Layers {
// TODO(bmizerany): do in parallel
if l.Size == 0 {
continue
}
// TODO(bmizerany): "global" throttle of rate of transfer
pushed, err := s.statObject(r.Context(), l.Digest)
if err != nil {
return err
}
if !pushed {
uploadID := generateUploadID()
for n, c := range upload.Chunks(l.Size, cmp.Or(s.UploadChunkSize, DefaultUploadChunkSize)) {
const expires = 15 * time.Minute
key := path.Join("blobs", l.Digest)
uploadID, err := mcc.NewMultipartUpload(r.Context(), bucketTODO, key, minio.PutObjectOptions{})
if err != nil {
return err
}
for partNumber, c := range upload.Chunks(l.Size, s.uploadChunkSize()) {
const timeToStartUpload = 15 * time.Minute
key := path.Join("blobs", l.Digest)
signedURL, err := s.mc().Presign(r.Context(), "PUT", "test", key, expires, url.Values{
signedURL, err := s.mc().Presign(r.Context(), "PUT", bucketTODO, key, timeToStartUpload, url.Values{
"UploadId": []string{uploadID},
"PartNumber": []string{strconv.Itoa(n)},
"PartNumber": []string{strconv.Itoa(partNumber)},
"ContentLength": []string{strconv.FormatInt(c.Size, 10)},
})
if err != nil {
@ -118,9 +176,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
Digest: l.Digest,
Offset: c.Offset,
Size: c.Size,
// TODO(bmizerany): use signed+temp urls
URL: signedURL.String(),
URL: signedURL.String(),
})
}
}
@ -130,7 +186,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
// Commit the manifest
body := bytes.NewReader(pr.Manifest)
path := path.Join("manifests", path.Join(ref.Parts()...))
_, err := s.mc().PutObject(r.Context(), "test", path, body, int64(len(pr.Manifest)), minio.PutObjectOptions{})
_, err := s.mc().PutObject(r.Context(), bucketTODO, path, body, int64(len(pr.Manifest)), minio.PutObjectOptions{})
if err != nil {
return err
}
@ -175,12 +231,3 @@ func (s *Server) mc() *minio.Client {
}
return mc
}
func generateUploadID() string {
const hex = "0123456789abcdef"
b := make([]byte, 32)
for i := range b {
b[i] = hex[rand.Intn(len(hex))]
}
return string(b)
}

View File

@ -1,118 +1,189 @@
package registry
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http/httptest"
"os"
"os/exec"
"strings"
"testing"
"time"
"bllamo.com/registry/apitype"
"github.com/kr/pretty"
"bllamo.com/utils/backoff"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"kr.dev/diff"
)
func TestPush(t *testing.T) {
startMinio(t)
const abc = "abcdefghijklmnopqrstuvwxyz"
s := &Server{}
hs := httptest.NewServer(s)
t.Cleanup(hs.Close)
c := &Client{BaseURL: hs.URL}
func testPush(t *testing.T, chunkSize int64) {
t.Run(fmt.Sprintf("chunkSize=%d", chunkSize), func(t *testing.T) {
mc := startMinio(t, false)
manifest := []byte(`{
"layers": [
{"digest": "sha256-1", "size": 1},
{"digest": "sha256-2", "size": 2},
{"digest": "sha256-3", "size": 3}
]
}`)
manifest := []byte(`{
"layers": [
{"digest": "sha256-1", "size": 1},
{"digest": "sha256-2", "size": 2},
{"digest": "sha256-3", "size": 3}
]
}`)
const ref = "registry.ollama.ai/x/y:latest+Z"
const ref = "registry.ollama.ai/x/y:latest+Z"
got, err := c.Push(context.Background(), ref, manifest)
if err != nil {
t.Fatal(err)
}
hs := httptest.NewServer(&Server{
minioClient: mc,
UploadChunkSize: chunkSize,
})
t.Cleanup(hs.Close)
c := &Client{BaseURL: hs.URL}
diff.Test(t, t.Errorf, got, []apitype.Requirement{
{Digest: "sha256-1", Size: 1},
{Digest: "sha256-2", Size: 2},
{Digest: "sha256-3", Size: 3},
}, diff.ZeroFields[apitype.Requirement]("URL"))
for _, r := range got {
body := io.Reader(strings.NewReader(strings.Repeat("x", int(r.Size))))
if err := PushLayer(context.Background(), r.URL, r.Size, body); err != nil {
requirements, err := c.Push(context.Background(), ref, manifest, nil)
if err != nil {
t.Fatal(err)
}
}
got, err = c.Push(context.Background(), ref, manifest)
if err != nil {
t.Fatal(err)
}
if len(requirements) < 3 {
t.Fatalf("expected at least 3 requirements; got %d", len(requirements))
t.Logf("requirements: %v", requirements)
}
if len(got) != 0 {
t.Fatalf("unexpected requirements: % #v", pretty.Formatter(got))
}
var uploaded []apitype.CompletePart
for i, r := range requirements {
t.Logf("[%d] pushing layer: offset=%d size=%d", i, r.Offset, r.Size)
mc, err := minio.New("localhost:9000", &minio.Options{
Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""),
Secure: false,
})
if err != nil {
t.Fatal(err)
}
body := strings.NewReader(abc)
etag, err := PushLayer(context.Background(), r.URL, r.Offset, r.Size, body)
if err != nil {
t.Fatal(err)
}
uploaded = append(uploaded, apitype.CompletePart{
URL: r.URL,
ETag: etag,
})
}
var paths []string
keys := mc.ListObjects(context.Background(), "test", minio.ListObjectsOptions{
Recursive: true,
})
for k := range keys {
paths = append(paths, k.Key)
}
requirements, err = c.Push(context.Background(), ref, manifest, &PushParams{
Uploaded: uploaded,
})
if err != nil {
t.Fatal(err)
}
if len(requirements) != 0 {
t.Fatalf("unexpected requirements: %v", requirements)
}
t.Logf("paths: %v", paths)
var paths []string
keys := mc.ListObjects(context.Background(), "test", minio.ListObjectsOptions{
Recursive: true,
})
for k := range keys {
paths = append(paths, k.Key)
}
diff.Test(t, t.Errorf, paths, []string{
"blobs/sha256-1",
"blobs/sha256-2",
"blobs/sha256-3",
"manifests/registry.ollama.ai/x/y/latest/Z",
})
t.Logf("paths: %v", paths)
obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{})
if err != nil {
t.Fatal(err)
}
defer obj.Close()
diff.Test(t, t.Errorf, paths, []string{
"blobs/sha256-1",
"blobs/sha256-2",
"blobs/sha256-3",
"manifests/registry.ollama.ai/x/y/latest/Z",
})
var gotM apitype.Manifest
if err := json.NewDecoder(obj).Decode(&gotM); err != nil {
t.Fatal(err)
}
obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{})
if err != nil {
t.Fatal(err)
}
defer obj.Close()
diff.Test(t, t.Errorf, gotM, apitype.Manifest{
Layers: []apitype.Layer{
{Digest: "sha256-1", Size: 1},
{Digest: "sha256-2", Size: 2},
{Digest: "sha256-3", Size: 3},
},
var gotM apitype.Manifest
if err := json.NewDecoder(obj).Decode(&gotM); err != nil {
t.Fatal(err)
}
diff.Test(t, t.Errorf, gotM, apitype.Manifest{
Layers: []apitype.Layer{
{Digest: "sha256-1", Size: 1},
{Digest: "sha256-2", Size: 2},
{Digest: "sha256-3", Size: 3},
},
})
// checksum the blobs
for i, l := range gotM.Layers {
obj, err := mc.GetObject(context.Background(), "test", "blobs/"+l.Digest, minio.GetObjectOptions{})
if err != nil {
t.Fatal(err)
}
defer obj.Close()
info, err := obj.Stat()
if err != nil {
t.Fatal(err)
}
t.Logf("[%d] layer info: name=%q l.Size=%d size=%d", i, info.Key, l.Size, info.Size)
data, err := io.ReadAll(obj)
if err != nil {
t.Fatal(err)
}
got := string(data)
want := abc[:l.Size]
if got != want {
t.Errorf("[%d] got layer data = %q; want %q", i, got, want)
}
}
})
}
func startMinio(t *testing.T) {
func TestPush(t *testing.T) {
testPush(t, 0)
testPush(t, 1)
}
func availableAddr() string {
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
panic(err)
}
defer l.Close()
return l.Addr().String()
}
func startMinio(t *testing.T, debug bool) *minio.Client {
t.Helper()
dir := t.TempDir()
cmd := exec.Command("minio", "server", "--address", "localhost:9000", dir)
t.Logf(">> minio data dir: %s", dir)
addr := availableAddr()
cmd := exec.Command("minio", "server", "--address", addr, dir)
cmd.Env = os.Environ()
if debug {
stdout, err := cmd.StdoutPipe()
if err != nil {
t.Fatal(err)
}
doneLogging := make(chan struct{})
t.Cleanup(func() {
<-doneLogging
})
go func() {
defer close(doneLogging)
sc := bufio.NewScanner(stdout)
for sc.Scan() {
t.Logf("minio: %s", sc.Text())
}
}()
}
// TODO(bmizerany): wait delay etc...
if err := cmd.Start(); err != nil {
@ -131,7 +202,7 @@ func startMinio(t *testing.T) {
}
})
mc, err := minio.New("localhost:9000", &minio.Options{
mc, err := minio.New(addr, &minio.Options{
Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""),
Secure: false,
})
@ -139,17 +210,44 @@ func startMinio(t *testing.T) {
t.Fatal(err)
}
// wait for server to start
// TODO(bmizerany): use backoff
for {
_, err := mc.ListBuckets(context.Background())
if err == nil {
ctx, cancel := context.WithCancel(context.Background())
deadline, ok := t.Deadline()
if ok {
ctx, cancel = context.WithDeadline(ctx, deadline.Add(-100*time.Millisecond))
defer cancel()
}
// wait for server to start with exponential backoff
for _, err := range backoff.Upto(ctx, 1*time.Second) {
if err != nil {
t.Fatal(err)
}
if mc.IsOnline() {
break
}
time.Sleep(100 * time.Millisecond)
}
if err := mc.MakeBucket(context.Background(), "test", minio.MakeBucketOptions{}); err != nil {
t.Fatal(err)
}
return mc
}
// contextForTest returns a context that is canceled when the test deadline,
// if any, is reached. The returned doneLogging function should be called
// after all Log/Error/Fatalf calls are done before the test returns.
func contextForTest(t *testing.T) (_ context.Context, doneLogging func()) {
done := make(chan struct{})
deadline, ok := t.Deadline()
if !ok {
return context.Background(), func() {}
}
ctx, cancel := context.WithDeadline(context.Background(), deadline.Add(-100*time.Millisecond))
t.Cleanup(func() {
cancel()
<-done
})
return ctx, func() { close(done) }
}

58
utils/backoff/backoff.go Normal file
View File

@ -0,0 +1,58 @@
package backoff
import (
"context"
"errors"
"iter"
"math/rand"
"time"
)
// Errors
var (
// ErrMaxAttempts is not used by backoff but is available for use by
// callers that want to signal that a maximum number of retries has
// been exceeded. This should eliminate the need for callers to invent
// their own error.
ErrMaxAttempts = errors.New("max retries exceeded")
)
// Upto implements a backoff strategy that yields nil errors until the
// context is canceled, the maxRetries is exceeded, or yield returns false.
//
// The backoff strategy is a simple exponential backoff with a maximum
// backoff of maxBackoff. The backoff is randomized between 0.5-1.5 times
// the current backoff, in order to prevent accidental "thundering herd"
// problems.
func Upto(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
var n int
return func(yield func(int, error) bool) {
for {
if ctx.Err() != nil {
yield(n, ctx.Err())
return
}
n++
// n^2 backoff timer is a little smoother than the
// common choice of 2^n.
d := time.Duration(n*n) * 10 * time.Millisecond
if d > maxBackoff {
d = maxBackoff
}
// Randomize the delay between 0.5-1.5 x msec, in order
// to prevent accidental "thundering herd" problems.
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
t := time.NewTimer(d)
select {
case <-ctx.Done():
t.Stop()
case <-t.C:
if !yield(n, nil) {
return
}
}
}
}
}