Merge bc0a60f617be02e8bfe45d6ed5bf254aa7f74fee into d7eb05b9361febead29a74e71ddffc2ebeff5302
This commit is contained in:
commit
8908038f58
openai
@ -874,19 +874,48 @@ func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||
func ChatMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req ChatCompletionRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||
|
||||
// Create a new decoder and use DisallowUnknownFields to catch unknown fields
|
||||
decoder := json.NewDecoder(c.Request.Body)
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
var unmarshalTypeError *json.UnmarshalTypeError
|
||||
var jsonSyntaxError *json.SyntaxError
|
||||
|
||||
switch {
|
||||
case errors.As(err, &unmarshalTypeError):
|
||||
errMsg := fmt.Sprintf("Invalid type for field %s. Expected %s, got %s", unmarshalTypeError.Field, unmarshalTypeError.Type, unmarshalTypeError.Value)
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, errMsg))
|
||||
case errors.As(err, &jsonSyntaxError):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "Invalid JSON syntax"))
|
||||
default:
|
||||
// Check if it's an unknown field error
|
||||
if strings.HasPrefix(err.Error(), "json: unknown field ") {
|
||||
fieldName := strings.TrimPrefix(err.Error(), "json: unknown field ")
|
||||
errMsg := fmt.Sprintf("Unsupported parameter: %s", fieldName)
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, errMsg))
|
||||
} else {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Check for unexpected additional fields
|
||||
if decoder.More() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "Unexpected additional fields"))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate that the 'messages' field is not empty
|
||||
if len(req.Messages) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
|
||||
return
|
||||
}
|
||||
|
||||
// Encode the validated request back to a buffer
|
||||
var b bytes.Buffer
|
||||
|
||||
chatReq, err := fromChatRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||
@ -898,8 +927,10 @@ func ChatMiddleware() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// Replace the request body with the new buffer
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
// Initialize the custom ResponseWriter
|
||||
w := &ChatWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
@ -908,6 +939,7 @@ func ChatMiddleware() gin.HandlerFunc {
|
||||
|
||||
c.Writer = w
|
||||
|
||||
// Proceed to the next handler
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
@ -211,6 +211,102 @@ func TestChatMiddleware(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with unsupported parameter",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"n": 2
|
||||
}`,
|
||||
err: ErrorResponse{
|
||||
Error: Error{
|
||||
Message: "Unsupported parameter: \"n\"",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with multiple unsupported parameters",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"n": 2,
|
||||
"best_of": 3
|
||||
}`,
|
||||
err: ErrorResponse{
|
||||
Error: Error{
|
||||
Message: "Unsupported parameter: \"n\"",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with unsupported nested parameter",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"options": {
|
||||
"unsupported_option": true
|
||||
}
|
||||
}`,
|
||||
err: ErrorResponse{
|
||||
Error: Error{
|
||||
Message: "Unsupported parameter: \"options\"",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with empty messages array",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": []
|
||||
}`,
|
||||
err: ErrorResponse{
|
||||
Error: Error{
|
||||
Message: "[] is too short - 'messages'",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with invalid JSON",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
}`,
|
||||
err: ErrorResponse{
|
||||
Error: Error{
|
||||
Message: "Invalid JSON syntax",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with invalid type for known field",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"temperature": "not a number"
|
||||
}`,
|
||||
err: ErrorResponse{
|
||||
Error: Error{
|
||||
Message: "Invalid type for field temperature. Expected float64, got string",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
@ -235,15 +331,15 @@ func TestChatMiddleware(t *testing.T) {
|
||||
var errResp ErrorResponse
|
||||
if resp.Code != http.StatusOK {
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
t.Fatalf("Failed to unmarshal error response: %v", err)
|
||||
}
|
||||
}
|
||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||
t.Fatal("requests did not match")
|
||||
t.Fatalf("Requests did not match.\nExpected: %+v\nGot: %+v", tc.req, *capturedRequest)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.err, errResp) {
|
||||
t.Fatal("errors did not match")
|
||||
t.Fatalf("Errors did not match.\nExpected: %+v\nGot: %+v\nResponse body: %s", tc.err, errResp, resp.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user