ollama/server/openai.go
2023-12-06 16:28:28 -08:00

305 lines
8.0 KiB
Go

package server
import (
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log"
"math/rand"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/jmorganca/ollama/api"
)
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param interface{} `json:"param"`
Code *string `json:"code"`
}
type OpenAIErrorResponse struct {
Error OpenAIError `json:"error"`
}
type OpenAIChatCompletionRequest struct {
Model string
Messages []OpenAIMessage `json:"messages"`
Stream bool `json:"stream"`
}
type OpenAIMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
func (m *OpenAIMessage) toMessage() api.Message {
return api.Message{
Role: m.Role,
Content: m.Content,
}
}
// non-streaming response
type OpenAIChatCompletionResponseChoice struct {
Index int `json:"index"`
Message OpenAIMessage `json:"message"`
FinishReason *string `json:"finish_reason"`
}
type OpenAIUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type OpenAIChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []OpenAIChatCompletionResponseChoice `json:"choices"`
Usage OpenAIUsage `json:"usage,omitempty"`
}
// streaming response
type OpenAIChatCompletionResponseChoiceStream struct {
Index int `json:"index"`
Delta OpenAIMessage `json:"delta"`
FinishReason *string `json:"finish_reason"`
}
type OpenAIChatCompletionResponseStream struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []OpenAIChatCompletionResponseChoiceStream `json:"choices"`
}
type StreamCompletionMarker struct{} // signals to send [DONE] on the event-stream
func ChatCompletions(c *gin.Context) {
var req OpenAIChatCompletionRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
OpenAIError{
Message: "missing request body",
Type: "invalid_request_error",
},
})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
OpenAIError{
Message: err.Error(),
Type: "invalid_request_error",
},
})
return
}
// Call generate and receive the channel with the responses
chatReq := api.ChatRequest{
Model: req.Model,
Stream: &req.Stream,
}
for _, m := range req.Messages {
chatReq.Messages = append(chatReq.Messages, m.toMessage())
}
ch, err := chat(c, chatReq, time.Now())
if err != nil {
var pErr *fs.PathError
switch {
case errors.As(err, &pErr):
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
OpenAIError{
Message: fmt.Sprintf("model '%s' not found, try pulling it first", req.Model),
Type: "invalid_request_error",
},
})
case errors.Is(err, api.ErrInvalidOpts), errors.Is(err, errInvalidRole):
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
OpenAIError{
Message: err.Error(),
Type: "invalid_request_error",
},
})
default:
c.AbortWithStatusJSON(http.StatusInternalServerError, OpenAIErrorResponse{
OpenAIError{
Message: err.Error(),
Type: "internal_server_error",
},
})
}
return
}
if !req.Stream {
// Wait for the channel to close
var chatResponse api.ChatResponse
var sb strings.Builder
for val := range ch {
var ok bool
chatResponse, ok = val.(api.ChatResponse)
if !ok {
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
OpenAIError{
Message: err.Error(),
Type: "internal_server_error",
},
})
return
}
if chatResponse.Message != nil {
sb.WriteString(chatResponse.Message.Content)
}
if chatResponse.Done {
chatResponse.Message = &api.Message{Role: "assistant", Content: sb.String()}
}
}
// Send a single response with accumulated content
id := fmt.Sprintf("chatcmpl-%d", rand.Intn(999))
chatCompletionResponse := OpenAIChatCompletionResponse{
ID: id,
Object: "chat.completion",
Created: chatResponse.CreatedAt.Unix(),
Model: req.Model,
Choices: []OpenAIChatCompletionResponseChoice{
{
Index: 0,
Message: OpenAIMessage{
Role: "assistant",
Content: chatResponse.Message.Content,
},
FinishReason: func(done bool) *string {
if done {
reason := "stop"
return &reason
}
return nil
}(chatResponse.Done),
},
},
}
c.JSON(http.StatusOK, chatCompletionResponse)
return
}
// Now, create the intermediate channel and transformation goroutine
transformedCh := make(chan any)
go func() {
defer close(transformedCh)
id := fmt.Sprintf("chatcmpl-%d", rand.Intn(999)) // TODO: validate that this does not change with each chunk
predefinedResponse := OpenAIChatCompletionResponseStream{
ID: id,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: req.Model,
Choices: []OpenAIChatCompletionResponseChoiceStream{
{
Index: 0,
Delta: OpenAIMessage{
Role: "assistant",
},
},
},
}
transformedCh <- predefinedResponse
for val := range ch {
resp, ok := val.(api.ChatResponse)
if !ok {
// If val is not of type ChatResponse, send an error down the channel and exit
transformedCh <- OpenAIErrorResponse{
OpenAIError{
Message: "failed to parse chat response",
Type: "internal_server_error",
},
}
return
}
// Transform the ChatResponse into OpenAIChatCompletionResponse
chatCompletionResponse := OpenAIChatCompletionResponseStream{
ID: id,
Object: "chat.completion.chunk",
Created: resp.CreatedAt.Unix(),
Model: resp.Model,
Choices: []OpenAIChatCompletionResponseChoiceStream{
{
Index: 0,
FinishReason: func(done bool) *string {
if done {
reason := "stop"
return &reason
}
return nil
}(resp.Done),
},
},
}
if resp.Message != nil {
chatCompletionResponse.Choices[0].Delta = OpenAIMessage{
Content: resp.Message.Content,
}
}
transformedCh <- chatCompletionResponse
if resp.Done {
transformedCh <- StreamCompletionMarker{}
}
}
}()
// Pass the transformed channel to streamResponse
streamOpenAIResponse(c, transformedCh)
}
func streamOpenAIResponse(c *gin.Context, ch chan any) {
c.Header("Content-Type", "text/event-stream")
c.Stream(func(w io.Writer) bool {
val, ok := <-ch
if !ok {
return false
}
// Check if the message is a StreamCompletionMarker to close the event stream
if _, isCompletionMarker := val.(StreamCompletionMarker); isCompletionMarker {
if _, err := w.Write([]byte("data: [DONE]\n")); err != nil {
log.Printf("streamOpenAIResponse: w.Write failed with %s", err)
return false
}
return false // Stop streaming after sending [DONE]
}
bts, err := json.Marshal(val)
if err != nil {
log.Printf("streamOpenAIResponse: json.Marshal failed with %s", err)
return false
}
formattedResponse := fmt.Sprintf("data: %s\n", bts)
if _, err := w.Write([]byte(formattedResponse)); err != nil {
log.Printf("streamOpenAIResponse: w.Write failed with %s", err)
return false
}
return true
})
}