Merge 79379daab4075289e441961ee912d5393c290ce1 into 67691e410db7a50b07a64858820b14de9aa91314
This commit is contained in:
commit
543e239a0c
@ -15,19 +15,19 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
messages := []api.Message{
|
messages := []api.Message{
|
||||||
api.Message{
|
{
|
||||||
Role: "system",
|
Role: "system",
|
||||||
Content: "Provide very brief, concise responses",
|
Content: "Provide very brief, concise responses",
|
||||||
},
|
},
|
||||||
api.Message{
|
{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: "Name some unusual animals",
|
Content: "Name some unusual animals",
|
||||||
},
|
},
|
||||||
api.Message{
|
{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: "Monotreme, platypus, echidna",
|
Content: "Monotreme, platypus, echidna",
|
||||||
},
|
},
|
||||||
api.Message{
|
{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: "which of these is the most dangerous?",
|
Content: "which of these is the most dangerous?",
|
||||||
},
|
},
|
||||||
|
@ -32,8 +32,8 @@ type ErrorResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role,omitempty"`
|
||||||
Content any `json:"content"`
|
Content any `json:"content,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ type Choice struct {
|
|||||||
|
|
||||||
type ChunkChoice struct {
|
type ChunkChoice struct {
|
||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
Delta Message `json:"delta"`
|
Delta Message `json:"delta,omitempty"`
|
||||||
FinishReason *string `json:"finish_reason"`
|
FinishReason *string `json:"finish_reason"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,6 +139,8 @@ type CompletionChunk struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ToolCall struct {
|
type ToolCall struct {
|
||||||
|
// Index is not nil only in chat completion chunk object
|
||||||
|
Index *int `json:"index,omitempty"`
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Function struct {
|
Function struct {
|
||||||
@ -244,6 +246,28 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
||||||
|
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
|
||||||
|
for i, tc := range r.Message.ToolCalls {
|
||||||
|
idx := i
|
||||||
|
toolCalls[i].Index = &idx
|
||||||
|
toolCalls[i].ID = toolCallId()
|
||||||
|
toolCalls[i].Type = "function"
|
||||||
|
toolCalls[i].Function.Name = tc.Function.Name
|
||||||
|
|
||||||
|
args, err := json.Marshal(tc.Function.Arguments)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("could not marshall function arguments to json", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls[i].Function.Arguments = string(args)
|
||||||
|
}
|
||||||
|
|
||||||
|
message := Message{Role: "assistant", Content: r.Message.Content}
|
||||||
|
hasToolCalls := len(toolCalls) > 0
|
||||||
|
if hasToolCalls {
|
||||||
|
message = Message{ToolCalls: toolCalls}
|
||||||
|
}
|
||||||
return ChatCompletionChunk{
|
return ChatCompletionChunk{
|
||||||
Id: id,
|
Id: id,
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
@ -252,8 +276,12 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
|||||||
SystemFingerprint: "fp_ollama",
|
SystemFingerprint: "fp_ollama",
|
||||||
Choices: []ChunkChoice{{
|
Choices: []ChunkChoice{{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Delta: Message{Role: "assistant", Content: r.Message.Content},
|
// Delta: Message{Role: "assistant", Content: r.Message.Content},
|
||||||
|
Delta: message,
|
||||||
FinishReason: func(reason string) *string {
|
FinishReason: func(reason string) *string {
|
||||||
|
// if hasToolCalls {
|
||||||
|
// reason = "tool_calls"
|
||||||
|
// }
|
||||||
if len(reason) > 0 {
|
if len(reason) > 0 {
|
||||||
return &reason
|
return &reason
|
||||||
}
|
}
|
||||||
@ -610,6 +638,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
|||||||
if chatResponse.Done {
|
if chatResponse.Done {
|
||||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
slog.Error("writeResponse done", "err", err)
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1456,9 +1456,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
|
|
||||||
slog.Debug("chat request", "images", len(images), "prompt", prompt)
|
slog.Debug("chat request", "images", len(images), "prompt", prompt)
|
||||||
|
|
||||||
|
toolCallsCh := make(chan []api.ToolCall, 1)
|
||||||
|
contentCh := make(chan string, 1)
|
||||||
|
|
||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
go func() {
|
go func() {
|
||||||
|
var sb strings.Builder
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
|
defer close(toolCallsCh)
|
||||||
|
defer close(contentCh)
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
@ -1478,8 +1484,17 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
EvalDuration: r.EvalDuration,
|
EvalDuration: r.EvalDuration,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
sb.WriteString(r.Content)
|
||||||
|
|
||||||
if r.Done {
|
if r.Done {
|
||||||
|
content := sb.String()
|
||||||
|
contentCh <- content
|
||||||
|
if len(req.Tools) > 0 {
|
||||||
|
if toolCalls, ok := m.parseToolCalls(content); ok {
|
||||||
|
toolCallsCh <- toolCalls
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
res.TotalDuration = time.Since(checkpointStart)
|
res.TotalDuration = time.Since(checkpointStart)
|
||||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
}
|
}
|
||||||
@ -1490,13 +1505,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
toolsRequired := len(req.Tools) > 0
|
||||||
|
// no stream response
|
||||||
if req.Stream != nil && !*req.Stream {
|
if req.Stream != nil && !*req.Stream {
|
||||||
var resp api.ChatResponse
|
var resp api.ChatResponse
|
||||||
var sb strings.Builder
|
|
||||||
for rr := range ch {
|
for rr := range ch {
|
||||||
switch t := rr.(type) {
|
switch t := rr.(type) {
|
||||||
case api.ChatResponse:
|
case api.ChatResponse:
|
||||||
sb.WriteString(t.Message.Content)
|
|
||||||
resp = t
|
resp = t
|
||||||
case gin.H:
|
case gin.H:
|
||||||
msg, ok := t["error"].(string)
|
msg, ok := t["error"].(string)
|
||||||
@ -1512,10 +1527,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Message.Content = sb.String()
|
content := <-contentCh
|
||||||
|
resp.Message.Content = content
|
||||||
if len(req.Tools) > 0 {
|
if toolsRequired {
|
||||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
toolCalls := <-toolCallsCh
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
resp.Message.ToolCalls = toolCalls
|
resp.Message.ToolCalls = toolCalls
|
||||||
resp.Message.Content = ""
|
resp.Message.Content = ""
|
||||||
}
|
}
|
||||||
@ -1525,7 +1541,59 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
streamResponse(c, ch)
|
// stream response
|
||||||
|
streamCh := make(chan any)
|
||||||
|
for rr := range ch {
|
||||||
|
switch t := rr.(type) {
|
||||||
|
case api.ChatResponse:
|
||||||
|
go func() {
|
||||||
|
// slog.Warn("reassign chat response", "content", t.Message.Content)
|
||||||
|
streamCh <- t
|
||||||
|
if t.Done {
|
||||||
|
// slog.Warn("close stream channel")
|
||||||
|
close(streamCh)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
case gin.H:
|
||||||
|
msg, ok := t["error"].(string)
|
||||||
|
if !ok {
|
||||||
|
msg = "unexpected error format in response"
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if tools are required
|
||||||
|
if toolsRequired {
|
||||||
|
toolCalls := <-toolCallsCh
|
||||||
|
// if tool calls are present, use different channel respose
|
||||||
|
hasToolCalls := len(toolCalls) > 0
|
||||||
|
if hasToolCalls {
|
||||||
|
// reset the channel
|
||||||
|
toolCallsCh := make(chan any, 1)
|
||||||
|
res := api.ChatResponse{
|
||||||
|
Model: req.Model,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
Message: api.Message{Role: "assistant", ToolCalls: toolCalls},
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "tool_calls",
|
||||||
|
}
|
||||||
|
toolCallsCh <- res
|
||||||
|
close(toolCallsCh)
|
||||||
|
slog.Info("[tools] stream response")
|
||||||
|
streamResponse(c, toolCallsCh)
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
slog.Info("[tools] no call")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("stream response")
|
||||||
|
streamResponse(c, streamCh)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleScheduleError(c *gin.Context, name string, err error) {
|
func handleScheduleError(c *gin.Context, name string, err error) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user