diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 549f29af..e495b98d 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -38,6 +38,9 @@ type Sequence struct { // channel to send responses over responses chan string + // channel to stop decoding (such as if the remote connection is closed) + quit chan bool + // number of tokens to predict numPredict int @@ -106,6 +109,7 @@ func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence n_prompt_tokens: len(tokens), numPredict: params.numPredict, responses: make(chan string, 1), + quit: make(chan bool, 1), embedding: make(chan []float32, 1), samplingCtx: sc, embeddingOnly: params.embedding, @@ -344,7 +348,11 @@ func (s *Server) processBatch(pieces [][]string) [][]string { truncated := truncateStop(pieces[i], stop) for _, p := range truncated { - seq.responses <- p + select { + case seq.responses <- p: + case <-seq.quit: + break + } } s.removeSequence(i, &pieces, "stop") @@ -356,7 +364,12 @@ func (s *Server) processBatch(pieces [][]string) [][]string { } for _, p := range pieces[i] { - seq.responses <- p + select { + case seq.responses <- p: + case <-seq.quit: + s.removeSequence(i, &pieces, "connection") + break + } } pieces[i] = []string{} @@ -475,12 +488,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { Content: content, }); err != nil { log.Println("Failed to encode result:", err) + close(seq.quit) return } flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "Streaming not supported", http.StatusInternalServerError) + close(seq.quit) return }