Compare commits
2 Commits
main
...
brucemacd/
Author | SHA1 | Date | |
---|---|---|---|
|
4ad24c6ca6 | ||
|
de0f833ce3 |
@ -164,6 +164,8 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts
|
|||||||
return llm, nil
|
return llm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrPredictTimeout = fmt.Errorf("timed out waiting for next token")
|
||||||
|
|
||||||
func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
|
func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
|
||||||
resp := newExtServerResp(128)
|
resp := newExtServerResp(128)
|
||||||
defer freeExtServerResp(resp)
|
defer freeExtServerResp(resp)
|
||||||
@ -237,56 +239,73 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return cancelCompletion(llm, resp)
|
return cancelCompletion(llm, resp)
|
||||||
default:
|
default:
|
||||||
var result C.ext_server_task_result_t
|
// this channel is used to communicate the result of each call, while allowing for a timeout
|
||||||
C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result)
|
resultChan := make(chan C.ext_server_task_result_t)
|
||||||
json_resp := C.GoString(result.json_resp)
|
// timeout waiting for a token from this specific call
|
||||||
C.dyn_llama_server_release_task_result(llm.s, &result)
|
timeout := time.After(30 * time.Second)
|
||||||
|
|
||||||
var p prediction
|
go func() {
|
||||||
if err := json.Unmarshal([]byte(json_resp), &p); err != nil {
|
var result C.ext_server_task_result_t
|
||||||
C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
|
C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result)
|
||||||
if resp.id < 0 {
|
resultChan <- result
|
||||||
return fmt.Errorf("error unmarshaling llm prediction response: %w and cancel %s", err, C.GoString(resp.msg))
|
}()
|
||||||
} else {
|
|
||||||
return fmt.Errorf("error unmarshaling llm prediction response: %w", err)
|
select {
|
||||||
|
case result := <-resultChan:
|
||||||
|
json_resp := C.GoString(result.json_resp)
|
||||||
|
C.dyn_llama_server_release_task_result(llm.s, &result)
|
||||||
|
|
||||||
|
var p prediction
|
||||||
|
if err := json.Unmarshal([]byte(json_resp), &p); err != nil {
|
||||||
|
C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
|
||||||
|
if resp.id < 0 {
|
||||||
|
return fmt.Errorf("error unmarshaling llm prediction response: %w and cancel %s", err, C.GoString(resp.msg))
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("error unmarshaling llm prediction response: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if bool(result.error) && strings.Contains(json_resp, "slot unavailable") {
|
if bool(result.error) && strings.Contains(json_resp, "slot unavailable") {
|
||||||
retryNeeded = true
|
retryNeeded = true
|
||||||
// task will already be canceled
|
// task will already be canceled
|
||||||
break out
|
break out
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case strings.TrimSpace(p.Content) == lastToken:
|
case strings.TrimSpace(p.Content) == lastToken:
|
||||||
tokenRepeat++
|
tokenRepeat++
|
||||||
default:
|
default:
|
||||||
lastToken = strings.TrimSpace(p.Content)
|
lastToken = strings.TrimSpace(p.Content)
|
||||||
tokenRepeat = 0
|
tokenRepeat = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// 30 picked as an arbitrary max token repeat limit, modify as needed
|
// 30 picked as an arbitrary max token repeat limit, modify as needed
|
||||||
if tokenRepeat > 30 {
|
if tokenRepeat > 30 {
|
||||||
slog.Debug("prediction aborted, token repeat limit reached")
|
slog.Debug("prediction aborted, token repeat limit reached")
|
||||||
return cancelCompletion(llm, resp)
|
return cancelCompletion(llm, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Content != "" {
|
if p.Content != "" {
|
||||||
fn(PredictResult{
|
fn(PredictResult{
|
||||||
Content: p.Content,
|
Content: p.Content,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Stop || bool(result.stop) {
|
if p.Stop || bool(result.stop) {
|
||||||
fn(PredictResult{
|
fn(PredictResult{
|
||||||
Done: true,
|
Done: true,
|
||||||
PromptEvalCount: p.Timings.PromptN,
|
PromptEvalCount: p.Timings.PromptN,
|
||||||
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
|
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
|
||||||
EvalCount: p.Timings.PredictedN,
|
EvalCount: p.Timings.PredictedN,
|
||||||
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
|
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
case <-timeout:
|
||||||
|
if err := cancelCompletion(llm, resp); err != nil {
|
||||||
|
slog.Error("failed to cancel completion on predict timeout: ", err)
|
||||||
|
}
|
||||||
|
return ErrPredictTimeout
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1362,6 +1362,15 @@ func ChatHandler(c *gin.Context) {
|
|||||||
Options: opts,
|
Options: opts,
|
||||||
}
|
}
|
||||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||||
|
if errors.Is(err, llm.ErrPredictTimeout) {
|
||||||
|
// the loaded runner may be unresponsive, stop it now
|
||||||
|
if loaded.runner != nil {
|
||||||
|
loaded.runner.Close()
|
||||||
|
}
|
||||||
|
loaded.runner = nil
|
||||||
|
loaded.Model = nil
|
||||||
|
loaded.Options = nil
|
||||||
|
}
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user