Compare commits

...

4 Commits

Author SHA1 Message Date
Roy Han
f16b3db70c oai compat 2024-07-30 11:29:44 -07:00
Roy Han
23ff673bdc correct output 2024-07-29 17:12:39 -07:00
Roy Han
7950053972 rm comments 2024-07-29 17:02:03 -07:00
Roy Han
d2b25c1bfb draft 2024-07-29 16:59:02 -07:00
2 changed files with 32 additions and 8 deletions

View File

@ -192,9 +192,9 @@ func toolCallId() string {
return "call_" + strings.ToLower(string(b))
}
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
for i, tc := range r.Message.ToolCalls {
func parseToolCalls(respToolCalls []api.ToolCall) []ToolCall {
toolCalls := make([]ToolCall, len(respToolCalls))
for i, tc := range respToolCalls {
toolCalls[i].ID = toolCallId()
toolCalls[i].Type = "function"
toolCalls[i].Function.Name = tc.Function.Name
@ -207,6 +207,11 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls[i].Function.Arguments = string(args)
}
return toolCalls
}
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := parseToolCalls(r.Message.ToolCalls)
return ChatCompletion{
Id: id,
@ -218,9 +223,6 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
Index: 0,
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
FinishReason: func(reason string) *string {
if len(toolCalls) > 0 {
reason = "tool_calls"
}
if len(reason) > 0 {
return &reason
}
@ -236,6 +238,8 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
}
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
toolCalls := parseToolCalls(r.Message.ToolCalls)
return ChatCompletionChunk{
Id: id,
Object: "chat.completion.chunk",
@ -244,7 +248,7 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
SystemFingerprint: "fp_ollama",
Choices: []ChunkChoice{{
Index: 0,
Delta: Message{Role: "assistant", Content: r.Message.Content},
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
FinishReason: func(reason string) *string {
if len(reason) > 0 {
return &reason

View File

@ -1369,7 +1369,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
}()
if req.Stream != nil && !*req.Stream {
if (req.Stream != nil && !*req.Stream) || ((req.Stream == nil || *req.Stream) && len(req.Tools) > 0) {
var resp api.ChatResponse
var sb strings.Builder
for rr := range ch {
@ -1400,6 +1400,26 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
}
if (req.Stream == nil || *req.Stream) && len(resp.Message.ToolCalls) > 0 {
toolCh := make(chan any)
go func() {
defer close(toolCh)
toolCalls := resp.Message.ToolCalls
for _, toolCall := range toolCalls {
toolCh <- api.ChatResponse{
Model: resp.Model,
CreatedAt: resp.CreatedAt,
Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{toolCall}},
}
}
resp.Message.ToolCalls = nil
resp.DoneReason = "tool_calls"
toolCh <- resp
}()
streamResponse(c, toolCh)
return
}
c.JSON(http.StatusOK, resp)
return
}