From eb2c442a015741fc37af6d22109dd3adea594105 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sun, 31 Mar 2024 11:36:51 -0700 Subject: [PATCH] oweb: make DecodeUserJSON take a field name This allows for better error messages when decoding fails. For example, instead of: {"code":"invalid_json","message":"unexpected end of JSON input"} We now get: {"code":"invalid_json","field":"manifest","message":"unexpected end of JSON input"} --- api/api.go | 4 ++++ oweb/oweb.go | 15 ++++++++------- registry/apitypes.go | 5 ++++- registry/client.go | 15 ++++++++++----- registry/server.go | 14 ++++++++++---- 5 files changed, 36 insertions(+), 17 deletions(-) diff --git a/api/api.go b/api/api.go index fd525937..d8fad5dd 100644 --- a/api/api.go +++ b/api/api.go @@ -104,3 +104,7 @@ func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error { return err } + +func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error { + return oweb.ErrNotFound +} diff --git a/oweb/oweb.go b/oweb/oweb.go index 352ba0c6..6c586c14 100644 --- a/oweb/oweb.go +++ b/oweb/oweb.go @@ -63,20 +63,21 @@ func Serve(h HandlerFunc, w http.ResponseWriter, r *http.Request) { } } -func DecodeUserJSON[T any](r io.Reader) (*T, error) { +func DecodeUserJSON[T any](field string, r io.Reader) (*T, error) { v, err := DecodeJSON[T](r) + if err == nil { + return v, nil + } + var msg string var e *json.SyntaxError if errors.As(err, &e) { - return nil, &ollama.Error{Code: "invalid_json", Message: e.Error()} + msg = e.Error() } var se *json.UnmarshalTypeError if errors.As(err, &se) { - return nil, &ollama.Error{ - Code: "invalid_json", - Message: fmt.Sprintf("%s (%q) is not a %s", se.Field, se.Value, se.Type), - } + msg = fmt.Sprintf("%s (%q) is not a %s", se.Field, se.Value, se.Type) } - return v, err + return nil, Mistake("invalid_json", field, msg) } func DecodeJSON[T any](r io.Reader) (*T, error) { diff --git a/registry/apitypes.go b/registry/apitypes.go index dc3e22d9..599bfe9a 100644 --- a/registry/apitypes.go +++ b/registry/apitypes.go @@ -1,5 +1,7 @@ package registry +import "encoding/json" + type Manifest struct { Layers []Layer `json:"layers"` } @@ -11,7 +13,8 @@ type Layer struct { } type PushRequest struct { - Manifest Manifest `json:"manifest"` + Ref string `json:"ref"` + Manifest json.RawMessage } type Requirement struct { diff --git a/registry/client.go b/registry/client.go index b26be554..1e3e9c88 100644 --- a/registry/client.go +++ b/registry/client.go @@ -2,7 +2,6 @@ package registry import ( "context" - "encoding/json" "io" "net/http" @@ -10,15 +9,21 @@ import ( ) type Client struct { - BaseURL string + BaseURL string + HTTPClient *http.Client +} + +func (c *Client) oclient() *ollama.Client { + return (*ollama.Client)(c) } // Push pushes a manifest to the server. func (c *Client) Push(ctx context.Context, ref string, manifest []byte) ([]Requirement, error) { // TODO(bmizerany): backoff - v, err := ollama.Do[PushResponse](ctx, "POST", c.BaseURL+"/v1/push/"+ref, struct { - Manifest json.RawMessage `json:"manifest"` - }{manifest}) + v, err := ollama.Do[PushResponse](ctx, c.oclient(), "POST", "/v1/push", &PushRequest{ + Ref: ref, + Manifest: manifest, + }) if err != nil { return nil, err } diff --git a/registry/server.go b/registry/server.go index 991a1240..f1d03cc5 100644 --- a/registry/server.go +++ b/registry/server.go @@ -2,6 +2,7 @@ package registry import ( + "bytes" "cmp" "context" "errors" @@ -34,16 +35,16 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error { switch { - case strings.HasPrefix(r.URL.Path, "/v1/push/"): + case strings.HasPrefix(r.URL.Path, "/v1/push"): return s.handlePush(w, r) - case strings.HasPrefix(r.URL.Path, "/v1/pull/"): + case strings.HasPrefix(r.URL.Path, "/v1/pull"): return s.handlePull(w, r) } return oweb.ErrNotFound } func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { - pr, err := oweb.DecodeUserJSON[PushRequest](r.Body) + pr, err := oweb.DecodeUserJSON[PushRequest]("", r.Body) if err != nil { return err } @@ -53,9 +54,14 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error { Secure: false, }) + m, err := oweb.DecodeUserJSON[Manifest]("manifest", bytes.NewReader(pr.Manifest)) + if err != nil { + return err + } + // TODO(bmizerany): parallelize var requirements []Requirement - for _, l := range pr.Manifest.Layers { + for _, l := range m.Layers { if l.Size == 0 { continue }