diff --git a/api/client.go b/api/client.go index ccbcbf6b..29ab2698 100644 --- a/api/client.go +++ b/api/client.go @@ -10,6 +10,20 @@ import ( "net/url" ) +type StatusError struct { + StatusCode int + Status string + Message string +} + +func (e StatusError) Error() string { + if e.Message != "" { + return fmt.Sprintf("%s: %s", e.Status, e.Message) + } + + return e.Status +} + type Client struct { base url.URL } @@ -25,7 +39,7 @@ func NewClient(hosts ...string) *Client { } } -func (c *Client) stream(ctx context.Context, method, path string, data any, callback func([]byte) error) error { +func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error { var buf *bytes.Buffer if data != nil { bts, err := json.Marshal(data) @@ -53,7 +67,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, call scanner := bufio.NewScanner(response.Body) for scanner.Scan() { var errorResponse struct { - Error string `json:"error"` + Error string `json:"error,omitempty"` } bts := scanner.Bytes() @@ -61,11 +75,15 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, call return fmt.Errorf("unmarshal: %w", err) } - if len(errorResponse.Error) > 0 { - return fmt.Errorf("stream: %s", errorResponse.Error) + if response.StatusCode >= 400 { + return StatusError{ + StatusCode: response.StatusCode, + Status: response.Status, + Message: errorResponse.Error, + } } - if err := callback(bts); err != nil { + if err := fn(bts); err != nil { return err } } diff --git a/cmd/cmd.go b/cmd/cmd.go index 8421b8f5..ca924ae9 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "net" + "net/http" "os" "path" "strings" @@ -34,7 +35,14 @@ func RunRun(cmd *cobra.Command, args []string) error { switch { case errors.Is(err, os.ErrNotExist): if err := pull(args[0]); err != nil { - return err + var apiStatusError api.StatusError + if !errors.As(err, &apiStatusError) { + return err + } + + if apiStatusError.StatusCode != http.StatusBadGateway { + return err + } } case err != nil: return err @@ -50,11 +58,12 @@ func pull(model string) error { context.Background(), &api.PullRequest{Model: model}, func(progress api.PullProgress) error { - if bar == nil && progress.Percent == 100 { - // already downloaded - return nil - } if bar == nil { + if progress.Percent == 100 { + // already downloaded + return nil + } + bar = progressbar.DefaultBytes(progress.Total) } diff --git a/server/routes.go b/server/routes.go index 94894fdb..1478f9ae 100644 --- a/server/routes.go +++ b/server/routes.go @@ -108,7 +108,7 @@ func pull(c *gin.Context) { remote, err := getRemote(req.Model) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) return }