Merge bc0a60f617be02e8bfe45d6ed5bf254aa7f74fee into d7eb05b9361febead29a74e71ddffc2ebeff5302

This commit is contained in:
William Guss 2024-11-14 13:55:06 +08:00 committed by GitHub
commit 8908038f58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 135 additions and 7 deletions

@ -874,19 +874,48 @@ func EmbeddingsMiddleware() gin.HandlerFunc {
func ChatMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req ChatCompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
// Create a new decoder and use DisallowUnknownFields to catch unknown fields
decoder := json.NewDecoder(c.Request.Body)
decoder.DisallowUnknownFields()
if err := decoder.Decode(&req); err != nil {
var unmarshalTypeError *json.UnmarshalTypeError
var jsonSyntaxError *json.SyntaxError
switch {
case errors.As(err, &unmarshalTypeError):
errMsg := fmt.Sprintf("Invalid type for field %s. Expected %s, got %s", unmarshalTypeError.Field, unmarshalTypeError.Type, unmarshalTypeError.Value)
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, errMsg))
case errors.As(err, &jsonSyntaxError):
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "Invalid JSON syntax"))
default:
// Check if it's an unknown field error
if strings.HasPrefix(err.Error(), "json: unknown field ") {
fieldName := strings.TrimPrefix(err.Error(), "json: unknown field ")
errMsg := fmt.Sprintf("Unsupported parameter: %s", fieldName)
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, errMsg))
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
}
}
return
}
// Check for unexpected additional fields
if decoder.More() {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "Unexpected additional fields"))
return
}
// Validate that the 'messages' field is not empty
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
return
}
// Encode the validated request back to a buffer
var b bytes.Buffer
chatReq, err := fromChatRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
@ -898,8 +927,10 @@ func ChatMiddleware() gin.HandlerFunc {
return
}
// Replace the request body with the new buffer
c.Request.Body = io.NopCloser(&b)
// Initialize the custom ResponseWriter
w := &ChatWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
@ -908,6 +939,7 @@ func ChatMiddleware() gin.HandlerFunc {
c.Writer = w
// Proceed to the next handler
c.Next()
}
}

@ -211,6 +211,102 @@ func TestChatMiddleware(t *testing.T) {
},
},
},
{
name: "chat handler with unsupported parameter",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"n": 2
}`,
err: ErrorResponse{
Error: Error{
Message: "Unsupported parameter: \"n\"",
Type: "invalid_request_error",
},
},
},
{
name: "chat handler with multiple unsupported parameters",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"n": 2,
"best_of": 3
}`,
err: ErrorResponse{
Error: Error{
Message: "Unsupported parameter: \"n\"",
Type: "invalid_request_error",
},
},
},
{
name: "chat handler with unsupported nested parameter",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"options": {
"unsupported_option": true
}
}`,
err: ErrorResponse{
Error: Error{
Message: "Unsupported parameter: \"options\"",
Type: "invalid_request_error",
},
},
},
{
name: "chat handler with empty messages array",
body: `{
"model": "test-model",
"messages": []
}`,
err: ErrorResponse{
Error: Error{
Message: "[] is too short - 'messages'",
Type: "invalid_request_error",
},
},
},
{
name: "chat handler with invalid JSON",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"temperature": 0.7,
}`,
err: ErrorResponse{
Error: Error{
Message: "Invalid JSON syntax",
Type: "invalid_request_error",
},
},
},
{
name: "chat handler with invalid type for known field",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"temperature": "not a number"
}`,
err: ErrorResponse{
Error: Error{
Message: "Invalid type for field temperature. Expected float64, got string",
Type: "invalid_request_error",
},
},
},
}
endpoint := func(c *gin.Context) {
@ -235,15 +331,15 @@ func TestChatMiddleware(t *testing.T) {
var errResp ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
t.Fatalf("Failed to unmarshal error response: %v", err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
t.Fatalf("Requests did not match.\nExpected: %+v\nGot: %+v", tc.req, *capturedRequest)
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
t.Fatalf("Errors did not match.\nExpected: %+v\nGot: %+v\nResponse body: %s", tc.err, errResp, resp.Body.String())
}
})
}