Allow singular array for CompletionRequest prompt field

This commit is contained in:
Igor Drozdov 2024-08-05 23:37:58 +02:00
parent 39f2bc6bfc
commit 1477d629e9
No known key found for this signature in database
2 changed files with 90 additions and 2 deletions

@ -107,7 +107,7 @@ type ChatCompletionChunk struct {
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
type CompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Prompt any `json:"prompt"`
FrequencyPenalty float32 `json:"frequency_penalty"`
MaxTokens *int `json:"max_tokens"`
PresencePenalty float32 `json:"presence_penalty"`
@ -508,6 +508,24 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
options["stop"] = stops
}
var prompts []string
switch prompt := r.Prompt.(type) {
case string:
prompts = []string{prompt}
case []any:
for _, p := range prompt {
if str, ok := p.(string); ok {
prompts = append(prompts, str)
} else {
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'prompt' field")
}
}
if len(prompts) != 1 {
return api.GenerateRequest{}, fmt.Errorf("invalid size of 'prompt' field: must be 1")
}
}
if r.MaxTokens != nil {
options["num_predict"] = *r.MaxTokens
}
@ -534,7 +552,7 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
return api.GenerateRequest{
Model: r.Model,
Prompt: r.Prompt,
Prompt: prompts[0],
Options: options,
Stream: &r.Stream,
Suffix: r.Suffix,

@ -269,6 +269,76 @@ func TestCompletionsMiddleware(t *testing.T) {
}
},
},
{
Name: "completions handler",
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: []string{"Hello"},
Temperature: &temp,
Stop: []string{"\n", "stop"},
Suffix: "suffix",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.Code)
}
if req.Prompt != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Prompt)
}
},
},
{
Name: "completions handler prompt error forwarding",
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: []string{},
Temperature: &temp,
Stop: []string{"\n", "stop"},
Suffix: "suffix",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid size of 'prompt' field: must be 1") {
t.Fatalf("error was not forwarded")
}
},
},
{
Name: "completions handler prompt error forwarding",
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: []int{1},
Temperature: &temp,
Stop: []string{"\n", "stop"},
Suffix: "suffix",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid type for 'prompt' field") {
t.Fatalf("error was not forwarded")
}
},
},
}
endpoint := func(c *gin.Context) {