openai chat proxy endpoint

This commit is contained in:
Bruce MacDonald 2023-12-01 18:13:52 -05:00
parent f9b7d65e2b
commit 931dc1b36f
3 changed files with 345 additions and 28 deletions

View File

@ -49,6 +49,8 @@ type Model struct {
Options map[string]interface{}
}
var errInvalidRole = fmt.Errorf("invalid role")
type PromptVars struct {
System string
Prompt string
@ -119,7 +121,7 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
return "", err
}
default:
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
return "", fmt.Errorf("%w: %s, role must be one of [system, user, assistant]", errInvalidRole, msg.Role)
}
}

304
server/openai.go Normal file
View File

@ -0,0 +1,304 @@
package server
import (
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log"
"math/rand"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/jmorganca/ollama/api"
)
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param interface{} `json:"param"`
Code *string `json:"code"`
}
type OpenAIErrorResponse struct {
Error OpenAIError `json:"error"`
}
type OpenAIChatCompletionRequest struct {
Model string
Messages []OpenAIMessage `json:"messages"`
Stream bool `json:"stream"`
}
type OpenAIMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
func (m *OpenAIMessage) toMessage() api.Message {
return api.Message{
Role: m.Role,
Content: m.Content,
}
}
// non-streaming response
type OpenAIChatCompletionResponseChoice struct {
Index int `json:"index"`
Message OpenAIMessage `json:"message"`
FinishReason *string `json:"finish_reason"`
}
type OpenAIUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type OpenAIChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []OpenAIChatCompletionResponseChoice `json:"choices"`
Usage OpenAIUsage `json:"usage,omitempty"`
}
// streaming response
type OpenAIChatCompletionResponseChoiceStream struct {
Index int `json:"index"`
Delta OpenAIMessage `json:"delta"`
FinishReason *string `json:"finish_reason"`
}
type OpenAIChatCompletionResponseStream struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []OpenAIChatCompletionResponseChoiceStream `json:"choices"`
}
type StreamCompletionMarker struct{} // signals to send [DONE] on the event-stream
func ChatCompletions(c *gin.Context) {
var req OpenAIChatCompletionRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
OpenAIError{
Message: "missing request body",
Type: "invalid_request_error",
},
})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
OpenAIError{
Message: err.Error(),
Type: "invalid_request_error",
},
})
return
}
// Call generate and receive the channel with the responses
chatReq := api.ChatRequest{
Model: req.Model,
Stream: &req.Stream,
}
for _, m := range req.Messages {
chatReq.Messages = append(chatReq.Messages, m.toMessage())
}
ch, err := chat(c, chatReq, time.Now())
if err != nil {
var pErr *fs.PathError
switch {
case errors.As(err, &pErr):
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
OpenAIError{
Message: fmt.Sprintf("model '%s' not found, try pulling it first", req.Model),
Type: "invalid_request_error",
},
})
case errors.Is(err, api.ErrInvalidOpts), errors.Is(err, errInvalidRole):
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
OpenAIError{
Message: err.Error(),
Type: "invalid_request_error",
},
})
default:
c.AbortWithStatusJSON(http.StatusInternalServerError, OpenAIErrorResponse{
OpenAIError{
Message: err.Error(),
Type: "internal_server_error",
},
})
}
return
}
if !req.Stream {
// Wait for the channel to close
var chatResponse api.ChatResponse
var sb strings.Builder
for val := range ch {
var ok bool
chatResponse, ok = val.(api.ChatResponse)
if !ok {
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
OpenAIError{
Message: err.Error(),
Type: "internal_server_error",
},
})
return
}
if chatResponse.Message != nil {
sb.WriteString(chatResponse.Message.Content)
}
if chatResponse.Done {
chatResponse.Message = &api.Message{Role: "assistant", Content: sb.String()}
}
}
// Send a single response with accumulated content
id := fmt.Sprintf("chatcmpl-%d", rand.Intn(999))
chatCompletionResponse := OpenAIChatCompletionResponse{
ID: id,
Object: "chat.completion",
Created: chatResponse.CreatedAt.Unix(),
Model: req.Model,
Choices: []OpenAIChatCompletionResponseChoice{
{
Index: 0,
Message: OpenAIMessage{
Role: "assistant",
Content: chatResponse.Message.Content,
},
FinishReason: func(done bool) *string {
if done {
reason := "stop"
return &reason
}
return nil
}(chatResponse.Done),
},
},
}
c.JSON(http.StatusOK, chatCompletionResponse)
return
}
// Now, create the intermediate channel and transformation goroutine
transformedCh := make(chan any)
go func() {
defer close(transformedCh)
id := fmt.Sprintf("chatcmpl-%d", rand.Intn(999)) // TODO: validate that this does not change with each chunk
predefinedResponse := OpenAIChatCompletionResponseStream{
ID: id,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: req.Model,
Choices: []OpenAIChatCompletionResponseChoiceStream{
{
Index: 0,
Delta: OpenAIMessage{
Role: "assistant",
},
},
},
}
transformedCh <- predefinedResponse
for val := range ch {
resp, ok := val.(api.ChatResponse)
if !ok {
// If val is not of type ChatResponse, send an error down the channel and exit
transformedCh <- OpenAIErrorResponse{
OpenAIError{
Message: "failed to parse chat response",
Type: "internal_server_error",
},
}
return
}
// Transform the ChatResponse into OpenAIChatCompletionResponse
chatCompletionResponse := OpenAIChatCompletionResponseStream{
ID: id,
Object: "chat.completion.chunk",
Created: resp.CreatedAt.Unix(),
Model: resp.Model,
Choices: []OpenAIChatCompletionResponseChoiceStream{
{
Index: 0,
FinishReason: func(done bool) *string {
if done {
reason := "stop"
return &reason
}
return nil
}(resp.Done),
},
},
}
if resp.Message != nil {
chatCompletionResponse.Choices[0].Delta = OpenAIMessage{
Content: resp.Message.Content,
}
}
transformedCh <- chatCompletionResponse
if resp.Done {
transformedCh <- StreamCompletionMarker{}
}
}
}()
// Pass the transformed channel to streamResponse
streamOpenAIResponse(c, transformedCh)
}
func streamOpenAIResponse(c *gin.Context, ch chan any) {
c.Header("Content-Type", "text/event-stream")
c.Stream(func(w io.Writer) bool {
val, ok := <-ch
if !ok {
return false
}
// Check if the message is a StreamCompletionMarker to close the event stream
if _, isCompletionMarker := val.(StreamCompletionMarker); isCompletionMarker {
if _, err := w.Write([]byte("data: [DONE]\n")); err != nil {
log.Printf("streamOpenAIResponse: w.Write failed with %s", err)
return false
}
return false // Stop streaming after sending [DONE]
}
bts, err := json.Marshal(val)
if err != nil {
log.Printf("streamOpenAIResponse: json.Marshal failed with %s", err)
return false
}
formattedResponse := fmt.Sprintf("data: %s\n", bts)
if _, err := w.Write([]byte(formattedResponse)); err != nil {
log.Printf("streamOpenAIResponse: w.Write failed with %s", err)
return false
}
return true
})
}

