openai chat proxy endpoint
This commit is contained in:
parent
f9b7d65e2b
commit
931dc1b36f
@ -49,6 +49,8 @@ type Model struct {
|
||||
Options map[string]interface{}
|
||||
}
|
||||
|
||||
var errInvalidRole = fmt.Errorf("invalid role")
|
||||
|
||||
type PromptVars struct {
|
||||
System string
|
||||
Prompt string
|
||||
@ -119,7 +121,7 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
default:
|
||||
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||
return "", fmt.Errorf("%w: %s, role must be one of [system, user, assistant]", errInvalidRole, msg.Role)
|
||||
}
|
||||
}
|
||||
|
||||
|
304
server/openai.go
Normal file
304
server/openai.go
Normal file
@ -0,0 +1,304 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
type OpenAIError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Param interface{} `json:"param"`
|
||||
Code *string `json:"code"`
|
||||
}
|
||||
|
||||
type OpenAIErrorResponse struct {
|
||||
Error OpenAIError `json:"error"`
|
||||
}
|
||||
|
||||
type OpenAIChatCompletionRequest struct {
|
||||
Model string
|
||||
Messages []OpenAIMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type OpenAIMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func (m *OpenAIMessage) toMessage() api.Message {
|
||||
return api.Message{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
}
|
||||
}
|
||||
|
||||
// non-streaming response
|
||||
|
||||
type OpenAIChatCompletionResponseChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message OpenAIMessage `json:"message"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type OpenAIUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type OpenAIChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Choices []OpenAIChatCompletionResponseChoice `json:"choices"`
|
||||
Usage OpenAIUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// streaming response
|
||||
|
||||
type OpenAIChatCompletionResponseChoiceStream struct {
|
||||
Index int `json:"index"`
|
||||
Delta OpenAIMessage `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type OpenAIChatCompletionResponseStream struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Choices []OpenAIChatCompletionResponseChoiceStream `json:"choices"`
|
||||
}
|
||||
|
||||
type StreamCompletionMarker struct{} // signals to send [DONE] on the event-stream
|
||||
|
||||
func ChatCompletions(c *gin.Context) {
|
||||
var req OpenAIChatCompletionRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
|
||||
OpenAIError{
|
||||
Message: "missing request body",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
case err != nil:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
|
||||
OpenAIError{
|
||||
Message: err.Error(),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Call generate and receive the channel with the responses
|
||||
chatReq := api.ChatRequest{
|
||||
Model: req.Model,
|
||||
Stream: &req.Stream,
|
||||
}
|
||||
for _, m := range req.Messages {
|
||||
chatReq.Messages = append(chatReq.Messages, m.toMessage())
|
||||
}
|
||||
ch, err := chat(c, chatReq, time.Now())
|
||||
if err != nil {
|
||||
var pErr *fs.PathError
|
||||
switch {
|
||||
case errors.As(err, &pErr):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
|
||||
OpenAIError{
|
||||
Message: fmt.Sprintf("model '%s' not found, try pulling it first", req.Model),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
case errors.Is(err, api.ErrInvalidOpts), errors.Is(err, errInvalidRole):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
|
||||
OpenAIError{
|
||||
Message: err.Error(),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
default:
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, OpenAIErrorResponse{
|
||||
OpenAIError{
|
||||
Message: err.Error(),
|
||||
Type: "internal_server_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !req.Stream {
|
||||
// Wait for the channel to close
|
||||
var chatResponse api.ChatResponse
|
||||
var sb strings.Builder
|
||||
|
||||
for val := range ch {
|
||||
var ok bool
|
||||
chatResponse, ok = val.(api.ChatResponse)
|
||||
if !ok {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{
|
||||
OpenAIError{
|
||||
Message: err.Error(),
|
||||
Type: "internal_server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
if chatResponse.Message != nil {
|
||||
sb.WriteString(chatResponse.Message.Content)
|
||||
}
|
||||
|
||||
if chatResponse.Done {
|
||||
chatResponse.Message = &api.Message{Role: "assistant", Content: sb.String()}
|
||||
}
|
||||
}
|
||||
// Send a single response with accumulated content
|
||||
id := fmt.Sprintf("chatcmpl-%d", rand.Intn(999))
|
||||
chatCompletionResponse := OpenAIChatCompletionResponse{
|
||||
ID: id,
|
||||
Object: "chat.completion",
|
||||
Created: chatResponse.CreatedAt.Unix(),
|
||||
Model: req.Model,
|
||||
Choices: []OpenAIChatCompletionResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: OpenAIMessage{
|
||||
Role: "assistant",
|
||||
Content: chatResponse.Message.Content,
|
||||
},
|
||||
FinishReason: func(done bool) *string {
|
||||
if done {
|
||||
reason := "stop"
|
||||
return &reason
|
||||
}
|
||||
return nil
|
||||
}(chatResponse.Done),
|
||||
},
|
||||
},
|
||||
}
|
||||
c.JSON(http.StatusOK, chatCompletionResponse)
|
||||
return
|
||||
}
|
||||
|
||||
// Now, create the intermediate channel and transformation goroutine
|
||||
transformedCh := make(chan any)
|
||||
|
||||
go func() {
|
||||
defer close(transformedCh)
|
||||
id := fmt.Sprintf("chatcmpl-%d", rand.Intn(999)) // TODO: validate that this does not change with each chunk
|
||||
predefinedResponse := OpenAIChatCompletionResponseStream{
|
||||
ID: id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: time.Now().Unix(),
|
||||
Model: req.Model,
|
||||
Choices: []OpenAIChatCompletionResponseChoiceStream{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: OpenAIMessage{
|
||||
Role: "assistant",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
transformedCh <- predefinedResponse
|
||||
for val := range ch {
|
||||
resp, ok := val.(api.ChatResponse)
|
||||
if !ok {
|
||||
// If val is not of type ChatResponse, send an error down the channel and exit
|
||||
transformedCh <- OpenAIErrorResponse{
|
||||
OpenAIError{
|
||||
Message: "failed to parse chat response",
|
||||
Type: "internal_server_error",
|
||||
},
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Transform the ChatResponse into OpenAIChatCompletionResponse
|
||||
chatCompletionResponse := OpenAIChatCompletionResponseStream{
|
||||
ID: id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: resp.CreatedAt.Unix(),
|
||||
Model: resp.Model,
|
||||
Choices: []OpenAIChatCompletionResponseChoiceStream{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: func(done bool) *string {
|
||||
if done {
|
||||
reason := "stop"
|
||||
return &reason
|
||||
}
|
||||
return nil
|
||||
}(resp.Done),
|
||||
},
|
||||
},
|
||||
}
|
||||
if resp.Message != nil {
|
||||
chatCompletionResponse.Choices[0].Delta = OpenAIMessage{
|
||||
Content: resp.Message.Content,
|
||||
}
|
||||
}
|
||||
transformedCh <- chatCompletionResponse
|
||||
if resp.Done {
|
||||
transformedCh <- StreamCompletionMarker{}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Pass the transformed channel to streamResponse
|
||||
streamOpenAIResponse(c, transformedCh)
|
||||
}
|
||||
|
||||
func streamOpenAIResponse(c *gin.Context, ch chan any) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
val, ok := <-ch
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if the message is a StreamCompletionMarker to close the event stream
|
||||
if _, isCompletionMarker := val.(StreamCompletionMarker); isCompletionMarker {
|
||||
if _, err := w.Write([]byte("data: [DONE]\n")); err != nil {
|
||||
log.Printf("streamOpenAIResponse: w.Write failed with %s", err)
|
||||
return false
|
||||
}
|
||||
return false // Stop streaming after sending [DONE]
|
||||
}
|
||||
|
||||
bts, err := json.Marshal(val)
|
||||
if err != nil {
|
||||
log.Printf("streamOpenAIResponse: json.Marshal failed with %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
formattedResponse := fmt.Sprintf("data: %s\n", bts)
|
||||
|
||||
if _, err := w.Write([]byte(formattedResponse)); err != nil {
|
||||
log.Printf("streamOpenAIResponse: w.Write failed with %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
@ -821,6 +821,9 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
||||
r.POST("/api/blobs/:digest", CreateBlobHandler)
|
||||
r.HEAD("/api/blobs/:digest", HeadBlobHandler)
|
||||
|
||||
// openai compatible endpoints
|
||||
r.POST("/openai/chat/completions", ChatCompletions)
|
||||
|
||||
for _, method := range []string{http.MethodGet, http.MethodHead} {
|
||||
r.Handle(method, "/", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "Ollama is running")
|
||||
@ -936,14 +939,13 @@ func ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
sessionDuration := defaultSessionDuration
|
||||
model, err := load(c, req.Model, req.Options, sessionDuration)
|
||||
ch, err := chat(c, req, checkpointStart)
|
||||
if err != nil {
|
||||
var pErr *fs.PathError
|
||||
switch {
|
||||
case errors.As(err, &pErr):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||
case errors.Is(err, api.ErrInvalidOpts):
|
||||
case errors.Is(err, api.ErrInvalidOpts), errors.As(err, errInvalidRole):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@ -951,18 +953,41 @@ func ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// an empty request loads the model
|
||||
if len(req.Messages) == 0 {
|
||||
c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true})
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
// Wait for the channel to close
|
||||
var r api.ChatResponse
|
||||
var sb strings.Builder
|
||||
for resp := range ch {
|
||||
var ok bool
|
||||
if r, ok = resp.(api.ChatResponse); !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if r.Message != nil {
|
||||
sb.WriteString(r.Message.Content)
|
||||
}
|
||||
}
|
||||
r.Message = &api.Message{Role: "assistant", Content: sb.String()}
|
||||
c.JSON(http.StatusOK, r)
|
||||
return
|
||||
}
|
||||
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func chat(c *gin.Context, req api.ChatRequest, checkpointStart time.Time) (chan any, error) {
|
||||
sessionDuration := defaultSessionDuration
|
||||
model, err := load(c, req.Model, req.Options, sessionDuration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
prompt, err := model.ChatPrompt(req.Messages)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ch := make(chan any)
|
||||
@ -971,6 +996,11 @@ func ChatHandler(c *gin.Context) {
|
||||
defer close(ch)
|
||||
|
||||
fn := func(r llm.PredictResult) {
|
||||
// an empty request loads the model
|
||||
if len(req.Messages) == 0 {
|
||||
ch <- api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
|
||||
return
|
||||
}
|
||||
// Update model expiration
|
||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||
loaded.expireTimer.Reset(sessionDuration)
|
||||
@ -1009,24 +1039,5 @@ func ChatHandler(c *gin.Context) {
|
||||
}
|
||||
}()
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
// Wait for the channel to close
|
||||
var r api.ChatResponse
|
||||
var sb strings.Builder
|
||||
for resp := range ch {
|
||||
var ok bool
|
||||
if r, ok = resp.(api.ChatResponse); !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if r.Message != nil {
|
||||
sb.WriteString(r.Message.Content)
|
||||
}
|
||||
}
|
||||
r.Message = &api.Message{Role: "assistant", Content: sb.String()}
|
||||
c.JSON(http.StatusOK, r)
|
||||
return
|
||||
}
|
||||
|
||||
streamResponse(c, ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user