diff --git a/llama/llama.go b/llama/llama.go index 1315fbe2..92adef86 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -429,7 +429,9 @@ func NewSamplingContext(params SamplingParams) *SamplingContext { } func (s *SamplingContext) Free() { - C.llama_sampling_cfree(s.c) + if s.c != nil { + C.llama_sampling_cfree(s.c) + } } func (s *SamplingContext) Reset() { diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 77d7bdee..549f29af 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -197,6 +197,18 @@ func incompleteUnicode(token string) bool { return incomplete } +func (s *Server) removeSequence(seqIndex int, pieces *[][]string, reason string) { + seq := s.seqs[seqIndex] + + seq.doneReason = reason + close(seq.responses) + close(seq.embedding) + (*pieces)[seqIndex] = []string{} + seq.samplingCtx.Free() + s.lc.KvCacheSeqRm(seqIndex, 0, -1) + s.seqs[seqIndex] = nil +} + func (s *Server) run(ctx context.Context) { // build up stop sequences as we recognize them // TODO (jmorganca): simplify this @@ -231,10 +243,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string { // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted > seq.numPredict { - seq.doneReason = "limit" - close(seq.responses) - s.lc.KvCacheSeqRm(i, 0, -1) - s.seqs[i] = nil + s.removeSequence(i, &pieces, "limit") continue } @@ -288,9 +297,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string { } seq.embedding <- embd - close(seq.embedding) - s.lc.KvCacheSeqRm(i, 0, -1) - s.seqs[i] = nil + s.removeSequence(i, &pieces, "") continue } @@ -313,18 +320,12 @@ func (s *Server) processBatch(pieces [][]string) [][]string { // if it's an end of sequence token, break // TODO: just end this sequence if s.model.TokenIsEog(token) { - // TODO: end the sequence instead of quitting the pool - s.lc.KvCacheSeqRm(i, 0, -1) - // TODO (jmorganca): we should send this back // as it's important for the /api/generate context // seq.responses <- piece - seq.doneReason = "stop" - close(seq.responses) - seq.samplingCtx.Free() - pieces[i] = []string{} - s.seqs[i] = nil + // TODO: end the sequence instead of quitting the pool + s.removeSequence(i, &pieces, "stop") continue } @@ -346,12 +347,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string { seq.responses <- p } - s.lc.KvCacheSeqRm(i, 0, -1) - seq.doneReason = "stop" - close(seq.responses) - seq.samplingCtx.Free() - pieces[i] = []string{} - s.seqs[i] = nil + s.removeSequence(i, &pieces, "stop") continue }