From e117483ef6f24e115d64a982bd339a1739d4eb06 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Mon, 22 Apr 2024 09:30:19 -0400 Subject: [PATCH] Add `done_reason` --- api/types.go | 8 ++++--- llm/server.go | 24 +++++++++++++++---- server/prompt.go | 18 ++++++++------- server/prompt_test.go | 2 +- server/routes.go | 54 +++++++++++++++++++++++++++---------------- 5 files changed, 70 insertions(+), 36 deletions(-) diff --git a/api/types.go b/api/types.go index 7fe2b4e4..732bac06 100644 --- a/api/types.go +++ b/api/types.go @@ -98,7 +98,8 @@ type ChatResponse struct { CreatedAt time.Time `json:"created_at"` Message Message `json:"message"` - Done bool `json:"done"` + Done bool `json:"done"` + DoneReason string `json:"done_reason,omitempty"` Metrics } @@ -265,8 +266,9 @@ type GenerateResponse struct { CreatedAt time.Time `json:"created_at"` Response string `json:"response"` - Done bool `json:"done"` - Context []int `json:"context,omitempty"` + Done bool `json:"done"` + DoneReason string `json:"done_reason,omitempty"` + Context []int `json:"context,omitempty"` Metrics } diff --git a/llm/server.go b/llm/server.go index 3cab6f1d..ecc87860 100644 --- a/llm/server.go +++ b/llm/server.go @@ -509,10 +509,13 @@ type ImageData struct { } type completion struct { - Content string `json:"content"` - Model string `json:"model"` - Prompt string `json:"prompt"` - Stop bool `json:"stop"` + Content string `json:"content"` + Model string `json:"model"` + Prompt string `json:"prompt"` + Stop bool `json:"stop"` + StoppedEos bool `json:"stopped_eos"` + StoppedWord bool `json:"stopped_word"` + StoppedLimit bool `json:"stopped_limit"` Timings struct { PredictedN int `json:"predicted_n"` @@ -532,6 +535,7 @@ type CompletionRequest struct { type CompletionResponse struct { Content string Done bool + DoneReason string PromptEvalCount int PromptEvalDuration time.Duration EvalCount int @@ -648,6 +652,8 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn return fmt.Errorf("error parsing llm response stream: %s", line) } + fmt.Println("c", string(evt)) + var c completion if err := json.Unmarshal(evt, &c); err != nil { return fmt.Errorf("error unmarshaling llm prediction response: %v", err) @@ -674,8 +680,18 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn } if c.Stop { + var doneReason string + switch { + case c.StoppedEos: + doneReason = "stop" + case c.StoppedWord: + doneReason = "stop" + case c.StoppedLimit: + doneReason = "limit" + } fn(CompletionResponse{ Done: true, + DoneReason: doneReason, PromptEvalCount: c.Timings.PromptN, PromptEvalDuration: parseDurationMs(c.Timings.PromptMS), EvalCount: c.Timings.PredictedN, diff --git a/server/prompt.go b/server/prompt.go index 604e6971..7d33418e 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -91,7 +91,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc } // ChatPrompt builds up a prompt from a series of messages, truncating based on context window size -func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) { +func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, int, error) { type prompt struct { System string Prompt string @@ -138,7 +138,7 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str p.Response = msg.Content default: - return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) + return "", 0, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) } } @@ -151,7 +151,7 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str for i, p := range prompts { tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode) if err != nil { - return "", err + return "", 0, err } prompts[i].tokens = tokens + len(prompts[i].images)*768 @@ -160,15 +160,17 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str // truncate images and prompts starting from the beginning of the list // until either one prompt remains or the total tokens fits the context window // TODO (jmorganca): this doesn't account for the context window room required for the response + var required int for { - var required int + required = 0 for _, p := range prompts { required += p.tokens } required += 1 // for bos token - if required <= window { + // leave ~1024 tokens for generation + if required <= max(1024, window/2) { slog.Debug("prompt now fits in context window", "required", required, "window", window) break } @@ -194,7 +196,7 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode) if err != nil { - return "", err + return "", 0, err } prompts[0].tokens = tokens + len(prompts[0].images)*768 @@ -212,10 +214,10 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str // last prompt should leave the response unrendered (for completion) rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1) if err != nil { - return "", err + return "", 0, err } sb.WriteString(rendered) } - return sb.String(), nil + return sb.String(), required, nil } diff --git a/server/prompt_test.go b/server/prompt_test.go index a7e18a70..7795954e 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -192,7 +192,7 @@ func TestChatPrompt(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode) + got, _, err := ChatPrompt(tc.template, tc.messages, tc.window, encode) if err != nil { t.Errorf("error = %v", err) } diff --git a/server/routes.go b/server/routes.go index b0d36b14..955b7889 100644 --- a/server/routes.go +++ b/server/routes.go @@ -234,9 +234,10 @@ func GenerateHandler(c *gin.Context) { // of `raw` mode so we need to check for it too if req.Prompt == "" && req.Template == "" && req.System == "" { c.JSON(http.StatusOK, api.GenerateResponse{ - CreatedAt: time.Now().UTC(), - Model: req.Model, - Done: true, + CreatedAt: time.Now().UTC(), + Model: req.Model, + Done: true, + DoneReason: "load", }) return } @@ -289,6 +290,14 @@ func GenerateHandler(c *gin.Context) { prompt = sb.String() } + tokens, err := loaded.llama.Tokenize(c.Request.Context(), prompt) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + opts.NumPredict = max(opts.NumCtx-len(tokens), 0) + slog.Debug("generate handler", "prompt", prompt) ch := make(chan any) @@ -307,10 +316,11 @@ func GenerateHandler(c *gin.Context) { } resp := api.GenerateResponse{ - Model: req.Model, - CreatedAt: time.Now().UTC(), - Done: r.Done, - Response: r.Content, + Model: req.Model, + CreatedAt: time.Now().UTC(), + Done: r.Done, + DoneReason: r.DoneReason, + Response: r.Content, Metrics: api.Metrics{ PromptEvalCount: r.PromptEvalCount, PromptEvalDuration: r.PromptEvalDuration, @@ -1219,17 +1229,17 @@ func streamResponse(c *gin.Context, ch chan any) { } // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model -func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) { +func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, int, error) { encode := func(s string) ([]int, error) { return loaded.llama.Tokenize(ctx, s) } - prompt, err := ChatPrompt(template, messages, numCtx, encode) + prompt, tokens, err := ChatPrompt(template, messages, numCtx, encode) if err != nil { - return "", err + return "", 0, err } - return prompt, nil + return prompt, tokens, nil } func ChatHandler(c *gin.Context) { @@ -1309,19 +1319,22 @@ func ChatHandler(c *gin.Context) { }, req.Messages...) } - prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx) + prompt, tokens, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + opts.NumPredict = max(opts.NumCtx-tokens, 0) + // an empty request loads the model if len(req.Messages) == 0 || prompt == "" { resp := api.ChatResponse{ - CreatedAt: time.Now().UTC(), - Model: req.Model, - Done: true, - Message: api.Message{Role: "assistant"}, + CreatedAt: time.Now().UTC(), + Model: req.Model, + Done: true, + DoneReason: "load", + Message: api.Message{Role: "assistant"}, } c.JSON(http.StatusOK, resp) return @@ -1356,10 +1369,11 @@ func ChatHandler(c *gin.Context) { loaded.expireTimer.Reset(sessionDuration) resp := api.ChatResponse{ - Model: req.Model, - CreatedAt: time.Now().UTC(), - Message: api.Message{Role: "assistant", Content: r.Content}, - Done: r.Done, + Model: req.Model, + CreatedAt: time.Now().UTC(), + Message: api.Message{Role: "assistant", Content: r.Content}, + Done: r.Done, + DoneReason: r.DoneReason, Metrics: api.Metrics{ PromptEvalCount: r.PromptEvalCount, PromptEvalDuration: r.PromptEvalDuration,