305 lines
8.0 KiB
Go
305 lines
8.0 KiB
Go
package server
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"log"
|
|
"math/rand"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/jmorganca/ollama/api"
|
|
)
|
|
|
|
type OpenAIError struct {
|
|
Message string `json:"message"`
|
|
Type string `json:"type"`
|
|
Param interface{} `json:"param"`
|
|
Code *string `json:"code"`
|
|
}
|
|
|
|
type OpenAIErrorResponse struct {
|
|
Error OpenAIError `json:"error"`
|
|
}
|
|
|
|
type OpenAIChatCompletionRequest struct {
|
|
Model string
|
|
Messages []OpenAIMessage `json:"messages"`
|
|
Stream bool `json:"stream"`
|
|
}
|
|
|
|
type OpenAIMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
func (m *OpenAIMessage) toMessage() api.Message {
|
|
return api.Message{
|
|
Role: m.Role,
|
|
Content: m.Content,
|
|
}
|
|
}
|
|
|
|
// non-streaming response
|
|
|
|
type OpenAIChatCompletionResponseChoice struct {
|
|
Index int `json:"index"`
|
|
Message OpenAIMessage `json:"message"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
}
|
|
|
|
type OpenAIUsage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
type OpenAIChatCompletionResponse struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Choices []OpenAIChatCompletionResponseChoice `json:"choices"`
|
|
Usage OpenAIUsage `json:"usage,omitempty"`
|
|
}
|
|
|
|
// streaming response
|
|
|
|
type OpenAIChatCompletionResponseChoiceStream struct {
|
|
Index int `json:"index"`
|
|
Delta OpenAIMessage `json:"delta"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
}
|
|
|
|
type OpenAIChatCompletionResponseStream struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Choices []OpenAIChatCompletionResponseChoiceStream `json:"choices"`
|
|
}
|
|
|
|
type StreamCompletionMarker struct{} // signals to send [DONE] on the event-stream
|
|
|
|
func ChatCompletions(c *gin.Context) {
|
|
var req OpenAIChatCompletionRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
switch {
|
|
case errors.Is(err, io.EOF):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
|
|
OpenAIError{
|
|
Message: "missing request body",
|
|
Type: "invalid_request_error",
|
|
},
|
|
})
|
|
return
|
|
case err != nil:
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
|
|
OpenAIError{
|
|
Message: err.Error(),
|
|
Type: "invalid_request_error",
|
|
},
|
|
})
|
|
return
|
|
}
|
|
|
|
// Call generate and receive the channel with the responses
|
|
chatReq := api.ChatRequest{
|
|
Model: req.Model,
|
|
Stream: &req.Stream,
|
|
}
|
|
for _, m := range req.Messages {
|
|
chatReq.Messages = append(chatReq.Messages, m.toMessage())
|
|
}
|
|
ch, err := chat(c, chatReq, time.Now())
|
|
if err != nil {
|
|
var pErr *fs.PathError
|
|
switch {
|
|
case errors.As(err, &pErr):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
|
|
OpenAIError{
|
|
Message: fmt.Sprintf("model '%s' not found, try pulling it first", req.Model),
|
|
Type: "invalid_request_error",
|
|
},
|
|
})
|
|
case errors.Is(err, api.ErrInvalidOpts), errors.Is(err, errInvalidRole):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
|
|
OpenAIError{
|
|
Message: err.Error(),
|
|
Type: "invalid_request_error",
|
|
},
|
|
})
|
|
default:
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, OpenAIErrorResponse{
|
|
OpenAIError{
|
|
Message: err.Error(),
|
|
Type: "internal_server_error",
|
|
},
|
|
})
|
|
}
|
|
return
|
|
}
|
|
|
|
if !req.Stream {
|
|
// Wait for the channel to close
|
|
var chatResponse api.ChatResponse
|
|
var sb strings.Builder
|
|
|
|
for val := range ch {
|
|
var ok bool
|
|
chatResponse, ok = val.(api.ChatResponse)
|
|
if !ok {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
|
|
OpenAIError{
|
|
Message: err.Error(),
|
|
Type: "internal_server_error",
|
|
},
|
|
})
|
|
return
|
|
}
|
|
if chatResponse.Message != nil {
|
|
sb.WriteString(chatResponse.Message.Content)
|
|
}
|
|
|
|
if chatResponse.Done {
|
|
chatResponse.Message = &api.Message{Role: "assistant", Content: sb.String()}
|
|
}
|
|
}
|
|
// Send a single response with accumulated content
|
|
id := fmt.Sprintf("chatcmpl-%d", rand.Intn(999))
|
|
chatCompletionResponse := OpenAIChatCompletionResponse{
|
|
ID: id,
|
|
Object: "chat.completion",
|
|
Created: chatResponse.CreatedAt.Unix(),
|
|
Model: req.Model,
|
|
Choices: []OpenAIChatCompletionResponseChoice{
|
|
{
|
|
Index: 0,
|
|
Message: OpenAIMessage{
|
|
Role: "assistant",
|
|
Content: chatResponse.Message.Content,
|
|
},
|
|
FinishReason: func(done bool) *string {
|
|
if done {
|
|
reason := "stop"
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(chatResponse.Done),
|
|
},
|
|
},
|
|
}
|
|
c.JSON(http.StatusOK, chatCompletionResponse)
|
|
return
|
|
}
|
|
|
|
// Now, create the intermediate channel and transformation goroutine
|
|
transformedCh := make(chan any)
|
|
|
|
go func() {
|
|
defer close(transformedCh)
|
|
id := fmt.Sprintf("chatcmpl-%d", rand.Intn(999)) // TODO: validate that this does not change with each chunk
|
|
predefinedResponse := OpenAIChatCompletionResponseStream{
|
|
ID: id,
|
|
Object: "chat.completion.chunk",
|
|
Created: time.Now().Unix(),
|
|
Model: req.Model,
|
|
Choices: []OpenAIChatCompletionResponseChoiceStream{
|
|
{
|
|
Index: 0,
|
|
Delta: OpenAIMessage{
|
|
Role: "assistant",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
transformedCh <- predefinedResponse
|
|
for val := range ch {
|
|
resp, ok := val.(api.ChatResponse)
|
|
if !ok {
|
|
// If val is not of type ChatResponse, send an error down the channel and exit
|
|
transformedCh <- OpenAIErrorResponse{
|
|
OpenAIError{
|
|
Message: "failed to parse chat response",
|
|
Type: "internal_server_error",
|
|
},
|
|
}
|
|
return
|
|
}
|
|
|
|
// Transform the ChatResponse into OpenAIChatCompletionResponse
|
|
chatCompletionResponse := OpenAIChatCompletionResponseStream{
|
|
ID: id,
|
|
Object: "chat.completion.chunk",
|
|
Created: resp.CreatedAt.Unix(),
|
|
Model: resp.Model,
|
|
Choices: []OpenAIChatCompletionResponseChoiceStream{
|
|
{
|
|
Index: 0,
|
|
FinishReason: func(done bool) *string {
|
|
if done {
|
|
reason := "stop"
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(resp.Done),
|
|
},
|
|
},
|
|
}
|
|
if resp.Message != nil {
|
|
chatCompletionResponse.Choices[0].Delta = OpenAIMessage{
|
|
Content: resp.Message.Content,
|
|
}
|
|
}
|
|
transformedCh <- chatCompletionResponse
|
|
if resp.Done {
|
|
transformedCh <- StreamCompletionMarker{}
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Pass the transformed channel to streamResponse
|
|
streamOpenAIResponse(c, transformedCh)
|
|
}
|
|
|
|
func streamOpenAIResponse(c *gin.Context, ch chan any) {
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Stream(func(w io.Writer) bool {
|
|
val, ok := <-ch
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
// Check if the message is a StreamCompletionMarker to close the event stream
|
|
if _, isCompletionMarker := val.(StreamCompletionMarker); isCompletionMarker {
|
|
if _, err := w.Write([]byte("data: [DONE]\n")); err != nil {
|
|
log.Printf("streamOpenAIResponse: w.Write failed with %s", err)
|
|
return false
|
|
}
|
|
return false // Stop streaming after sending [DONE]
|
|
}
|
|
|
|
bts, err := json.Marshal(val)
|
|
if err != nil {
|
|
log.Printf("streamOpenAIResponse: json.Marshal failed with %s", err)
|
|
return false
|
|
}
|
|
|
|
formattedResponse := fmt.Sprintf("data: %s\n", bts)
|
|
|
|
if _, err := w.Write([]byte(formattedResponse)); err != nil {
|
|
log.Printf("streamOpenAIResponse: w.Write failed with %s", err)
|
|
return false
|
|
}
|
|
|
|
return true
|
|
})
|
|
}
|