diff --git a/x/api/api.go b/x/api/api.go index 57ad9fd9..428df45d 100644 --- a/x/api/api.go +++ b/x/api/api.go @@ -15,8 +15,8 @@ import ( // Common API Errors var ( - errUnqualifiedRef = oweb.Mistake("invalid", "name", "must be fully qualified") - errRefNotFound = oweb.Mistake("not_found", "name", "no such model") + errUnqualifiedRef = oweb.Invalid("invalid", "name", "must be fully qualified") + errRefNotFound = oweb.Invalid("not_found", "name", "no such model") ) type Server struct { diff --git a/x/build/blob/ref.go b/x/build/blob/ref.go index e7c75841..e105f5f6 100644 --- a/x/build/blob/ref.go +++ b/x/build/blob/ref.go @@ -253,6 +253,17 @@ func ParseRef(s string) Ref { return r } +// Complete is the same as ParseRef(s).Complete(). +// +// Future versions may be faster than calling ParseRef(s).Complete(), so if +// need to know if a ref is complete and don't need the ref, use this +// function. +func Complete(s string) bool { + // TODO(bmizerany): fast-path this with a quick scan withput + // allocating strings + return ParseRef(s).Complete() +} + func (r Ref) Valid() bool { // Name is required if !isValidPart(r.name) { diff --git a/x/client/ollama/ollama.go b/x/client/ollama/ollama.go index c87a5656..37b117f0 100644 --- a/x/client/ollama/ollama.go +++ b/x/client/ollama/ollama.go @@ -92,12 +92,23 @@ type Error struct { // Field is the field in the request that caused the error, if any. Field string `json:"field,omitempty"` + + // Value is the value of the field that caused the error, if any. + Value string `json:"value,omitempty"` } func (e *Error) Error() string { var b strings.Builder b.WriteString("ollama: ") b.WriteString(e.Code) + if e.Field != "" { + b.WriteString(" ") + b.WriteString(e.Field) + } + if e.Value != "" { + b.WriteString(": ") + b.WriteString(e.Value) + } if e.Message != "" { b.WriteString(": ") b.WriteString(e.Message) diff --git a/x/oweb/oweb.go b/x/oweb/oweb.go index 21da688f..0e8c3907 100644 --- a/x/oweb/oweb.go +++ b/x/oweb/oweb.go @@ -21,12 +21,13 @@ func Missing(field string) error { } } -func Mistake(code, field, message string) error { +func Invalid(field, value, format string, args ...any) error { return &ollama.Error{ Status: 400, - Code: code, + Code: "invalid", Field: field, - Message: fmt.Sprintf("%s: %s", field, message), + Value: value, + Message: fmt.Sprintf(format, args...), } } @@ -69,7 +70,7 @@ func DecodeUserJSON[T any](field string, r io.Reader) (*T, error) { if errors.As(err, &se) { msg = fmt.Sprintf("%s (%q) is not a %s", se.Field, se.Value, se.Type) } - return nil, Mistake("invalid_json", field, msg) + return nil, Invalid("invalid_json", field, "", msg) } func DecodeJSON[T any](r io.Reader) (*T, error) { diff --git a/x/registry/server.go b/x/registry/server.go index 6c5d3b57..ed270702 100644 --- a/x/registry/server.go +++ b/x/registry/server.go @@ -84,7 +84,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { ref := blob.ParseRef(pr.Ref) if !ref.Complete() { - return oweb.Mistake("invalid", "name", "must be complete") + return oweb.Invalid("name", pr.Ref, "must be complete") } m, err := oweb.DecodeUserJSON[apitype.Manifest]("manifest", bytes.NewReader(pr.Manifest)) @@ -107,24 +107,30 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { if err != nil { return err } + q := u.Query() + + // Check if this is a part upload, if not, skip uploadID := q.Get("uploadId") if uploadID == "" { // not a part upload continue } - partNumber, err := strconv.Atoi(q.Get("partNumber")) + + // PartNumber is required + queryPartNumber := q.Get("partNumber") + partNumber, err := strconv.Atoi(queryPartNumber) if err != nil { - return oweb.Mistake("invalid", "url", "invalid or missing PartNumber") + return oweb.Invalid("partNumber", queryPartNumber, "invalid or missing PartNumber") } + + // ETag is required if mcp.ETag == "" { - return oweb.Mistake("invalid", "etag", "missing") - } - cp, ok := completePartsByUploadID[uploadID] - if !ok { - cp = completeParts{key: u.Path} - completePartsByUploadID[uploadID] = cp + return oweb.Missing("etag") } + + cp := completePartsByUploadID[uploadID] + cp.key = u.Path cp.parts = append(cp.parts, minio.CompletePart{ PartNumber: partNumber, ETag: mcp.ETag, @@ -136,8 +142,11 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { var zeroOpts minio.PutObjectOptions _, err := mcc.CompleteMultipartUpload(r.Context(), bucketTODO, cp.key, uploadID, cp.parts, zeroOpts) if err != nil { - // log and continue; put backpressure on the client - log.Printf("error completing upload: %v", err) + var e minio.ErrorResponse + if errors.As(err, &e) && e.Code == "NoSuchUpload" { + return oweb.Invalid("uploadId", uploadID, "unknown uploadId") + } + return err } } diff --git a/x/registry/server_test.go b/x/registry/server_test.go index faf50445..255e0632 100644 --- a/x/registry/server_test.go +++ b/x/registry/server_test.go @@ -28,6 +28,39 @@ import ( "kr.dev/diff" ) +// const ref = "registry.ollama.ai/x/y:latest+Z" +// const manifest = `{ +// "layers": [ +// {"digest": "sha256-1", "size": 1}, +// {"digest": "sha256-2", "size": 2}, +// {"digest": "sha256-3", "size": 3} +// ] +// }` + +// ts := newTestServer(t) +// ts.pushNotOK(ref, `{}`, &ollama.Error{ +// Status: 400, +// Code: "invalid", +// Message: "name must be fully qualified", +// }) + +// ts.push(ref, `{ +// "layers": [ +// {"digest": "sha256-1", "size": 1}, +// {"digest": "sha256-2", "size": 2}, +// {"digest": "sha256-3", "size": 3} +// ] +// }`) + +type tWriter struct { + t *testing.T +} + +func (w tWriter) Write(p []byte) (n int, err error) { + w.t.Logf("%s", p) + return len(p), nil +} + func TestPushBasic(t *testing.T) { const MB = 1024 * 1024 @@ -41,6 +74,8 @@ func TestPushBasic(t *testing.T) { } }() + const ref = "registry.ollama.ai/x/y:latest+Z" + // Upload two small layers and one large layer that will // trigger a multipart upload. manifest := []byte(`{ @@ -49,9 +84,7 @@ func TestPushBasic(t *testing.T) { {"digest": "sha256-2", "size": 2}, {"digest": "sha256-3", "size": 11000000} ] - }`) - - const ref = "registry.ollama.ai/x/y:latest+Z" + }`) hs := httptest.NewServer(&Server{ minioClient: mc,