From a10a11b9d371f36b7c3510da32a1d70b74e27bd1 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Wed, 3 Apr 2024 10:39:30 -0700 Subject: [PATCH] registry: initial work on multipart pushes --- api/api.go | 18 ++- registry/apitype/apitype.go | 21 ++- registry/client.go | 51 +++++-- registry/server.go | 93 +++++++++---- registry/server_test.go | 260 +++++++++++++++++++++++++----------- utils/backoff/backoff.go | 58 ++++++++ 6 files changed, 378 insertions(+), 123 deletions(-) create mode 100644 utils/backoff/backoff.go diff --git a/api/api.go b/api/api.go index b1955a18..2d7800e4 100644 --- a/api/api.go +++ b/api/api.go @@ -10,6 +10,7 @@ import ( "bllamo.com/client/ollama/apitype" "bllamo.com/oweb" "bllamo.com/registry" + regtype "bllamo.com/registry/apitype" ) // Common API Errors @@ -64,11 +65,12 @@ func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error { } c := registry.Client{BaseURL: registryURLTODO} - requirements, err := c.Push(r.Context(), params.Name, man) + requirements, err := c.Push(r.Context(), params.Name, man, nil) if err != nil { return err } + var uploads []regtype.CompletePart for _, rq := range requirements { l, err := s.Build.LayerFile(rq.Digest) if err != nil { @@ -80,7 +82,15 @@ func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error { return err } defer f.Close() - return registry.PushLayer(r.Context(), rq.URL, rq.Size, f) + etag, err := registry.PushLayer(r.Context(), rq.URL, rq.Offset, rq.Size, f) + if err != nil { + return err + } + uploads = append(uploads, regtype.CompletePart{ + URL: rq.URL, + ETag: etag, + }) + return nil }() if err != nil { return err @@ -88,7 +98,9 @@ func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error { } // commit the manifest to the registry - requirements, err = c.Push(r.Context(), params.Name, man) + requirements, err = c.Push(r.Context(), params.Name, man, ®istry.PushParams{ + Uploaded: uploads, + }) if err != nil { return err } diff --git a/registry/apitype/apitype.go b/registry/apitype/apitype.go index 19cda1ab..36f2a342 100644 --- a/registry/apitype/apitype.go +++ b/registry/apitype/apitype.go @@ -6,6 +6,11 @@ type Manifest struct { Layers []Layer `json:"layers"` } +type CompletePart struct { + URL string `json:"url"` // contains PartNumber and UploadID from server + ETag string `json:"etag"` +} + type Layer struct { Digest string `json:"digest"` MediaType string `json:"mediaType"` @@ -13,15 +18,25 @@ type Layer struct { } type PushRequest struct { - Ref string `json:"ref"` - Manifest json.RawMessage + Ref string `json:"ref"` + Manifest json.RawMessage `json:"manifest"` + + // Parts is a list of upload parts that the client upload in the previous + // push. + Uploaded []CompletePart `json:"part_uploads"` } type Requirement struct { Digest string `json:"digest"` Offset int64 `json:"offset"` Size int64 `json:"Size"` - URL string `json:"url"` + + // URL is the url to PUT the layer to. + // + // Clients must include it as the URL, alond with the ETag in the + // response headers from the PUT request, in the next push request + // in the Uploaded field. + URL string `json:"url"` } type PushResponse struct { diff --git a/registry/client.go b/registry/client.go index 82616380..747dde57 100644 --- a/registry/client.go +++ b/registry/client.go @@ -1,7 +1,10 @@ package registry import ( + "cmp" "context" + "encoding/xml" + "fmt" "io" "net/http" @@ -18,12 +21,18 @@ func (c *Client) oclient() *ollama.Client { return (*ollama.Client)(c) } +type PushParams struct { + Uploaded []apitype.CompletePart +} + // Push pushes a manifest to the server. -func (c *Client) Push(ctx context.Context, ref string, manifest []byte) ([]apitype.Requirement, error) { +func (c *Client) Push(ctx context.Context, ref string, manifest []byte, p *PushParams) ([]apitype.Requirement, error) { + p = cmp.Or(p, &PushParams{}) // TODO(bmizerany): backoff v, err := ollama.Do[apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{ Ref: ref, Manifest: manifest, + Uploaded: p.Uploaded, }) if err != nil { return nil, err @@ -31,26 +40,42 @@ func (c *Client) Push(ctx context.Context, ref string, manifest []byte) ([]apity return v.Requirements, nil } -func PushLayer(ctx context.Context, dstURL string, size int64, file io.Reader) error { - req, err := http.NewRequest("PUT", dstURL, file) +func PushLayer(ctx context.Context, dstURL string, off, size int64, file io.ReaderAt) (etag string, err error) { + sr := io.NewSectionReader(file, off, size) + req, err := http.NewRequestWithContext(ctx, "PUT", dstURL, sr) if err != nil { - return err + return "", err } req.ContentLength = size res, err := http.DefaultClient.Do(req) if err != nil { - return err + return "", err } defer res.Body.Close() if res.StatusCode != 200 { - e := &ollama.Error{Status: res.StatusCode} - msg, err := io.ReadAll(res.Body) - if err != nil { - return err - } - // TODO(bmizerany): format error message - e.Message = string(msg) + return "", parseS3Error(res) } - return nil + return res.Header.Get("ETag"), nil +} + +type s3Error struct { + XMLName xml.Name `xml:"Error"` + Code string `xml:"Code"` + Message string `xml:"Message"` + Resource string `xml:"Resource"` + RequestId string `xml:"RequestId"` +} + +func (e *s3Error) Error() string { + return fmt.Sprintf("S3 (%s): %s: %s: %s", e.RequestId, e.Resource, e.Code, e.Message) +} + +// parseS3Error parses an XML error response from S3. +func parseS3Error(res *http.Response) error { + var se *s3Error + if err := xml.NewDecoder(res.Body).Decode(&se); err != nil { + return err + } + return se } diff --git a/registry/server.go b/registry/server.go index 6d99669a..91659767 100644 --- a/registry/server.go +++ b/registry/server.go @@ -6,8 +6,8 @@ import ( "cmp" "context" "errors" + "fmt" "log" - "math/rand" "net/http" "net/url" "os" @@ -70,7 +70,13 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error { } } +func (s *Server) uploadChunkSize() int64 { + return cmp.Or(s.UploadChunkSize, DefaultUploadChunkSize) +} + func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { + const bucketTODO = "test" + pr, err := oweb.DecodeUserJSON[apitype.PushRequest]("", r.Body) if err != nil { return err @@ -78,7 +84,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { ref := blob.ParseRef(pr.Ref) if !ref.Complete() { - return oweb.Mistake("invalid", "name", "must be fully qualified") + return oweb.Mistake("invalid", "name", "must be complete") } m, err := oweb.DecodeUserJSON[apitype.Manifest]("manifest", bytes.NewReader(pr.Manifest)) @@ -86,28 +92,80 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { return err } - // TODO(bmizerany): parallelize + mcc := &minio.Core{Client: s.mc()} + // TODO(bmizerany): complete uploads before stats for any with ETag + + type completeParts struct { + key string + parts []minio.CompletePart + } + + completePartsByUploadID := make(map[string]completeParts) + for _, pu := range pr.Uploaded { + // parse the URL + u, err := url.Parse(pu.URL) + if err != nil { + return err + } + q := u.Query() + uploadID := q.Get("UploadId") + if uploadID == "" { + return oweb.Mistake("invalid", "url", "missing UploadId") + } + partNumber, err := strconv.Atoi(q.Get("PartNumber")) + if err != nil { + return oweb.Mistake("invalid", "url", "invalid or missing PartNumber") + } + etag := pu.ETag + if etag == "" { + return oweb.Mistake("invalid", "etag", "missing") + } + cp, ok := completePartsByUploadID[uploadID] + if !ok { + cp = completeParts{key: u.Path} + completePartsByUploadID[uploadID] = cp + } + cp.parts = append(cp.parts, minio.CompletePart{ + PartNumber: partNumber, + ETag: etag, + }) + fmt.Println("uploadID", uploadID, "partNumber", partNumber, "etag", etag) + completePartsByUploadID[uploadID] = cp + } + + for uploadID, cp := range completePartsByUploadID { + var zeroOpts minio.PutObjectOptions + _, err := mcc.CompleteMultipartUpload(r.Context(), bucketTODO, cp.key, uploadID, cp.parts, zeroOpts) + if err != nil { + // log and continue; put backpressure on the client + log.Printf("error completing upload: %v", err) + } + } + var requirements []apitype.Requirement for _, l := range m.Layers { + // TODO(bmizerany): do in parallel if l.Size == 0 { continue } // TODO(bmizerany): "global" throttle of rate of transfer - pushed, err := s.statObject(r.Context(), l.Digest) if err != nil { return err } if !pushed { - uploadID := generateUploadID() - for n, c := range upload.Chunks(l.Size, cmp.Or(s.UploadChunkSize, DefaultUploadChunkSize)) { - const expires = 15 * time.Minute + key := path.Join("blobs", l.Digest) + uploadID, err := mcc.NewMultipartUpload(r.Context(), bucketTODO, key, minio.PutObjectOptions{}) + if err != nil { + return err + } + for partNumber, c := range upload.Chunks(l.Size, s.uploadChunkSize()) { + const timeToStartUpload = 15 * time.Minute - key := path.Join("blobs", l.Digest) - signedURL, err := s.mc().Presign(r.Context(), "PUT", "test", key, expires, url.Values{ + signedURL, err := s.mc().Presign(r.Context(), "PUT", bucketTODO, key, timeToStartUpload, url.Values{ "UploadId": []string{uploadID}, - "PartNumber": []string{strconv.Itoa(n)}, + "PartNumber": []string{strconv.Itoa(partNumber)}, "ContentLength": []string{strconv.FormatInt(c.Size, 10)}, }) if err != nil { @@ -118,9 +176,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { Digest: l.Digest, Offset: c.Offset, Size: c.Size, - - // TODO(bmizerany): use signed+temp urls - URL: signedURL.String(), + URL: signedURL.String(), }) } } @@ -130,7 +186,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { // Commit the manifest body := bytes.NewReader(pr.Manifest) path := path.Join("manifests", path.Join(ref.Parts()...)) - _, err := s.mc().PutObject(r.Context(), "test", path, body, int64(len(pr.Manifest)), minio.PutObjectOptions{}) + _, err := s.mc().PutObject(r.Context(), bucketTODO, path, body, int64(len(pr.Manifest)), minio.PutObjectOptions{}) if err != nil { return err } @@ -175,12 +231,3 @@ func (s *Server) mc() *minio.Client { } return mc } - -func generateUploadID() string { - const hex = "0123456789abcdef" - b := make([]byte, 32) - for i := range b { - b[i] = hex[rand.Intn(len(hex))] - } - return string(b) -} diff --git a/registry/server_test.go b/registry/server_test.go index 466fd787..8cb1ecc1 100644 --- a/registry/server_test.go +++ b/registry/server_test.go @@ -1,118 +1,189 @@ package registry import ( + "bufio" "context" "encoding/json" "errors" + "fmt" "io" + "net" "net/http/httptest" + "os" "os/exec" "strings" "testing" "time" "bllamo.com/registry/apitype" - "github.com/kr/pretty" + "bllamo.com/utils/backoff" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" "kr.dev/diff" ) -func TestPush(t *testing.T) { - startMinio(t) +const abc = "abcdefghijklmnopqrstuvwxyz" - s := &Server{} - hs := httptest.NewServer(s) - t.Cleanup(hs.Close) - c := &Client{BaseURL: hs.URL} +func testPush(t *testing.T, chunkSize int64) { + t.Run(fmt.Sprintf("chunkSize=%d", chunkSize), func(t *testing.T) { + mc := startMinio(t, false) - manifest := []byte(`{ - "layers": [ - {"digest": "sha256-1", "size": 1}, - {"digest": "sha256-2", "size": 2}, - {"digest": "sha256-3", "size": 3} - ] - }`) + manifest := []byte(`{ + "layers": [ + {"digest": "sha256-1", "size": 1}, + {"digest": "sha256-2", "size": 2}, + {"digest": "sha256-3", "size": 3} + ] + }`) - const ref = "registry.ollama.ai/x/y:latest+Z" + const ref = "registry.ollama.ai/x/y:latest+Z" - got, err := c.Push(context.Background(), ref, manifest) - if err != nil { - t.Fatal(err) - } + hs := httptest.NewServer(&Server{ + minioClient: mc, + UploadChunkSize: chunkSize, + }) + t.Cleanup(hs.Close) + c := &Client{BaseURL: hs.URL} - diff.Test(t, t.Errorf, got, []apitype.Requirement{ - {Digest: "sha256-1", Size: 1}, - {Digest: "sha256-2", Size: 2}, - {Digest: "sha256-3", Size: 3}, - }, diff.ZeroFields[apitype.Requirement]("URL")) - - for _, r := range got { - body := io.Reader(strings.NewReader(strings.Repeat("x", int(r.Size)))) - if err := PushLayer(context.Background(), r.URL, r.Size, body); err != nil { + requirements, err := c.Push(context.Background(), ref, manifest, nil) + if err != nil { t.Fatal(err) } - } - got, err = c.Push(context.Background(), ref, manifest) - if err != nil { - t.Fatal(err) - } + if len(requirements) < 3 { + t.Fatalf("expected at least 3 requirements; got %d", len(requirements)) + t.Logf("requirements: %v", requirements) + } - if len(got) != 0 { - t.Fatalf("unexpected requirements: % #v", pretty.Formatter(got)) - } + var uploaded []apitype.CompletePart + for i, r := range requirements { + t.Logf("[%d] pushing layer: offset=%d size=%d", i, r.Offset, r.Size) - mc, err := minio.New("localhost:9000", &minio.Options{ - Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""), - Secure: false, - }) - if err != nil { - t.Fatal(err) - } + body := strings.NewReader(abc) + etag, err := PushLayer(context.Background(), r.URL, r.Offset, r.Size, body) + if err != nil { + t.Fatal(err) + } + uploaded = append(uploaded, apitype.CompletePart{ + URL: r.URL, + ETag: etag, + }) + } - var paths []string - keys := mc.ListObjects(context.Background(), "test", minio.ListObjectsOptions{ - Recursive: true, - }) - for k := range keys { - paths = append(paths, k.Key) - } + requirements, err = c.Push(context.Background(), ref, manifest, &PushParams{ + Uploaded: uploaded, + }) + if err != nil { + t.Fatal(err) + } + if len(requirements) != 0 { + t.Fatalf("unexpected requirements: %v", requirements) + } - t.Logf("paths: %v", paths) + var paths []string + keys := mc.ListObjects(context.Background(), "test", minio.ListObjectsOptions{ + Recursive: true, + }) + for k := range keys { + paths = append(paths, k.Key) + } - diff.Test(t, t.Errorf, paths, []string{ - "blobs/sha256-1", - "blobs/sha256-2", - "blobs/sha256-3", - "manifests/registry.ollama.ai/x/y/latest/Z", - }) + t.Logf("paths: %v", paths) - obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{}) - if err != nil { - t.Fatal(err) - } - defer obj.Close() + diff.Test(t, t.Errorf, paths, []string{ + "blobs/sha256-1", + "blobs/sha256-2", + "blobs/sha256-3", + "manifests/registry.ollama.ai/x/y/latest/Z", + }) - var gotM apitype.Manifest - if err := json.NewDecoder(obj).Decode(&gotM); err != nil { - t.Fatal(err) - } + obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{}) + if err != nil { + t.Fatal(err) + } + defer obj.Close() - diff.Test(t, t.Errorf, gotM, apitype.Manifest{ - Layers: []apitype.Layer{ - {Digest: "sha256-1", Size: 1}, - {Digest: "sha256-2", Size: 2}, - {Digest: "sha256-3", Size: 3}, - }, + var gotM apitype.Manifest + if err := json.NewDecoder(obj).Decode(&gotM); err != nil { + t.Fatal(err) + } + + diff.Test(t, t.Errorf, gotM, apitype.Manifest{ + Layers: []apitype.Layer{ + {Digest: "sha256-1", Size: 1}, + {Digest: "sha256-2", Size: 2}, + {Digest: "sha256-3", Size: 3}, + }, + }) + + // checksum the blobs + for i, l := range gotM.Layers { + obj, err := mc.GetObject(context.Background(), "test", "blobs/"+l.Digest, minio.GetObjectOptions{}) + if err != nil { + t.Fatal(err) + } + defer obj.Close() + + info, err := obj.Stat() + if err != nil { + t.Fatal(err) + } + t.Logf("[%d] layer info: name=%q l.Size=%d size=%d", i, info.Key, l.Size, info.Size) + + data, err := io.ReadAll(obj) + if err != nil { + t.Fatal(err) + } + + got := string(data) + want := abc[:l.Size] + if got != want { + t.Errorf("[%d] got layer data = %q; want %q", i, got, want) + } + } }) } -func startMinio(t *testing.T) { +func TestPush(t *testing.T) { + testPush(t, 0) + testPush(t, 1) +} + +func availableAddr() string { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + panic(err) + } + defer l.Close() + return l.Addr().String() +} + +func startMinio(t *testing.T, debug bool) *minio.Client { t.Helper() dir := t.TempDir() - cmd := exec.Command("minio", "server", "--address", "localhost:9000", dir) + t.Logf(">> minio data dir: %s", dir) + addr := availableAddr() + cmd := exec.Command("minio", "server", "--address", addr, dir) + cmd.Env = os.Environ() + + if debug { + stdout, err := cmd.StdoutPipe() + if err != nil { + t.Fatal(err) + } + doneLogging := make(chan struct{}) + t.Cleanup(func() { + <-doneLogging + }) + go func() { + defer close(doneLogging) + sc := bufio.NewScanner(stdout) + for sc.Scan() { + t.Logf("minio: %s", sc.Text()) + } + }() + } // TODO(bmizerany): wait delay etc... if err := cmd.Start(); err != nil { @@ -131,7 +202,7 @@ func startMinio(t *testing.T) { } }) - mc, err := minio.New("localhost:9000", &minio.Options{ + mc, err := minio.New(addr, &minio.Options{ Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""), Secure: false, }) @@ -139,17 +210,44 @@ func startMinio(t *testing.T) { t.Fatal(err) } - // wait for server to start - // TODO(bmizerany): use backoff - for { - _, err := mc.ListBuckets(context.Background()) - if err == nil { + ctx, cancel := context.WithCancel(context.Background()) + deadline, ok := t.Deadline() + if ok { + ctx, cancel = context.WithDeadline(ctx, deadline.Add(-100*time.Millisecond)) + defer cancel() + } + + // wait for server to start with exponential backoff + for _, err := range backoff.Upto(ctx, 1*time.Second) { + if err != nil { + t.Fatal(err) + } + if mc.IsOnline() { break } - time.Sleep(100 * time.Millisecond) } if err := mc.MakeBucket(context.Background(), "test", minio.MakeBucketOptions{}); err != nil { t.Fatal(err) } + + return mc +} + +// contextForTest returns a context that is canceled when the test deadline, +// if any, is reached. The returned doneLogging function should be called +// after all Log/Error/Fatalf calls are done before the test returns. +func contextForTest(t *testing.T) (_ context.Context, doneLogging func()) { + done := make(chan struct{}) + deadline, ok := t.Deadline() + if !ok { + return context.Background(), func() {} + } + + ctx, cancel := context.WithDeadline(context.Background(), deadline.Add(-100*time.Millisecond)) + t.Cleanup(func() { + cancel() + <-done + }) + return ctx, func() { close(done) } } diff --git a/utils/backoff/backoff.go b/utils/backoff/backoff.go new file mode 100644 index 00000000..b77f8706 --- /dev/null +++ b/utils/backoff/backoff.go @@ -0,0 +1,58 @@ +package backoff + +import ( + "context" + "errors" + "iter" + "math/rand" + "time" +) + +// Errors +var ( + // ErrMaxAttempts is not used by backoff but is available for use by + // callers that want to signal that a maximum number of retries has + // been exceeded. This should eliminate the need for callers to invent + // their own error. + ErrMaxAttempts = errors.New("max retries exceeded") +) + +// Upto implements a backoff strategy that yields nil errors until the +// context is canceled, the maxRetries is exceeded, or yield returns false. +// +// The backoff strategy is a simple exponential backoff with a maximum +// backoff of maxBackoff. The backoff is randomized between 0.5-1.5 times +// the current backoff, in order to prevent accidental "thundering herd" +// problems. +func Upto(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] { + var n int + return func(yield func(int, error) bool) { + for { + if ctx.Err() != nil { + yield(n, ctx.Err()) + return + } + + n++ + + // n^2 backoff timer is a little smoother than the + // common choice of 2^n. + d := time.Duration(n*n) * 10 * time.Millisecond + if d > maxBackoff { + d = maxBackoff + } + // Randomize the delay between 0.5-1.5 x msec, in order + // to prevent accidental "thundering herd" problems. + d = time.Duration(float64(d) * (rand.Float64() + 0.5)) + t := time.NewTimer(d) + select { + case <-ctx.Done(): + t.Stop() + case <-t.C: + if !yield(n, nil) { + return + } + } + } + } +}