diff --git a/llama/runner/runner.go b/llama/runner/runner.go index e495b98d..56d60ec8 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -35,6 +35,10 @@ type Sequence struct { // tokens left to evaluate tokens []int + // tokens that have been generated but not returned yet (e.g. for stop sequences) + // TODO (jmorganca): simplify this + pendingResponses []string + // channel to send responses over responses chan string @@ -105,16 +109,17 @@ func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence } return &Sequence{ - tokens: tokens, - n_prompt_tokens: len(tokens), - numPredict: params.numPredict, - responses: make(chan string, 1), - quit: make(chan bool, 1), - embedding: make(chan []float32, 1), - samplingCtx: sc, - embeddingOnly: params.embedding, - stop: params.stop, - numKeep: params.numKeep, + tokens: tokens, + n_prompt_tokens: len(tokens), + numPredict: params.numPredict, + pendingResponses: make([]string, 0), + responses: make(chan string, 1), + quit: make(chan bool, 1), + embedding: make(chan []float32, 1), + samplingCtx: sc, + embeddingOnly: params.embedding, + stop: params.stop, + numKeep: params.numKeep, } } @@ -201,34 +206,30 @@ func incompleteUnicode(token string) bool { return incomplete } -func (s *Server) removeSequence(seqIndex int, pieces *[][]string, reason string) { +func (s *Server) removeSequence(seqIndex int, reason string) { seq := s.seqs[seqIndex] seq.doneReason = reason close(seq.responses) close(seq.embedding) - (*pieces)[seqIndex] = []string{} + seq.pendingResponses = []string{} seq.samplingCtx.Free() s.lc.KvCacheSeqRm(seqIndex, 0, -1) s.seqs[seqIndex] = nil } func (s *Server) run(ctx context.Context) { - // build up stop sequences as we recognize them - // TODO (jmorganca): simplify this - pieces := make([][]string, s.parallel) - for { select { case <-ctx.Done(): return default: - pieces = s.processBatch(pieces) + s.processBatch() } } } -func (s *Server) processBatch(pieces [][]string) [][]string { +func (s *Server) processBatch() { batch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs)) defer batch.Free() @@ -247,7 +248,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string { // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted > seq.numPredict { - s.removeSequence(i, &pieces, "limit") + s.removeSequence(i, "limit") continue } @@ -274,7 +275,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string { } if batch.NumTokens() == 0 { - return pieces + return } err := s.lc.Decode(batch) @@ -301,7 +302,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string { } seq.embedding <- embd - s.removeSequence(i, &pieces, "") + s.removeSequence(i, "") continue } @@ -329,14 +330,14 @@ func (s *Server) processBatch(pieces [][]string) [][]string { // seq.responses <- piece // TODO: end the sequence instead of quitting the pool - s.removeSequence(i, &pieces, "stop") + s.removeSequence(i, "stop") continue } seq.tokens = []int{token} - pieces[i] = append(pieces[i], piece) - sequence := strings.Join(pieces[i], "") + seq.pendingResponses = append(seq.pendingResponses, piece) + sequence := strings.Join(seq.pendingResponses, "") if incompleteUnicode(sequence) { continue @@ -345,7 +346,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string { if ok, stop := findStop(sequence, seq.stop); ok { slog.Info("hit stop token", "stop", seq.stop) - truncated := truncateStop(pieces[i], stop) + truncated := truncateStop(seq.pendingResponses, stop) for _, p := range truncated { select { @@ -355,7 +356,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string { } } - s.removeSequence(i, &pieces, "stop") + s.removeSequence(i, "stop") continue } @@ -363,19 +364,17 @@ func (s *Server) processBatch(pieces [][]string) [][]string { continue } - for _, p := range pieces[i] { + for _, p := range seq.pendingResponses { select { case seq.responses <- p: case <-seq.quit: - s.removeSequence(i, &pieces, "connection") + s.removeSequence(i, "connection") break } } - pieces[i] = []string{} + seq.pendingResponses = []string{} } - - return pieces } type Options struct {