From 6ccd0644e1e09624a6de504cc4d13bd373e90b72 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Mon, 26 Aug 2024 14:20:50 -0700 Subject: [PATCH] runner.go: Fix deadlock if a connection is closed during decoding If a connection is closed while a sequence is being decoded, tokens will continue to be added to the channel without anyone to read them. This will result in the sender blocking, which will in turn block all other decoding and sending for other sequences. This is not limited to just the connection between Ollama and the runner process. If the connection to the Ollama API is closed by the user then Ollama will close the connection to the runner, triggering this issue. --- llama/runner/runner.go | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 549f29af..e495b98d 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -38,6 +38,9 @@ type Sequence struct { // channel to send responses over responses chan string + // channel to stop decoding (such as if the remote connection is closed) + quit chan bool + // number of tokens to predict numPredict int @@ -106,6 +109,7 @@ func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence n_prompt_tokens: len(tokens), numPredict: params.numPredict, responses: make(chan string, 1), + quit: make(chan bool, 1), embedding: make(chan []float32, 1), samplingCtx: sc, embeddingOnly: params.embedding, @@ -344,7 +348,11 @@ func (s *Server) processBatch(pieces [][]string) [][]string { truncated := truncateStop(pieces[i], stop) for _, p := range truncated { - seq.responses <- p + select { + case seq.responses <- p: + case <-seq.quit: + break + } } s.removeSequence(i, &pieces, "stop") @@ -356,7 +364,12 @@ func (s *Server) processBatch(pieces [][]string) [][]string { } for _, p := range pieces[i] { - seq.responses <- p + select { + case seq.responses <- p: + case <-seq.quit: + s.removeSequence(i, &pieces, "connection") + break + } } pieces[i] = []string{} @@ -475,12 +488,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { Content: content, }); err != nil { log.Println("Failed to encode result:", err) + close(seq.quit) return } flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "Streaming not supported", http.StatusInternalServerError) + close(seq.quit) return }