Add done_reason

This commit is contained in:
jmorganca 2024-04-22 09:30:19 -04:00
parent 62be2050dd
commit e117483ef6
5 changed files with 70 additions and 36 deletions

View File

@ -98,7 +98,8 @@ type ChatResponse struct {
CreatedAt time.Time `json:"created_at"`
Message Message `json:"message"`
Done bool `json:"done"`
Done bool `json:"done"`
DoneReason string `json:"done_reason,omitempty"`
Metrics
}
@ -265,8 +266,9 @@ type GenerateResponse struct {
CreatedAt time.Time `json:"created_at"`
Response string `json:"response"`
Done bool `json:"done"`
Context []int `json:"context,omitempty"`
Done bool `json:"done"`
DoneReason string `json:"done_reason,omitempty"`
Context []int `json:"context,omitempty"`
Metrics
}

View File

@ -509,10 +509,13 @@ type ImageData struct {
}
type completion struct {
Content string `json:"content"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stop bool `json:"stop"`
Content string `json:"content"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stop bool `json:"stop"`
StoppedEos bool `json:"stopped_eos"`
StoppedWord bool `json:"stopped_word"`
StoppedLimit bool `json:"stopped_limit"`
Timings struct {
PredictedN int `json:"predicted_n"`
@ -532,6 +535,7 @@ type CompletionRequest struct {
type CompletionResponse struct {
Content string
Done bool
DoneReason string
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
@ -648,6 +652,8 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
return fmt.Errorf("error parsing llm response stream: %s", line)
}
fmt.Println("c", string(evt))
var c completion
if err := json.Unmarshal(evt, &c); err != nil {
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
@ -674,8 +680,18 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
}
if c.Stop {
var doneReason string
switch {
case c.StoppedEos:
doneReason = "stop"
case c.StoppedWord:
doneReason = "stop"
case c.StoppedLimit:
doneReason = "limit"
}
fn(CompletionResponse{
Done: true,
DoneReason: doneReason,
PromptEvalCount: c.Timings.PromptN,
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
EvalCount: c.Timings.PredictedN,

View File

@ -91,7 +91,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc
}
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, int, error) {
type prompt struct {
System string
Prompt string
@ -138,7 +138,7 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
p.Response = msg.Content
default:
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
return "", 0, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
@ -151,7 +151,7 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
for i, p := range prompts {
tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
if err != nil {
return "", err
return "", 0, err
}
prompts[i].tokens = tokens + len(prompts[i].images)*768
@ -160,15 +160,17 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
// truncate images and prompts starting from the beginning of the list
// until either one prompt remains or the total tokens fits the context window
// TODO (jmorganca): this doesn't account for the context window room required for the response
var required int
for {
var required int
required = 0
for _, p := range prompts {
required += p.tokens
}
required += 1 // for bos token
if required <= window {
// leave ~1024 tokens for generation
if required <= max(1024, window/2) {
slog.Debug("prompt now fits in context window", "required", required, "window", window)
break
}
@ -194,7 +196,7 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
if err != nil {
return "", err
return "", 0, err
}
prompts[0].tokens = tokens + len(prompts[0].images)*768
@ -212,10 +214,10 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
// last prompt should leave the response unrendered (for completion)
rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
if err != nil {
return "", err
return "", 0, err
}
sb.WriteString(rendered)
}
return sb.String(), nil
return sb.String(), required, nil
}

View File

@ -192,7 +192,7 @@ func TestChatPrompt(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
got, _, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
if err != nil {
t.Errorf("error = %v", err)
}

View File

@ -234,9 +234,10 @@ func GenerateHandler(c *gin.Context) {
// of `raw` mode so we need to check for it too
if req.Prompt == "" && req.Template == "" && req.System == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
DoneReason: "load",
})
return
}
@ -289,6 +290,14 @@ func GenerateHandler(c *gin.Context) {
prompt = sb.String()
}
tokens, err := loaded.llama.Tokenize(c.Request.Context(), prompt)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
opts.NumPredict = max(opts.NumCtx-len(tokens), 0)
slog.Debug("generate handler", "prompt", prompt)
ch := make(chan any)
@ -307,10 +316,11 @@ func GenerateHandler(c *gin.Context) {
}
resp := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: r.Done,
Response: r.Content,
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: r.Done,
DoneReason: r.DoneReason,
Response: r.Content,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
@ -1219,17 +1229,17 @@ func streamResponse(c *gin.Context, ch chan any) {
}
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, int, error) {
encode := func(s string) ([]int, error) {
return loaded.llama.Tokenize(ctx, s)
}
prompt, err := ChatPrompt(template, messages, numCtx, encode)
prompt, tokens, err := ChatPrompt(template, messages, numCtx, encode)
if err != nil {
return "", err
return "", 0, err
}
return prompt, nil
return prompt, tokens, nil
}
func ChatHandler(c *gin.Context) {
@ -1309,19 +1319,22 @@ func ChatHandler(c *gin.Context) {
}, req.Messages...)
}
prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
prompt, tokens, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
opts.NumPredict = max(opts.NumCtx-tokens, 0)
// an empty request loads the model
if len(req.Messages) == 0 || prompt == "" {
resp := api.ChatResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
Message: api.Message{Role: "assistant"},
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
DoneReason: "load",
Message: api.Message{Role: "assistant"},
}
c.JSON(http.StatusOK, resp)
return
@ -1356,10 +1369,11 @@ func ChatHandler(c *gin.Context) {
loaded.expireTimer.Reset(sessionDuration)
resp := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
DoneReason: r.DoneReason,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,