From ae03496a488a15c199c25c349020298efdf5adc7 Mon Sep 17 00:00:00 2001 From: William Guss Date: Tue, 24 Sep 2024 23:47:40 -0700 Subject: [PATCH 1/3] Update openai.go --- openai/openai.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/openai/openai.go b/openai/openai.go index 2bf9b9f9..f74cabe0 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -874,19 +874,24 @@ func EmbeddingsMiddleware() gin.HandlerFunc { func ChatMiddleware() gin.HandlerFunc { return func(c *gin.Context) { var req ChatCompletionRequest - err := c.ShouldBindJSON(&req) - if err != nil { + + // Create a new decoder and disallow unknown fields + decoder := json.NewDecoder(c.Request.Body) + decoder.DisallowUnknownFields() + + if err := decoder.Decode(&req); err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) return } - if len(req.Messages) == 0 { - c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'")) + // Optionally check for more tokens to ensure the entire body has been read + if decoder.More() { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "Unexpected additional fields")) return } + // Encode back to buffer if needed (as in original middleware) var b bytes.Buffer - chatReq, err := fromChatRequest(req) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) From 85f635877de38a83e89e0dae183b8a6bb559d8a4 Mon Sep 17 00:00:00 2001 From: William Guss Date: Tue, 24 Sep 2024 23:52:26 -0700 Subject: [PATCH 2/3] Update openai.go --- openai/openai.go | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/openai/openai.go b/openai/openai.go index f74cabe0..289ce10c 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -884,13 +884,19 @@ func ChatMiddleware() gin.HandlerFunc { return } - // Optionally check for more tokens to ensure the entire body has been read + // Check for unexpected additional fields if decoder.More() { c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "Unexpected additional fields")) return } - // Encode back to buffer if needed (as in original middleware) + // Validate that the 'messages' field is not empty + if len(req.Messages) == 0 { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'")) + return + } + + // Encode the validated request back to a buffer var b bytes.Buffer chatReq, err := fromChatRequest(req) if err != nil { @@ -903,8 +909,10 @@ func ChatMiddleware() gin.HandlerFunc { return } + // Replace the request body with the new buffer c.Request.Body = io.NopCloser(&b) + // Initialize the custom ResponseWriter w := &ChatWriter{ BaseWriter: BaseWriter{ResponseWriter: c.Writer}, stream: req.Stream, @@ -913,6 +921,7 @@ func ChatMiddleware() gin.HandlerFunc { c.Writer = w + // Proceed to the next handler c.Next() } } From bc0a60f617be02e8bfe45d6ed5bf254aa7f74fee Mon Sep 17 00:00:00 2001 From: William Guss Date: Wed, 25 Sep 2024 13:57:16 -0700 Subject: [PATCH 3/3] Adding testsa nd more descriptive erorr messages for validation --- openai/openai.go | 22 ++++++++- openai/openai_test.go | 102 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 119 insertions(+), 5 deletions(-) diff --git a/openai/openai.go b/openai/openai.go index 289ce10c..ecc03ed5 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -875,12 +875,30 @@ func ChatMiddleware() gin.HandlerFunc { return func(c *gin.Context) { var req ChatCompletionRequest - // Create a new decoder and disallow unknown fields + // Create a new decoder and use DisallowUnknownFields to catch unknown fields decoder := json.NewDecoder(c.Request.Body) decoder.DisallowUnknownFields() if err := decoder.Decode(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) + var unmarshalTypeError *json.UnmarshalTypeError + var jsonSyntaxError *json.SyntaxError + + switch { + case errors.As(err, &unmarshalTypeError): + errMsg := fmt.Sprintf("Invalid type for field %s. Expected %s, got %s", unmarshalTypeError.Field, unmarshalTypeError.Type, unmarshalTypeError.Value) + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, errMsg)) + case errors.As(err, &jsonSyntaxError): + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "Invalid JSON syntax")) + default: + // Check if it's an unknown field error + if strings.HasPrefix(err.Error(), "json: unknown field ") { + fieldName := strings.TrimPrefix(err.Error(), "json: unknown field ") + errMsg := fmt.Sprintf("Unsupported parameter: %s", fieldName) + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, errMsg)) + } else { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) + } + } return } diff --git a/openai/openai_test.go b/openai/openai_test.go index eabf5b66..b0a75d63 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -211,6 +211,102 @@ func TestChatMiddleware(t *testing.T) { }, }, }, + { + name: "chat handler with unsupported parameter", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "n": 2 + }`, + err: ErrorResponse{ + Error: Error{ + Message: "Unsupported parameter: \"n\"", + Type: "invalid_request_error", + }, + }, + }, + { + name: "chat handler with multiple unsupported parameters", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "n": 2, + "best_of": 3 + }`, + err: ErrorResponse{ + Error: Error{ + Message: "Unsupported parameter: \"n\"", + Type: "invalid_request_error", + }, + }, + }, + { + name: "chat handler with unsupported nested parameter", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "options": { + "unsupported_option": true + } + }`, + err: ErrorResponse{ + Error: Error{ + Message: "Unsupported parameter: \"options\"", + Type: "invalid_request_error", + }, + }, + }, + { + name: "chat handler with empty messages array", + body: `{ + "model": "test-model", + "messages": [] + }`, + err: ErrorResponse{ + Error: Error{ + Message: "[] is too short - 'messages'", + Type: "invalid_request_error", + }, + }, + }, + { + name: "chat handler with invalid JSON", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "temperature": 0.7, + }`, + err: ErrorResponse{ + Error: Error{ + Message: "Invalid JSON syntax", + Type: "invalid_request_error", + }, + }, + }, + { + name: "chat handler with invalid type for known field", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "temperature": "not a number" + }`, + err: ErrorResponse{ + Error: Error{ + Message: "Invalid type for field temperature. Expected float64, got string", + Type: "invalid_request_error", + }, + }, + }, } endpoint := func(c *gin.Context) { @@ -235,15 +331,15 @@ func TestChatMiddleware(t *testing.T) { var errResp ErrorResponse if resp.Code != http.StatusOK { if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { - t.Fatal(err) + t.Fatalf("Failed to unmarshal error response: %v", err) } } if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { - t.Fatal("requests did not match") + t.Fatalf("Requests did not match.\nExpected: %+v\nGot: %+v", tc.req, *capturedRequest) } if !reflect.DeepEqual(tc.err, errResp) { - t.Fatal("errors did not match") + t.Fatalf("Errors did not match.\nExpected: %+v\nGot: %+v\nResponse body: %s", tc.err, errResp, resp.Body.String()) } }) }