Merge 79379daab4075289e441961ee912d5393c290ce1 into 67691e410db7a50b07a64858820b14de9aa91314

This commit is contained in:
venjiang 2024-11-14 16:58:33 +08:00 committed by GitHub
commit 543e239a0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 112 additions and 15 deletions

View File

@ -15,19 +15,19 @@ func main() {
} }
messages := []api.Message{ messages := []api.Message{
api.Message{ {
Role: "system", Role: "system",
Content: "Provide very brief, concise responses", Content: "Provide very brief, concise responses",
}, },
api.Message{ {
Role: "user", Role: "user",
Content: "Name some unusual animals", Content: "Name some unusual animals",
}, },
api.Message{ {
Role: "assistant", Role: "assistant",
Content: "Monotreme, platypus, echidna", Content: "Monotreme, platypus, echidna",
}, },
api.Message{ {
Role: "user", Role: "user",
Content: "which of these is the most dangerous?", Content: "which of these is the most dangerous?",
}, },

View File

@ -32,8 +32,8 @@ type ErrorResponse struct {
} }
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role,omitempty"`
Content any `json:"content"` Content any `json:"content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
} }
@ -45,7 +45,7 @@ type Choice struct {
type ChunkChoice struct { type ChunkChoice struct {
Index int `json:"index"` Index int `json:"index"`
Delta Message `json:"delta"` Delta Message `json:"delta,omitempty"`
FinishReason *string `json:"finish_reason"` FinishReason *string `json:"finish_reason"`
} }
@ -139,6 +139,8 @@ type CompletionChunk struct {
} }
type ToolCall struct { type ToolCall struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
ID string `json:"id"` ID string `json:"id"`
Type string `json:"type"` Type string `json:"type"`
Function struct { Function struct {
@ -244,6 +246,28 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
} }
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
for i, tc := range r.Message.ToolCalls {
idx := i
toolCalls[i].Index = &idx
toolCalls[i].ID = toolCallId()
toolCalls[i].Type = "function"
toolCalls[i].Function.Name = tc.Function.Name
args, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("could not marshall function arguments to json", "error", err)
continue
}
toolCalls[i].Function.Arguments = string(args)
}
message := Message{Role: "assistant", Content: r.Message.Content}
hasToolCalls := len(toolCalls) > 0
if hasToolCalls {
message = Message{ToolCalls: toolCalls}
}
return ChatCompletionChunk{ return ChatCompletionChunk{
Id: id, Id: id,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
@ -252,8 +276,12 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
SystemFingerprint: "fp_ollama", SystemFingerprint: "fp_ollama",
Choices: []ChunkChoice{{ Choices: []ChunkChoice{{
Index: 0, Index: 0,
Delta: Message{Role: "assistant", Content: r.Message.Content}, // Delta: Message{Role: "assistant", Content: r.Message.Content},
Delta: message,
FinishReason: func(reason string) *string { FinishReason: func(reason string) *string {
// if hasToolCalls {
// reason = "tool_calls"
// }
if len(reason) > 0 { if len(reason) > 0 {
return &reason return &reason
} }
@ -610,6 +638,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
if chatResponse.Done { if chatResponse.Done {
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil { if err != nil {
slog.Error("writeResponse done", "err", err)
return 0, err return 0, err
} }
} }

View File

@ -1456,9 +1456,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
slog.Debug("chat request", "images", len(images), "prompt", prompt) slog.Debug("chat request", "images", len(images), "prompt", prompt)
toolCallsCh := make(chan []api.ToolCall, 1)
contentCh := make(chan string, 1)
ch := make(chan any) ch := make(chan any)
go func() { go func() {
var sb strings.Builder
defer close(ch) defer close(ch)
defer close(toolCallsCh)
defer close(contentCh)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
@ -1478,8 +1484,17 @@ func (s *Server) ChatHandler(c *gin.Context) {
EvalDuration: r.EvalDuration, EvalDuration: r.EvalDuration,
}, },
} }
sb.WriteString(r.Content)
if r.Done { if r.Done {
content := sb.String()
contentCh <- content
if len(req.Tools) > 0 {
if toolCalls, ok := m.parseToolCalls(content); ok {
toolCallsCh <- toolCalls
}
}
res.TotalDuration = time.Since(checkpointStart) res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} }
@ -1490,13 +1505,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
}() }()
toolsRequired := len(req.Tools) > 0
// no stream response
if req.Stream != nil && !*req.Stream { if req.Stream != nil && !*req.Stream {
var resp api.ChatResponse var resp api.ChatResponse
var sb strings.Builder
for rr := range ch { for rr := range ch {
switch t := rr.(type) { switch t := rr.(type) {
case api.ChatResponse: case api.ChatResponse:
sb.WriteString(t.Message.Content)
resp = t resp = t
case gin.H: case gin.H:
msg, ok := t["error"].(string) msg, ok := t["error"].(string)
@ -1512,10 +1527,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
} }
resp.Message.Content = sb.String() content := <-contentCh
resp.Message.Content = content
if len(req.Tools) > 0 { if toolsRequired {
if toolCalls, ok := m.parseToolCalls(sb.String()); ok { toolCalls := <-toolCallsCh
if len(toolCalls) > 0 {
resp.Message.ToolCalls = toolCalls resp.Message.ToolCalls = toolCalls
resp.Message.Content = "" resp.Message.Content = ""
} }
@ -1525,7 +1541,59 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
streamResponse(c, ch) // stream response
streamCh := make(chan any)
for rr := range ch {
switch t := rr.(type) {
case api.ChatResponse:
go func() {
// slog.Warn("reassign chat response", "content", t.Message.Content)
streamCh <- t
if t.Done {
// slog.Warn("close stream channel")
close(streamCh)
}
}()
case gin.H:
msg, ok := t["error"].(string)
if !ok {
msg = "unexpected error format in response"
}
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
return
}
}
// if tools are required
if toolsRequired {
toolCalls := <-toolCallsCh
// if tool calls are present, use different channel respose
hasToolCalls := len(toolCalls) > 0
if hasToolCalls {
// reset the channel
toolCallsCh := make(chan any, 1)
res := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", ToolCalls: toolCalls},
Done: true,
DoneReason: "tool_calls",
}
toolCallsCh <- res
close(toolCallsCh)
slog.Info("[tools] stream response")
streamResponse(c, toolCallsCh)
return
} else {
slog.Info("[tools] no call")
}
}
slog.Info("stream response")
streamResponse(c, streamCh)
} }
func handleScheduleError(c *gin.Context, name string, err error) { func handleScheduleError(c *gin.Context, name string, err error) {