diff --git a/llm/dyn_ext_server.go b/llm/dyn_ext_server.go index 832d3c47..cc7f06b0 100644 --- a/llm/dyn_ext_server.go +++ b/llm/dyn_ext_server.go @@ -35,6 +35,7 @@ import ( "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/gpu" + "golang.org/x/exp/slog" ) type dynExtServer struct { @@ -164,6 +165,8 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts return llm, nil } +var ErrPredictTimeout = fmt.Errorf("timed out waiting for next token") + func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error { resp := newExtServerResp(128) defer freeExtServerResp(resp) @@ -237,56 +240,73 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu case <-ctx.Done(): return cancelCompletion(llm, resp) default: - var result C.ext_server_task_result_t - C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result) - json_resp := C.GoString(result.json_resp) - C.dyn_llama_server_release_task_result(llm.s, &result) + // this channel is used to communicate the result of each call, while allowing for a timeout + resultChan := make(chan C.ext_server_task_result_t) + // timeout waiting for a token from this specific call + timeout := time.After(30 * time.Second) - var p prediction - if err := json.Unmarshal([]byte(json_resp), &p); err != nil { - C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp) - if resp.id < 0 { - return fmt.Errorf("error unmarshaling llm prediction response: %w and cancel %s", err, C.GoString(resp.msg)) - } else { - return fmt.Errorf("error unmarshaling llm prediction response: %w", err) + go func() { + var result C.ext_server_task_result_t + C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result) + resultChan <- result + }() + + select { + case result := <-resultChan: + json_resp := C.GoString(result.json_resp) + C.dyn_llama_server_release_task_result(llm.s, &result) + + var p prediction + if err := json.Unmarshal([]byte(json_resp), &p); err != nil { + C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp) + if resp.id < 0 { + return fmt.Errorf("error unmarshaling llm prediction response: %w and cancel %s", err, C.GoString(resp.msg)) + } else { + return fmt.Errorf("error unmarshaling llm prediction response: %w", err) + } } - } - if bool(result.error) && strings.Contains(json_resp, "slot unavailable") { - retryNeeded = true - // task will already be canceled - break out - } + if bool(result.error) && strings.Contains(json_resp, "slot unavailable") { + retryNeeded = true + // task will already be canceled + break out + } - switch { - case strings.TrimSpace(p.Content) == lastToken: - tokenRepeat++ - default: - lastToken = strings.TrimSpace(p.Content) - tokenRepeat = 0 - } + switch { + case strings.TrimSpace(p.Content) == lastToken: + tokenRepeat++ + default: + lastToken = strings.TrimSpace(p.Content) + tokenRepeat = 0 + } - // 30 picked as an arbitrary max token repeat limit, modify as needed - if tokenRepeat > 30 { - slog.Debug("prediction aborted, token repeat limit reached") - return cancelCompletion(llm, resp) - } + // 30 picked as an arbitrary max token repeat limit, modify as needed + if tokenRepeat > 30 { + slog.Debug("prediction aborted, token repeat limit reached") + return cancelCompletion(llm, resp) + } - if p.Content != "" { - fn(PredictResult{ - Content: p.Content, - }) - } + if p.Content != "" { + fn(PredictResult{ + Content: p.Content, + }) + } - if p.Stop || bool(result.stop) { - fn(PredictResult{ - Done: true, - PromptEvalCount: p.Timings.PromptN, - PromptEvalDuration: parseDurationMs(p.Timings.PromptMS), - EvalCount: p.Timings.PredictedN, - EvalDuration: parseDurationMs(p.Timings.PredictedMS), - }) - return nil + if p.Stop || bool(result.stop) { + fn(PredictResult{ + Done: true, + PromptEvalCount: p.Timings.PromptN, + PromptEvalDuration: parseDurationMs(p.Timings.PromptMS), + EvalCount: p.Timings.PredictedN, + EvalDuration: parseDurationMs(p.Timings.PredictedMS), + }) + return nil + } + case <-timeout: + if err := cancelCompletion(llm, resp); err != nil { + slog.Error("failed to cancel completion on predict timeout: ", err) + } + return ErrPredictTimeout } } } diff --git a/server/routes.go b/server/routes.go index a03f39e7..35477e8d 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1362,6 +1362,15 @@ func ChatHandler(c *gin.Context) { Options: opts, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { + if errors.Is(err, llm.ErrPredictTimeout) { + // the loaded runner may be unresponsive, stop it now + if loaded.runner != nil { + loaded.runner.Close() + } + loaded.runner = nil + loaded.Model = nil + loaded.Options = nil + } ch <- gin.H{"error": err.Error()} } }()