diff --git a/openai/openai.go b/openai/openai.go index 2bf9b9f9..aff8a04e 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -107,7 +107,7 @@ type ChatCompletionChunk struct { // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int type CompletionRequest struct { Model string `json:"model"` - Prompt string `json:"prompt"` + Prompt any `json:"prompt"` FrequencyPenalty float32 `json:"frequency_penalty"` MaxTokens *int `json:"max_tokens"` PresencePenalty float32 `json:"presence_penalty"` @@ -508,6 +508,24 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { options["stop"] = stops } + var prompts []string + switch prompt := r.Prompt.(type) { + case string: + prompts = []string{prompt} + case []any: + for _, p := range prompt { + if str, ok := p.(string); ok { + prompts = append(prompts, str) + } else { + return api.GenerateRequest{}, fmt.Errorf("invalid type for 'prompt' field") + } + } + + if len(prompts) != 1 { + return api.GenerateRequest{}, fmt.Errorf("invalid size of 'prompt' field: must be 1") + } + } + if r.MaxTokens != nil { options["num_predict"] = *r.MaxTokens } @@ -534,7 +552,7 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { return api.GenerateRequest{ Model: r.Model, - Prompt: r.Prompt, + Prompt: prompts[0], Options: options, Stream: &r.Stream, Suffix: r.Suffix, diff --git a/openai/openai_test.go b/openai/openai_test.go index eabf5b66..45679f1c 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -299,6 +299,76 @@ func TestCompletionsMiddleware(t *testing.T) { }, }, }, + { + Name: "completions handler", + Setup: func(t *testing.T, req *http.Request) { + temp := float32(0.8) + body := CompletionRequest{ + Model: "test-model", + Prompt: []string{"Hello"}, + Temperature: &temp, + Stop: []string{"\n", "stop"}, + Suffix: "suffix", + } + prepareRequest(req, body) + }, + Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) { + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.Code) + } + + if req.Prompt != "Hello" { + t.Fatalf("expected 'Hello', got %s", req.Prompt) + } + }, + }, + + { + Name: "completions handler prompt error forwarding", + Setup: func(t *testing.T, req *http.Request) { + temp := float32(0.8) + body := CompletionRequest{ + Model: "test-model", + Prompt: []string{}, + Temperature: &temp, + Stop: []string{"\n", "stop"}, + Suffix: "suffix", + } + prepareRequest(req, body) + }, + Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) { + if resp.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.Code) + } + + if !strings.Contains(resp.Body.String(), "invalid size of 'prompt' field: must be 1") { + t.Fatalf("error was not forwarded") + } + }, + }, + { + Name: "completions handler prompt error forwarding", + Setup: func(t *testing.T, req *http.Request) { + temp := float32(0.8) + body := CompletionRequest{ + Model: "test-model", + Prompt: []int{1}, + Temperature: &temp, + Stop: []string{"\n", "stop"}, + Suffix: "suffix", + } + prepareRequest(req, body) + }, + Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) { + if resp.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.Code) + } + + if !strings.Contains(resp.Body.String(), "invalid type for 'prompt' field") { + t.Fatalf("error was not forwarded") + } + }, + }, } endpoint := func(c *gin.Context) {