diff --git a/server/images.go b/server/images.go index 68b2770a..ead2c014 100644 --- a/server/images.go +++ b/server/images.go @@ -49,6 +49,8 @@ type Model struct { Options map[string]interface{} } +var errInvalidRole = fmt.Errorf("invalid role") + type PromptVars struct { System string Prompt string @@ -119,7 +121,7 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, error) { return "", err } default: - return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) + return "", fmt.Errorf("%w: %s, role must be one of [system, user, assistant]", errInvalidRole, msg.Role) } } diff --git a/server/openai.go b/server/openai.go new file mode 100644 index 00000000..4fd8ca4b --- /dev/null +++ b/server/openai.go @@ -0,0 +1,304 @@ +package server + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "io/fs" + "log" + "math/rand" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/jmorganca/ollama/api" +) + +type OpenAIError struct { + Message string `json:"message"` + Type string `json:"type"` + Param interface{} `json:"param"` + Code *string `json:"code"` +} + +type OpenAIErrorResponse struct { + Error OpenAIError `json:"error"` +} + +type OpenAIChatCompletionRequest struct { + Model string + Messages []OpenAIMessage `json:"messages"` + Stream bool `json:"stream"` +} + +type OpenAIMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +func (m *OpenAIMessage) toMessage() api.Message { + return api.Message{ + Role: m.Role, + Content: m.Content, + } +} + +// non-streaming response + +type OpenAIChatCompletionResponseChoice struct { + Index int `json:"index"` + Message OpenAIMessage `json:"message"` + FinishReason *string `json:"finish_reason"` +} + +type OpenAIUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type OpenAIChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []OpenAIChatCompletionResponseChoice `json:"choices"` + Usage OpenAIUsage `json:"usage,omitempty"` +} + +// streaming response + +type OpenAIChatCompletionResponseChoiceStream struct { + Index int `json:"index"` + Delta OpenAIMessage `json:"delta"` + FinishReason *string `json:"finish_reason"` +} + +type OpenAIChatCompletionResponseStream struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []OpenAIChatCompletionResponseChoiceStream `json:"choices"` +} + +type StreamCompletionMarker struct{} // signals to send [DONE] on the event-stream + +func ChatCompletions(c *gin.Context) { + var req OpenAIChatCompletionRequest + err := c.ShouldBindJSON(&req) + switch { + case errors.Is(err, io.EOF): + c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{ + OpenAIError{ + Message: "missing request body", + Type: "invalid_request_error", + }, + }) + return + case err != nil: + c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{ + OpenAIError{ + Message: err.Error(), + Type: "invalid_request_error", + }, + }) + return + } + + // Call generate and receive the channel with the responses + chatReq := api.ChatRequest{ + Model: req.Model, + Stream: &req.Stream, + } + for _, m := range req.Messages { + chatReq.Messages = append(chatReq.Messages, m.toMessage()) + } + ch, err := chat(c, chatReq, time.Now()) + if err != nil { + var pErr *fs.PathError + switch { + case errors.As(err, &pErr): + c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{ + OpenAIError{ + Message: fmt.Sprintf("model '%s' not found, try pulling it first", req.Model), + Type: "invalid_request_error", + }, + }) + case errors.Is(err, api.ErrInvalidOpts), errors.Is(err, errInvalidRole): + c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{ + OpenAIError{ + Message: err.Error(), + Type: "invalid_request_error", + }, + }) + default: + c.AbortWithStatusJSON(http.StatusInternalServerError, OpenAIErrorResponse{ + OpenAIError{ + Message: err.Error(), + Type: "internal_server_error", + }, + }) + } + return + } + + if !req.Stream { + // Wait for the channel to close + var chatResponse api.ChatResponse + var sb strings.Builder + + for val := range ch { + var ok bool + chatResponse, ok = val.(api.ChatResponse) + if !ok { + c.AbortWithStatusJSON(http.StatusBadRequest, OpenAIErrorResponse{ + OpenAIError{ + Message: err.Error(), + Type: "internal_server_error", + }, + }) + return + } + if chatResponse.Message != nil { + sb.WriteString(chatResponse.Message.Content) + } + + if chatResponse.Done { + chatResponse.Message = &api.Message{Role: "assistant", Content: sb.String()} + } + } + // Send a single response with accumulated content + id := fmt.Sprintf("chatcmpl-%d", rand.Intn(999)) + chatCompletionResponse := OpenAIChatCompletionResponse{ + ID: id, + Object: "chat.completion", + Created: chatResponse.CreatedAt.Unix(), + Model: req.Model, + Choices: []OpenAIChatCompletionResponseChoice{ + { + Index: 0, + Message: OpenAIMessage{ + Role: "assistant", + Content: chatResponse.Message.Content, + }, + FinishReason: func(done bool) *string { + if done { + reason := "stop" + return &reason + } + return nil + }(chatResponse.Done), + }, + }, + } + c.JSON(http.StatusOK, chatCompletionResponse) + return + } + + // Now, create the intermediate channel and transformation goroutine + transformedCh := make(chan any) + + go func() { + defer close(transformedCh) + id := fmt.Sprintf("chatcmpl-%d", rand.Intn(999)) // TODO: validate that this does not change with each chunk + predefinedResponse := OpenAIChatCompletionResponseStream{ + ID: id, + Object: "chat.completion.chunk", + Created: time.Now().Unix(), + Model: req.Model, + Choices: []OpenAIChatCompletionResponseChoiceStream{ + { + Index: 0, + Delta: OpenAIMessage{ + Role: "assistant", + }, + }, + }, + } + transformedCh <- predefinedResponse + for val := range ch { + resp, ok := val.(api.ChatResponse) + if !ok { + // If val is not of type ChatResponse, send an error down the channel and exit + transformedCh <- OpenAIErrorResponse{ + OpenAIError{ + Message: "failed to parse chat response", + Type: "internal_server_error", + }, + } + return + } + + // Transform the ChatResponse into OpenAIChatCompletionResponse + chatCompletionResponse := OpenAIChatCompletionResponseStream{ + ID: id, + Object: "chat.completion.chunk", + Created: resp.CreatedAt.Unix(), + Model: resp.Model, + Choices: []OpenAIChatCompletionResponseChoiceStream{ + { + Index: 0, + FinishReason: func(done bool) *string { + if done { + reason := "stop" + return &reason + } + return nil + }(resp.Done), + }, + }, + } + if resp.Message != nil { + chatCompletionResponse.Choices[0].Delta = OpenAIMessage{ + Content: resp.Message.Content, + } + } + transformedCh <- chatCompletionResponse + if resp.Done { + transformedCh <- StreamCompletionMarker{} + } + } + }() + + // Pass the transformed channel to streamResponse + streamOpenAIResponse(c, transformedCh) +} + +func streamOpenAIResponse(c *gin.Context, ch chan any) { + c.Header("Content-Type", "text/event-stream") + c.Stream(func(w io.Writer) bool { + val, ok := <-ch + if !ok { + return false + } + + // Check if the message is a StreamCompletionMarker to close the event stream + if _, isCompletionMarker := val.(StreamCompletionMarker); isCompletionMarker { + if _, err := w.Write([]byte("data: [DONE]\n")); err != nil { + log.Printf("streamOpenAIResponse: w.Write failed with %s", err) + return false + } + return false // Stop streaming after sending [DONE] + } + + bts, err := json.Marshal(val) + if err != nil { + log.Printf("streamOpenAIResponse: json.Marshal failed with %s", err) + return false + } + + formattedResponse := fmt.Sprintf("data: %s\n", bts) + + if _, err := w.Write([]byte(formattedResponse)); err != nil { + log.Printf("streamOpenAIResponse: w.Write failed with %s", err) + return false + } + + return true + }) +} diff --git a/server/routes.go b/server/routes.go index ea3c0692..5cea1b88 100644 --- a/server/routes.go +++ b/server/routes.go @@ -821,6 +821,9 @@ func Serve(ln net.Listener, allowOrigins []string) error { r.POST("/api/blobs/:digest", CreateBlobHandler) r.HEAD("/api/blobs/:digest", HeadBlobHandler) + // openai compatible endpoints + r.POST("/openai/chat/completions", ChatCompletions) + for _, method := range []string{http.MethodGet, http.MethodHead} { r.Handle(method, "/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") @@ -936,14 +939,13 @@ func ChatHandler(c *gin.Context) { return } - sessionDuration := defaultSessionDuration - model, err := load(c, req.Model, req.Options, sessionDuration) + ch, err := chat(c, req, checkpointStart) if err != nil { var pErr *fs.PathError switch { case errors.As(err, &pErr): c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) - case errors.Is(err, api.ErrInvalidOpts): + case errors.Is(err, api.ErrInvalidOpts), errors.As(err, errInvalidRole): c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) default: c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -951,18 +953,41 @@ func ChatHandler(c *gin.Context) { return } - // an empty request loads the model - if len(req.Messages) == 0 { - c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}) + if req.Stream != nil && !*req.Stream { + // Wait for the channel to close + var r api.ChatResponse + var sb strings.Builder + for resp := range ch { + var ok bool + if r, ok = resp.(api.ChatResponse); !ok { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if r.Message != nil { + sb.WriteString(r.Message.Content) + } + } + r.Message = &api.Message{Role: "assistant", Content: sb.String()} + c.JSON(http.StatusOK, r) return } + streamResponse(c, ch) +} + +func chat(c *gin.Context, req api.ChatRequest, checkpointStart time.Time) (chan any, error) { + sessionDuration := defaultSessionDuration + model, err := load(c, req.Model, req.Options, sessionDuration) + if err != nil { + return nil, err + } + checkpointLoaded := time.Now() prompt, err := model.ChatPrompt(req.Messages) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return + return nil, err } ch := make(chan any) @@ -971,6 +996,11 @@ func ChatHandler(c *gin.Context) { defer close(ch) fn := func(r llm.PredictResult) { + // an empty request loads the model + if len(req.Messages) == 0 { + ch <- api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true} + return + } // Update model expiration loaded.expireAt = time.Now().Add(sessionDuration) loaded.expireTimer.Reset(sessionDuration) @@ -1009,24 +1039,5 @@ func ChatHandler(c *gin.Context) { } }() - if req.Stream != nil && !*req.Stream { - // Wait for the channel to close - var r api.ChatResponse - var sb strings.Builder - for resp := range ch { - var ok bool - if r, ok = resp.(api.ChatResponse); !ok { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - if r.Message != nil { - sb.WriteString(r.Message.Content) - } - } - r.Message = &api.Message{Role: "assistant", Content: sb.String()} - c.JSON(http.StatusOK, r) - return - } - - streamResponse(c, ch) + return ch, nil }