Compare commits

...

2 Commits

Author SHA1 Message Date
Michael Yang
5c17743b2c remove json stream errors 2024-01-22 11:08:41 -08:00
Michael Yang
0bd5245acf fix: status on errors
HTTP status on errors when stream:=false is always 500 Internal Server
Error because the individual errors are not handled.

The most common errors are:

- pull, push: pulling or pushing a model that doesn't exist should
  return 404 Not Found
- push: pushing a model into a place the user is authorized to should
  return 401 Unauthorized
2024-01-22 11:07:53 -08:00
3 changed files with 57 additions and 20 deletions

View File

@ -280,7 +280,7 @@ func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
bts, err := os.ReadFile(fp)
if err != nil {
return nil, "", fmt.Errorf("couldn't open file '%s'", fp)
return nil, "", fmt.Errorf("couldn't open file '%s': %w", fp, err)
}
shaSum := sha256.Sum256(bts)
@ -971,7 +971,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
if errors.Is(err, errUnauthorized) {
return fmt.Errorf("unable to push %s, make sure this namespace exists and you are authorized to push to it", ParseModelPath(name).GetNamespaceRepository())
return fmt.Errorf("%w: unable to push %s, make sure this namespace exists and you are authorized to push to it", err, ParseModelPath(name).GetNamespaceRepository())
}
return err
}
@ -1031,7 +1031,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
manifest, err = pullModelManifest(ctx, mp, regOpts)
if err != nil {
return fmt.Errorf("pull model manifest: %s", err)
return fmt.Errorf("pull model manifest: %w", err)
}
var layers []*Layer
@ -1169,7 +1169,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
case resp.StatusCode >= http.StatusBadRequest:
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
return nil, fmt.Errorf("%d: %w", resp.StatusCode, err)
}
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
default:

View File

@ -86,7 +86,7 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
err = fmt.Errorf("%w: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
}
return err
@ -254,7 +254,7 @@ func GenerateHandler(c *gin.Context) {
// Build up the full response
if _, err := generated.WriteString(r.Content); err != nil {
ch <- gin.H{"error": err.Error()}
ch <- err
return
}
@ -280,12 +280,12 @@ func GenerateHandler(c *gin.Context) {
promptVars.Response = generated.String()
result, err := model.PostResponseTemplate(promptVars)
if err != nil {
ch <- gin.H{"error": err.Error()}
ch <- err
return
}
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+result)
if err != nil {
ch <- gin.H{"error": err.Error()}
ch <- err
return
}
resp.Context = embd
@ -303,7 +303,7 @@ func GenerateHandler(c *gin.Context) {
Options: opts,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}
ch <- err
}
}()
@ -439,7 +439,7 @@ func PullModelHandler(c *gin.Context) {
defer cancel()
if err := PullModel(ctx, model, regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
ch <- err
}
}()
@ -488,7 +488,7 @@ func PushModelHandler(c *gin.Context) {
defer cancel()
if err := PushModel(ctx, model, regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
ch <- err
}
}()
@ -561,7 +561,7 @@ func CreateModelHandler(c *gin.Context) {
defer cancel()
if err := CreateModel(ctx, model, filepath.Dir(req.Path), commands, fn); err != nil {
ch <- gin.H{"error": err.Error()}
ch <- err
}
}()
@ -979,14 +979,17 @@ func waitForStream(c *gin.Context, ch chan interface{}) {
c.JSON(http.StatusOK, r)
return
}
case gin.H:
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
return
case error:
status := http.StatusInternalServerError
switch {
case errors.Is(r, os.ErrNotExist):
status = http.StatusNotFound
case errors.Is(r, errUnauthorized):
status = http.StatusUnauthorized
}
c.JSON(status, gin.H{"error": r.Error()})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"})
return
@ -1003,6 +1006,10 @@ func streamResponse(c *gin.Context, ch chan any) {
return false
}
if err, ok := val.(error); ok {
val = gin.H{"error": err.Error()}
}
bts, err := json.Marshal(val)
if err != nil {
slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
@ -1126,7 +1133,7 @@ func ChatHandler(c *gin.Context) {
Options: opts,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}
ch <- err
}
}()

View File

@ -204,6 +204,36 @@ func Test_Routes(t *testing.T) {
assert.Equal(t, expectedParams, params)
},
},
{
Name: "Pull Model Handler - 404",
Method: http.MethodPost,
Path: "/api/pull",
Setup: func(t *testing.T, req *http.Request) {
var b bytes.Buffer
stream := false
err := json.NewEncoder(&b).Encode(api.PullRequest{Name: "not-a-model", Stream: &stream})
assert.Nil(t, err)
req.Body = io.NopCloser(&b)
},
Expected: func(t *testing.T, resp *http.Response) {
assert.Equal(t, resp.StatusCode, 404)
},
},
{
Name: "Push Model Handler - 404",
Method: http.MethodPost,
Path: "/api/pull",
Setup: func(t *testing.T, req *http.Request) {
var b bytes.Buffer
stream := false
err := json.NewEncoder(&b).Encode(api.PushRequest{Name: "not-a-model", Stream: &stream})
assert.Nil(t, err)
req.Body = io.NopCloser(&b)
},
Expected: func(t *testing.T, resp *http.Response) {
assert.Equal(t, resp.StatusCode, 404)
},
},
}
s, err := setupServer(t)