View File

@ -821,6 +821,9 @@ func Serve(ln net.Listener, allowOrigins []string) error {
r.POST("/api/blobs/:digest", CreateBlobHandler)
r.HEAD("/api/blobs/:digest", HeadBlobHandler)
// openai compatible endpoints
r.POST("/openai/chat/completions", ChatCompletions)
for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) {
c.String(http.StatusOK, "Ollama is running")
@ -936,14 +939,13 @@ func ChatHandler(c *gin.Context) {
return
}
sessionDuration := defaultSessionDuration
model, err := load(c, req.Model, req.Options, sessionDuration)
ch, err := chat(c, req, checkpointStart)
if err != nil {
var pErr *fs.PathError
switch {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
case errors.Is(err, api.ErrInvalidOpts):
case errors.Is(err, api.ErrInvalidOpts), errors.As(err, errInvalidRole):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@ -951,18 +953,41 @@ func ChatHandler(c *gin.Context) {
return
}
// an empty request loads the model
if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true})
if req.Stream != nil && !*req.Stream {
// Wait for the channel to close
var r api.ChatResponse
var sb strings.Builder
for resp := range ch {
var ok bool
if r, ok = resp.(api.ChatResponse); !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if r.Message != nil {
sb.WriteString(r.Message.Content)
}
}
r.Message = &api.Message{Role: "assistant", Content: sb.String()}
c.JSON(http.StatusOK, r)
return
}
streamResponse(c, ch)
}
func chat(c *gin.Context, req api.ChatRequest, checkpointStart time.Time) (chan any, error) {
sessionDuration := defaultSessionDuration
model, err := load(c, req.Model, req.Options, sessionDuration)
if err != nil {
return nil, err
}
checkpointLoaded := time.Now()
prompt, err := model.ChatPrompt(req.Messages)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
return nil, err
}
ch := make(chan any)
@ -971,6 +996,11 @@ func ChatHandler(c *gin.Context) {
defer close(ch)
fn := func(r llm.PredictResult) {
// an empty request loads the model
if len(req.Messages) == 0 {
ch <- api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
return
}
// Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration)
@ -1009,24 +1039,5 @@ func ChatHandler(c *gin.Context) {
}
}()
if req.Stream != nil && !*req.Stream {
// Wait for the channel to close
var r api.ChatResponse
var sb strings.Builder
for resp := range ch {
var ok bool
if r, ok = resp.(api.ChatResponse); !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if r.Message != nil {
sb.WriteString(r.Message.Content)
}
}
r.Message = &api.Message{Role: "assistant", Content: sb.String()}
c.JSON(http.StatusOK, r)
return
}
streamResponse(c, ch)
return ch, nil
}