From f7cfe946dcca4e1924be585dcf089e36ad1b1332 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Wed, 3 Apr 2024 16:37:27 -0700 Subject: [PATCH] x/registry: fixing tests wip --- x/registry/client.go | 39 ++++++++++---- x/registry/server.go | 2 - x/registry/server_test.go | 104 ++++++++++++++------------------------ 3 files changed, 68 insertions(+), 77 deletions(-) diff --git a/x/registry/client.go b/x/registry/client.go index 747dde57..54d90bb5 100644 --- a/x/registry/client.go +++ b/x/registry/client.go @@ -4,9 +4,11 @@ import ( "cmp" "context" "encoding/xml" + "errors" "fmt" "io" "net/http" + "strings" "bllamo.com/client/ollama" "bllamo.com/registry/apitype" @@ -40,23 +42,42 @@ func (c *Client) Push(ctx context.Context, ref string, manifest []byte, p *PushP return v.Requirements, nil } -func PushLayer(ctx context.Context, dstURL string, off, size int64, file io.ReaderAt) (etag string, err error) { - sr := io.NewSectionReader(file, off, size) - req, err := http.NewRequestWithContext(ctx, "PUT", dstURL, sr) - if err != nil { - return "", err +func PushLayer(ctx context.Context, body io.ReaderAt, url string, off, n int64) (apitype.CompletePart, error) { + var zero apitype.CompletePart + if off < 0 { + return zero, errors.New("off must be >0") + } + + file := io.NewSectionReader(body, off, n) + req, err := http.NewRequest("PUT", url, file) + if err != nil { + return zero, err + } + req.ContentLength = n + + // TODO(bmizerany): take content type param + req.Header.Set("Content-Type", "text/plain") + + if n >= 0 { + req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", off, off+n-1)) } - req.ContentLength = size res, err := http.DefaultClient.Do(req) if err != nil { - return "", err + return zero, err } defer res.Body.Close() if res.StatusCode != 200 { - return "", parseS3Error(res) + e := parseS3Error(res) + return zero, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e) } - return res.Header.Get("ETag"), nil + etag := strings.Trim(res.Header.Get("ETag"), `"`) + cp := apitype.CompletePart{ + URL: url, + ETag: etag, + // TODO(bmizerany): checksum + } + return cp, nil } type s3Error struct { diff --git a/x/registry/server.go b/x/registry/server.go index e6baf94b..884778ef 100644 --- a/x/registry/server.go +++ b/x/registry/server.go @@ -6,7 +6,6 @@ import ( "cmp" "context" "errors" - "fmt" "log" "net/http" "net/url" @@ -131,7 +130,6 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { PartNumber: partNumber, ETag: etag, }) - fmt.Println("uploadID", uploadID, "partNumber", partNumber, "etag", etag) completePartsByUploadID[uploadID] = cp } diff --git a/x/registry/server_test.go b/x/registry/server_test.go index cd8463ac..3dbae41f 100644 --- a/x/registry/server_test.go +++ b/x/registry/server_test.go @@ -11,13 +11,11 @@ import ( "fmt" "io" "net" - "net/http" "net/http/httptest" "net/url" "os" "os/exec" "strconv" - "strings" "syscall" "testing" "time" @@ -30,8 +28,6 @@ import ( "kr.dev/diff" ) -const abc = "abcdefghijklmnopqrstuvwxyz" - func testPush(t *testing.T, chunkSize int64) { t.Run(fmt.Sprintf("chunkSize=%d", chunkSize), func(t *testing.T) { mc := startMinio(t, true) @@ -71,15 +67,11 @@ func testPush(t *testing.T, chunkSize int64) { for i, r := range requirements { t.Logf("[%d] pushing layer: offset=%d size=%d", i, r.Offset, r.Size) - body := strings.NewReader(abc) - etag, err := PushLayer(context.Background(), r.URL, r.Offset, r.Size, body) + cp, err := PushLayer(context.Background(), &abcReader{}, r.URL, r.Offset, r.Size) if err != nil { t.Fatal(err) } - uploaded = append(uploaded, apitype.CompletePart{ - URL: r.URL, - ETag: etag, - }) + uploaded = append(uploaded, cp) } requirements, err = c.Push(context.Background(), ref, manifest, &PushParams{ @@ -142,15 +134,8 @@ func testPush(t *testing.T, chunkSize int64) { } t.Logf("[%d] layer info: name=%q l.Size=%d size=%d", i, info.Key, l.Size, info.Size) - data, err := io.ReadAll(obj) - if err != nil { - t.Fatal(err) - } - - got := string(data) - want := abc[:l.Size] - if got != want { - t.Errorf("[%d] got layer data = %q; want %q", i, got, want) + if msg := checkABCs(obj, int(l.Size)); msg != "" { + t.Errorf("[%d] %s", i, msg) } } }) @@ -161,44 +146,6 @@ func TestPush(t *testing.T) { testPush(t, 1) } -func pushLayer(body io.ReaderAt, url string, off, n int64) (apitype.CompletePart, error) { - var zero apitype.CompletePart - if off < 0 { - return zero, errors.New("off must be >0") - } - - file := io.NewSectionReader(body, off, n) - req, err := http.NewRequest("PUT", url, file) - if err != nil { - return zero, err - } - req.ContentLength = n - - // TODO(bmizerany): take content type param - req.Header.Set("Content-Type", "text/plain") - - if n >= 0 { - req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", off, off+n-1)) - } - - res, err := http.DefaultClient.Do(req) - if err != nil { - return zero, err - } - defer res.Body.Close() - if res.StatusCode != 200 { - e := parseS3Error(res) - return zero, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e) - } - etag := strings.Trim(res.Header.Get("ETag"), `"`) - cp := apitype.CompletePart{ - URL: url, - ETag: etag, - // TODO(bmizerany): checksum - } - return cp, nil -} - // TestBasicPresignS3MultipartReferenceDoNotDelete tests the basic flow of // presigning a multipart upload, uploading the parts, and completing the // upload. It is for future reference and should not be deleted. This flow @@ -230,7 +177,7 @@ func TestBasicPresignS3MultipartReferenceDoNotDelete(t *testing.T) { t.Logf("[partNumber=%d]: %v", partNumber, u) var body abcReader - cp, err := pushLayer(&body, u.String(), c.Offset, c.N) + cp, err := PushLayer(context.Background(), &body, u.String(), c.Offset, c.N) if err != nil { t.Fatalf("[partNumber=%d]: %v", partNumber, err) } @@ -306,7 +253,7 @@ func startMinio(t *testing.T, trace bool) *minio.Client { // explicitly setting trace to true. trace = cmp.Or(trace, os.Getenv("OLLAMA_MINIO_TRACE") != "") - dir := t.TempDir() + "-keep" // prevent tempdir from auto delete + dir := t.TempDir() t.Cleanup(func() { // TODO(bmizerany): trim temp dir based on dates so that @@ -317,19 +264,18 @@ func startMinio(t *testing.T, trace bool) *minio.Client { if err := cmd.Wait(); err != nil { var e *exec.ExitError if errors.As(err, &e) { - if !e.Exited() { - // died due to our signal + if e.Exited() { return } - t.Errorf("startMinio: %s stderr: %s", cmd.Path, e.Stderr) - t.Errorf("startMinio: %s exit status: %v", cmd.Path, e.ExitCode()) - t.Errorf("startMinio: %s exited: %v", cmd.Path, e.Exited()) - t.Errorf("startMinio: %s stderr: %s", cmd.Path, e.Stderr) + t.Logf("startMinio: %s stderr: %s", cmd.Path, e.Stderr) + t.Logf("startMinio: %s exit status: %v", cmd.Path, e.ExitCode()) + t.Logf("startMinio: %s exited: %v", cmd.Path, e.Exited()) + t.Logf("startMinio: %s stderr: %s", cmd.Path, e.Stderr) } else { if errors.Is(err, context.Canceled) { return } - t.Errorf("startMinio: %s exit error: %v", cmd.Path, err) + t.Logf("startMinio: %s exit error: %v", cmd.Path, err) } } } @@ -343,6 +289,7 @@ func startMinio(t *testing.T, trace bool) *minio.Client { } t.Logf(">> minio: minio server %s", dir) + addr := availableAddr() cmd := exec.CommandContext(ctx, "minio", "server", "--address", addr, dir) cmd.Env = os.Environ() @@ -463,3 +410,28 @@ func (r *abcReader) ReadAt(p []byte, off int64) (n int, err error) { } return len(p), nil } + +func checkABCs(r io.Reader, size int) (reason string) { + h := sha256.New() + n, err := io.CopyN(h, &abcReader{}, int64(size)) + if err != nil { + return err.Error() + } + if n != int64(size) { + panic("short read; should not happen") + } + want := h.Sum(nil) + h = sha256.New() + n, err = io.Copy(h, r) + if err != nil { + return err.Error() + } + if n != int64(size) { + return fmt.Sprintf("got len(r) = %d; want %d", n, size) + } + got := h.Sum(nil) + if !bytes.Equal(got, want) { + return fmt.Sprintf("got sum = %x; want %x", got, want) + } + return "" +}