Add done_reason
This commit is contained in:
parent
62be2050dd
commit
e117483ef6
@ -98,7 +98,8 @@ type ChatResponse struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Message Message `json:"message"`
|
||||
|
||||
Done bool `json:"done"`
|
||||
Done bool `json:"done"`
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
|
||||
Metrics
|
||||
}
|
||||
@ -265,8 +266,9 @@ type GenerateResponse struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Response string `json:"response"`
|
||||
|
||||
Done bool `json:"done"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
Metrics
|
||||
}
|
||||
|
@ -509,10 +509,13 @@ type ImageData struct {
|
||||
}
|
||||
|
||||
type completion struct {
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stop bool `json:"stop"`
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stop bool `json:"stop"`
|
||||
StoppedEos bool `json:"stopped_eos"`
|
||||
StoppedWord bool `json:"stopped_word"`
|
||||
StoppedLimit bool `json:"stopped_limit"`
|
||||
|
||||
Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
@ -532,6 +535,7 @@ type CompletionRequest struct {
|
||||
type CompletionResponse struct {
|
||||
Content string
|
||||
Done bool
|
||||
DoneReason string
|
||||
PromptEvalCount int
|
||||
PromptEvalDuration time.Duration
|
||||
EvalCount int
|
||||
@ -648,6 +652,8 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
|
||||
return fmt.Errorf("error parsing llm response stream: %s", line)
|
||||
}
|
||||
|
||||
fmt.Println("c", string(evt))
|
||||
|
||||
var c completion
|
||||
if err := json.Unmarshal(evt, &c); err != nil {
|
||||
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
||||
@ -674,8 +680,18 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
|
||||
}
|
||||
|
||||
if c.Stop {
|
||||
var doneReason string
|
||||
switch {
|
||||
case c.StoppedEos:
|
||||
doneReason = "stop"
|
||||
case c.StoppedWord:
|
||||
doneReason = "stop"
|
||||
case c.StoppedLimit:
|
||||
doneReason = "limit"
|
||||
}
|
||||
fn(CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: doneReason,
|
||||
PromptEvalCount: c.Timings.PromptN,
|
||||
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
||||
EvalCount: c.Timings.PredictedN,
|
||||
|
@ -91,7 +91,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc
|
||||
}
|
||||
|
||||
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
|
||||
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
|
||||
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, int, error) {
|
||||
type prompt struct {
|
||||
System string
|
||||
Prompt string
|
||||
@ -138,7 +138,7 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
|
||||
|
||||
p.Response = msg.Content
|
||||
default:
|
||||
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||
return "", 0, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||
}
|
||||
}
|
||||
|
||||
@ -151,7 +151,7 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
|
||||
for i, p := range prompts {
|
||||
tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
prompts[i].tokens = tokens + len(prompts[i].images)*768
|
||||
@ -160,15 +160,17 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
|
||||
// truncate images and prompts starting from the beginning of the list
|
||||
// until either one prompt remains or the total tokens fits the context window
|
||||
// TODO (jmorganca): this doesn't account for the context window room required for the response
|
||||
var required int
|
||||
for {
|
||||
var required int
|
||||
required = 0
|
||||
for _, p := range prompts {
|
||||
required += p.tokens
|
||||
}
|
||||
|
||||
required += 1 // for bos token
|
||||
|
||||
if required <= window {
|
||||
// leave ~1024 tokens for generation
|
||||
if required <= max(1024, window/2) {
|
||||
slog.Debug("prompt now fits in context window", "required", required, "window", window)
|
||||
break
|
||||
}
|
||||
@ -194,7 +196,7 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
|
||||
|
||||
tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
prompts[0].tokens = tokens + len(prompts[0].images)*768
|
||||
@ -212,10 +214,10 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
|
||||
// last prompt should leave the response unrendered (for completion)
|
||||
rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", 0, err
|
||||
}
|
||||
sb.WriteString(rendered)
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
return sb.String(), required, nil
|
||||
}
|
||||
|
@ -192,7 +192,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
|
||||
got, _, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
|
||||
if err != nil {
|
||||
t.Errorf("error = %v", err)
|
||||
}
|
||||
|
@ -234,9 +234,10 @@ func GenerateHandler(c *gin.Context) {
|
||||
// of `raw` mode so we need to check for it too
|
||||
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Model: req.Model,
|
||||
Done: true,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Model: req.Model,
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -289,6 +290,14 @@ func GenerateHandler(c *gin.Context) {
|
||||
prompt = sb.String()
|
||||
}
|
||||
|
||||
tokens, err := loaded.llama.Tokenize(c.Request.Context(), prompt)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
opts.NumPredict = max(opts.NumCtx-len(tokens), 0)
|
||||
|
||||
slog.Debug("generate handler", "prompt", prompt)
|
||||
|
||||
ch := make(chan any)
|
||||
@ -307,10 +316,11 @@ func GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
resp := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: r.Done,
|
||||
Response: r.Content,
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: r.Done,
|
||||
DoneReason: r.DoneReason,
|
||||
Response: r.Content,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
PromptEvalDuration: r.PromptEvalDuration,
|
||||
@ -1219,17 +1229,17 @@ func streamResponse(c *gin.Context, ch chan any) {
|
||||
}
|
||||
|
||||
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
||||
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
|
||||
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, int, error) {
|
||||
encode := func(s string) ([]int, error) {
|
||||
return loaded.llama.Tokenize(ctx, s)
|
||||
}
|
||||
|
||||
prompt, err := ChatPrompt(template, messages, numCtx, encode)
|
||||
prompt, tokens, err := ChatPrompt(template, messages, numCtx, encode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
return prompt, nil
|
||||
return prompt, tokens, nil
|
||||
}
|
||||
|
||||
func ChatHandler(c *gin.Context) {
|
||||
@ -1309,19 +1319,22 @@ func ChatHandler(c *gin.Context) {
|
||||
}, req.Messages...)
|
||||
}
|
||||
|
||||
prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
|
||||
prompt, tokens, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
opts.NumPredict = max(opts.NumCtx-tokens, 0)
|
||||
|
||||
// an empty request loads the model
|
||||
if len(req.Messages) == 0 || prompt == "" {
|
||||
resp := api.ChatResponse{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Model: req.Model,
|
||||
Done: true,
|
||||
Message: api.Message{Role: "assistant"},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Model: req.Model,
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
Message: api.Message{Role: "assistant"},
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
return
|
||||
@ -1356,10 +1369,11 @@ func ChatHandler(c *gin.Context) {
|
||||
loaded.expireTimer.Reset(sessionDuration)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||
Done: r.Done,
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||
Done: r.Done,
|
||||
DoneReason: r.DoneReason,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
PromptEvalDuration: r.PromptEvalDuration,
|
||||
|
Loading…
x
Reference in New Issue
Block a user