Allow singular array for CompletionRequest prompt field
This commit is contained in:
parent
39f2bc6bfc
commit
1477d629e9
openai
@ -107,7 +107,7 @@ type ChatCompletionChunk struct {
|
||||
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
||||
type CompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Prompt any `json:"prompt"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
PresencePenalty float32 `json:"presence_penalty"`
|
||||
@ -508,6 +508,24 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
options["stop"] = stops
|
||||
}
|
||||
|
||||
var prompts []string
|
||||
switch prompt := r.Prompt.(type) {
|
||||
case string:
|
||||
prompts = []string{prompt}
|
||||
case []any:
|
||||
for _, p := range prompt {
|
||||
if str, ok := p.(string); ok {
|
||||
prompts = append(prompts, str)
|
||||
} else {
|
||||
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'prompt' field")
|
||||
}
|
||||
}
|
||||
|
||||
if len(prompts) != 1 {
|
||||
return api.GenerateRequest{}, fmt.Errorf("invalid size of 'prompt' field: must be 1")
|
||||
}
|
||||
}
|
||||
|
||||
if r.MaxTokens != nil {
|
||||
options["num_predict"] = *r.MaxTokens
|
||||
}
|
||||
@ -534,7 +552,7 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
|
||||
return api.GenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
Prompt: prompts[0],
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Suffix: r.Suffix,
|
||||
|
@ -269,6 +269,76 @@ func TestCompletionsMiddleware(t *testing.T) {
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "completions handler",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
temp := float32(0.8)
|
||||
body := CompletionRequest{
|
||||
Model: "test-model",
|
||||
Prompt: []string{"Hello"},
|
||||
Temperature: &temp,
|
||||
Stop: []string{"\n", "stop"},
|
||||
Suffix: "suffix",
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.Code)
|
||||
}
|
||||
|
||||
if req.Prompt != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %s", req.Prompt)
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
Name: "completions handler prompt error forwarding",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
temp := float32(0.8)
|
||||
body := CompletionRequest{
|
||||
Model: "test-model",
|
||||
Prompt: []string{},
|
||||
Temperature: &temp,
|
||||
Stop: []string{"\n", "stop"},
|
||||
Suffix: "suffix",
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
||||
if resp.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", resp.Code)
|
||||
}
|
||||
|
||||
if !strings.Contains(resp.Body.String(), "invalid size of 'prompt' field: must be 1") {
|
||||
t.Fatalf("error was not forwarded")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "completions handler prompt error forwarding",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
temp := float32(0.8)
|
||||
body := CompletionRequest{
|
||||
Model: "test-model",
|
||||
Prompt: []int{1},
|
||||
Temperature: &temp,
|
||||
Stop: []string{"\n", "stop"},
|
||||
Suffix: "suffix",
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
||||
if resp.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", resp.Code)
|
||||
}
|
||||
|
||||
if !strings.Contains(resp.Body.String(), "invalid type for 'prompt' field") {
|
||||
t.Fatalf("error was not forwarded")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user