Compare commits

...

2 Commits

Author SHA1 Message Date
Michael Yang
5c17743b2c remove json stream errors 2024-01-22 11:08:41 -08:00
Michael Yang
0bd5245acf fix: status on errors
HTTP status on errors when stream:=false is always 500 Internal Server
Error because the individual errors are not handled.

The most common errors are:

- pull, push: pulling or pushing a model that doesn't exist should
  return 404 Not Found
- push: pushing a model into a place the user is authorized to should
  return 401 Unauthorized
2024-01-22 11:07:53 -08:00
3 changed files with 57 additions and 20 deletions

View File

@ -280,7 +280,7 @@ func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
bts, err := os.ReadFile(fp) bts, err := os.ReadFile(fp)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("couldn't open file '%s'", fp) return nil, "", fmt.Errorf("couldn't open file '%s': %w", fp, err)
} }
shaSum := sha256.Sum256(bts) shaSum := sha256.Sum256(bts)
@ -971,7 +971,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil { if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err)) slog.Info(fmt.Sprintf("error uploading blob: %v", err))
if errors.Is(err, errUnauthorized) { if errors.Is(err, errUnauthorized) {
return fmt.Errorf("unable to push %s, make sure this namespace exists and you are authorized to push to it", ParseModelPath(name).GetNamespaceRepository()) return fmt.Errorf("%w: unable to push %s, make sure this namespace exists and you are authorized to push to it", err, ParseModelPath(name).GetNamespaceRepository())
} }
return err return err
} }
@ -1031,7 +1031,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
manifest, err = pullModelManifest(ctx, mp, regOpts) manifest, err = pullModelManifest(ctx, mp, regOpts)
if err != nil { if err != nil {
return fmt.Errorf("pull model manifest: %s", err) return fmt.Errorf("pull model manifest: %w", err)
} }
var layers []*Layer var layers []*Layer
@ -1169,7 +1169,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
case resp.StatusCode >= http.StatusBadRequest: case resp.StatusCode >= http.StatusBadRequest:
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err) return nil, fmt.Errorf("%d: %w", resp.StatusCode, err)
} }
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody) return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
default: default:

View File

@ -86,7 +86,7 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D
// show a generalized compatibility error until there is a better way to // show a generalized compatibility error until there is a better way to
// check for model compatibility // check for model compatibility
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") { if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName) err = fmt.Errorf("%w: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
} }
return err return err
@ -254,7 +254,7 @@ func GenerateHandler(c *gin.Context) {
// Build up the full response // Build up the full response
if _, err := generated.WriteString(r.Content); err != nil { if _, err := generated.WriteString(r.Content); err != nil {
ch <- gin.H{"error": err.Error()} ch <- err
return return
} }
@ -280,12 +280,12 @@ func GenerateHandler(c *gin.Context) {
promptVars.Response = generated.String() promptVars.Response = generated.String()
result, err := model.PostResponseTemplate(promptVars) result, err := model.PostResponseTemplate(promptVars)
if err != nil { if err != nil {
ch <- gin.H{"error": err.Error()} ch <- err
return return
} }
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+result) embd, err := loaded.runner.Encode(c.Request.Context(), prompt+result)
if err != nil { if err != nil {
ch <- gin.H{"error": err.Error()} ch <- err
return return
} }
resp.Context = embd resp.Context = embd
@ -303,7 +303,7 @@ func GenerateHandler(c *gin.Context) {
Options: opts, Options: opts,
} }
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- err
} }
}() }()
@ -439,7 +439,7 @@ func PullModelHandler(c *gin.Context) {
defer cancel() defer cancel()
if err := PullModel(ctx, model, regOpts, fn); err != nil { if err := PullModel(ctx, model, regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- err
} }
}() }()
@ -488,7 +488,7 @@ func PushModelHandler(c *gin.Context) {
defer cancel() defer cancel()
if err := PushModel(ctx, model, regOpts, fn); err != nil { if err := PushModel(ctx, model, regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- err
} }
}() }()
@ -561,7 +561,7 @@ func CreateModelHandler(c *gin.Context) {
defer cancel() defer cancel()
if err := CreateModel(ctx, model, filepath.Dir(req.Path), commands, fn); err != nil { if err := CreateModel(ctx, model, filepath.Dir(req.Path), commands, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- err
} }
}() }()
@ -979,14 +979,17 @@ func waitForStream(c *gin.Context, ch chan interface{}) {
c.JSON(http.StatusOK, r) c.JSON(http.StatusOK, r)
return return
} }
case gin.H: case error:
if errorMsg, ok := r["error"].(string); ok { status := http.StatusInternalServerError
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg}) switch {
return case errors.Is(r, os.ErrNotExist):
} else { status = http.StatusNotFound
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"}) case errors.Is(r, errUnauthorized):
return status = http.StatusUnauthorized
} }
c.JSON(status, gin.H{"error": r.Error()})
return
default: default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"})
return return
@ -1003,6 +1006,10 @@ func streamResponse(c *gin.Context, ch chan any) {
return false return false
} }
if err, ok := val.(error); ok {
val = gin.H{"error": err.Error()}
}
bts, err := json.Marshal(val) bts, err := json.Marshal(val)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err)) slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
@ -1126,7 +1133,7 @@ func ChatHandler(c *gin.Context) {
Options: opts, Options: opts,
} }
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- err
} }
}() }()

View File

@ -204,6 +204,36 @@ func Test_Routes(t *testing.T) {
assert.Equal(t, expectedParams, params) assert.Equal(t, expectedParams, params)
}, },
}, },
{
Name: "Pull Model Handler - 404",
Method: http.MethodPost,
Path: "/api/pull",
Setup: func(t *testing.T, req *http.Request) {
var b bytes.Buffer
stream := false
err := json.NewEncoder(&b).Encode(api.PullRequest{Name: "not-a-model", Stream: &stream})
assert.Nil(t, err)
req.Body = io.NopCloser(&b)
},
Expected: func(t *testing.T, resp *http.Response) {
assert.Equal(t, resp.StatusCode, 404)
},
},
{
Name: "Push Model Handler - 404",
Method: http.MethodPost,
Path: "/api/pull",
Setup: func(t *testing.T, req *http.Request) {
var b bytes.Buffer
stream := false
err := json.NewEncoder(&b).Encode(api.PushRequest{Name: "not-a-model", Stream: &stream})
assert.Nil(t, err)
req.Body = io.NopCloser(&b)
},
Expected: func(t *testing.T, resp *http.Response) {
assert.Equal(t, resp.StatusCode, 404)
},
},
} }
s, err := setupServer(t) s, err := setupServer(t)