From 1477d629e932ab0629a9051294cbc8ce8d1af2c3 Mon Sep 17 00:00:00 2001
From: Igor Drozdov <idrozdov@gitlab.com>
Date: Mon, 5 Aug 2024 23:37:58 +0200
Subject: [PATCH] Allow singular array for CompletionRequest prompt field

---
 openai/openai.go      | 22 ++++++++++++--
 openai/openai_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 90 insertions(+), 2 deletions(-)

diff --git a/openai/openai.go b/openai/openai.go
index bda42b4d..6b9afb66 100644
--- a/openai/openai.go
+++ b/openai/openai.go
@@ -107,7 +107,7 @@ type ChatCompletionChunk struct {
 // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
 type CompletionRequest struct {
 	Model            string   `json:"model"`
-	Prompt           string   `json:"prompt"`
+	Prompt           any      `json:"prompt"`
 	FrequencyPenalty float32  `json:"frequency_penalty"`
 	MaxTokens        *int     `json:"max_tokens"`
 	PresencePenalty  float32  `json:"presence_penalty"`
@@ -508,6 +508,24 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
 		options["stop"] = stops
 	}
 
+	var prompts []string
+	switch prompt := r.Prompt.(type) {
+	case string:
+		prompts = []string{prompt}
+	case []any:
+		for _, p := range prompt {
+			if str, ok := p.(string); ok {
+				prompts = append(prompts, str)
+			} else {
+				return api.GenerateRequest{}, fmt.Errorf("invalid type for 'prompt' field")
+			}
+		}
+
+		if len(prompts) != 1 {
+			return api.GenerateRequest{}, fmt.Errorf("invalid size of 'prompt' field: must be 1")
+		}
+	}
+
 	if r.MaxTokens != nil {
 		options["num_predict"] = *r.MaxTokens
 	}
@@ -534,7 +552,7 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
 
 	return api.GenerateRequest{
 		Model:   r.Model,
-		Prompt:  r.Prompt,
+		Prompt:  prompts[0],
 		Options: options,
 		Stream:  &r.Stream,
 		Suffix:  r.Suffix,
diff --git a/openai/openai_test.go b/openai/openai_test.go
index e08a96c9..47f67889 100644
--- a/openai/openai_test.go
+++ b/openai/openai_test.go
@@ -269,6 +269,76 @@ func TestCompletionsMiddleware(t *testing.T) {
 				}
 			},
 		},
+		{
+			Name: "completions handler",
+			Setup: func(t *testing.T, req *http.Request) {
+				temp := float32(0.8)
+				body := CompletionRequest{
+					Model:       "test-model",
+					Prompt:      []string{"Hello"},
+					Temperature: &temp,
+					Stop:        []string{"\n", "stop"},
+					Suffix:      "suffix",
+				}
+				prepareRequest(req, body)
+			},
+			Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
+				if resp.Code != http.StatusOK {
+					t.Fatalf("expected 200, got %d", resp.Code)
+				}
+
+				if req.Prompt != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", req.Prompt)
+				}
+			},
+		},
+
+		{
+			Name: "completions handler prompt error forwarding",
+			Setup: func(t *testing.T, req *http.Request) {
+				temp := float32(0.8)
+				body := CompletionRequest{
+					Model:       "test-model",
+					Prompt:      []string{},
+					Temperature: &temp,
+					Stop:        []string{"\n", "stop"},
+					Suffix:      "suffix",
+				}
+				prepareRequest(req, body)
+			},
+			Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
+				if resp.Code != http.StatusBadRequest {
+					t.Fatalf("expected 400, got %d", resp.Code)
+				}
+
+				if !strings.Contains(resp.Body.String(), "invalid size of 'prompt' field: must be 1") {
+					t.Fatalf("error was not forwarded")
+				}
+			},
+		},
+		{
+			Name: "completions handler prompt error forwarding",
+			Setup: func(t *testing.T, req *http.Request) {
+				temp := float32(0.8)
+				body := CompletionRequest{
+					Model:       "test-model",
+					Prompt:      []int{1},
+					Temperature: &temp,
+					Stop:        []string{"\n", "stop"},
+					Suffix:      "suffix",
+				}
+				prepareRequest(req, body)
+			},
+			Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
+				if resp.Code != http.StatusBadRequest {
+					t.Fatalf("expected 400, got %d", resp.Code)
+				}
+
+				if !strings.Contains(resp.Body.String(), "invalid type for 'prompt' field") {
+					t.Fatalf("error was not forwarded")
+				}
+			},
+		},
 	}
 
 	endpoint := func(c *gin.Context) {