Compare commits

...

2 Commits

Author SHA1 Message Date
Bruce MacDonald
4ad24c6ca6 Update dyn_ext_server.go 2024-03-15 16:31:49 +00:00
Bruce MacDonald
de0f833ce3 feat: timeout between token generation
- if 30 seconds pass since the last token generation abort the request
- stop the llama thread to flush any accumulated context
2024-03-15 16:19:54 +00:00
2 changed files with 71 additions and 43 deletions

View File

@ -164,6 +164,8 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts
return llm, nil return llm, nil
} }
var ErrPredictTimeout = fmt.Errorf("timed out waiting for next token")
func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error { func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
resp := newExtServerResp(128) resp := newExtServerResp(128)
defer freeExtServerResp(resp) defer freeExtServerResp(resp)
@ -237,56 +239,73 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
case <-ctx.Done(): case <-ctx.Done():
return cancelCompletion(llm, resp) return cancelCompletion(llm, resp)
default: default:
var result C.ext_server_task_result_t // this channel is used to communicate the result of each call, while allowing for a timeout
C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result) resultChan := make(chan C.ext_server_task_result_t)
json_resp := C.GoString(result.json_resp) // timeout waiting for a token from this specific call
C.dyn_llama_server_release_task_result(llm.s, &result) timeout := time.After(30 * time.Second)
var p prediction go func() {
if err := json.Unmarshal([]byte(json_resp), &p); err != nil { var result C.ext_server_task_result_t
C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp) C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result)
if resp.id < 0 { resultChan <- result
return fmt.Errorf("error unmarshaling llm prediction response: %w and cancel %s", err, C.GoString(resp.msg)) }()
} else {
return fmt.Errorf("error unmarshaling llm prediction response: %w", err) select {
case result := <-resultChan:
json_resp := C.GoString(result.json_resp)
C.dyn_llama_server_release_task_result(llm.s, &result)
var p prediction
if err := json.Unmarshal([]byte(json_resp), &p); err != nil {
C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
if resp.id < 0 {
return fmt.Errorf("error unmarshaling llm prediction response: %w and cancel %s", err, C.GoString(resp.msg))
} else {
return fmt.Errorf("error unmarshaling llm prediction response: %w", err)
}
} }
}
if bool(result.error) && strings.Contains(json_resp, "slot unavailable") { if bool(result.error) && strings.Contains(json_resp, "slot unavailable") {
retryNeeded = true retryNeeded = true
// task will already be canceled // task will already be canceled
break out break out
} }
switch { switch {
case strings.TrimSpace(p.Content) == lastToken: case strings.TrimSpace(p.Content) == lastToken:
tokenRepeat++ tokenRepeat++
default: default:
lastToken = strings.TrimSpace(p.Content) lastToken = strings.TrimSpace(p.Content)
tokenRepeat = 0 tokenRepeat = 0
} }
// 30 picked as an arbitrary max token repeat limit, modify as needed // 30 picked as an arbitrary max token repeat limit, modify as needed
if tokenRepeat > 30 { if tokenRepeat > 30 {
slog.Debug("prediction aborted, token repeat limit reached") slog.Debug("prediction aborted, token repeat limit reached")
return cancelCompletion(llm, resp) return cancelCompletion(llm, resp)
} }
if p.Content != "" { if p.Content != "" {
fn(PredictResult{ fn(PredictResult{
Content: p.Content, Content: p.Content,
}) })
} }
if p.Stop || bool(result.stop) { if p.Stop || bool(result.stop) {
fn(PredictResult{ fn(PredictResult{
Done: true, Done: true,
PromptEvalCount: p.Timings.PromptN, PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS), PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
EvalCount: p.Timings.PredictedN, EvalCount: p.Timings.PredictedN,
EvalDuration: parseDurationMs(p.Timings.PredictedMS), EvalDuration: parseDurationMs(p.Timings.PredictedMS),
}) })
return nil return nil
}
case <-timeout:
if err := cancelCompletion(llm, resp); err != nil {
slog.Error("failed to cancel completion on predict timeout: ", err)
}
return ErrPredictTimeout
} }
} }
} }

View File

@ -1362,6 +1362,15 @@ func ChatHandler(c *gin.Context) {
Options: opts, Options: opts,
} }
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
if errors.Is(err, llm.ErrPredictTimeout) {
// the loaded runner may be unresponsive, stop it now
if loaded.runner != nil {
loaded.runner.Close()
}
loaded.runner = nil
loaded.Model = nil
loaded.Options = nil
}
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()