From 53b600921e9c5db4d2915a46fe4cdd92b2fedcc5 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Fri, 23 Aug 2024 16:28:38 -0700 Subject: [PATCH] runner.go: Hold mutex for entire time when processing batch It is not safe to hold a mutex only while we are waiting for the condition variable to signal that a new sequence has been added. It's possible that a sequence could be added in the middle of batch processing. For example, if a new sequence is added while Decode() is running, it will get picked up for sampling, despite not having been added to the original batch. This change holds a mutex for the majority of the time when active processing is happening, releasing it only for a brief period each time around the loop. Depending on the workload and the scheduler is may result in unfairness between different requests. However, this was not actually observed in testing. This addresses the correctness issue - better performance and fairness can be achieved with additional improvements in the future. --- llama/runner/runner.go | 311 +++++++++++++++++++++-------------------- 1 file changed, 158 insertions(+), 153 deletions(-) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 29d59432..52087276 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -198,9 +198,6 @@ func incompleteUnicode(token string) bool { } func (s *Server) run(ctx context.Context) { - batch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs)) - defer batch.Free() - // build up stop sequences as we recognize them // TODO (jmorganca): simplify this pieces := make([][]string, s.parallel) @@ -210,160 +207,168 @@ func (s *Server) run(ctx context.Context) { case <-ctx.Done(): return default: - slog.Debug("Processing batch", "seqs", len(s.seqs)) - s.mu.Lock() - for s.allNil() { - s.cond.Wait() // Wait until an item is added - } - s.mu.Unlock() - - for i, seq := range s.seqs { - if seq == nil { - continue - } - - // if past the num predict limit - if seq.numPredict > 0 && seq.numPredicted > seq.numPredict { - seq.doneReason = "limit" - close(seq.responses) - s.lc.KvCacheSeqRm(i, 0, -1) - s.seqs[i] = nil - continue - } - - if seq.nPast+len(seq.tokens) > s.numCtx { - s.shiftContext(i) - } - - if seq.t_start_process_prompt.IsZero() { - seq.t_start_process_prompt = time.Now() - } - - var numTokensProcessed int - for j, t := range seq.tokens { - // todo: make this n_batch - if j >= s.batchSize { - break - } - batch.Add(t, seq.nPast, []int{i}, numTokensProcessed+1 == len(seq.tokens)) - seq.nPast++ - numTokensProcessed++ - } - seq.tokens = seq.tokens[numTokensProcessed:] - seq.iBatch = batch.NumTokens() - 1 - } - - if batch.NumTokens() == 0 { - continue - } - - err := s.lc.Decode(batch) - if err != nil { - slog.Error("failed to decode batch", "error", err) - panic("Failed to decode") - } - - for i, seq := range s.seqs { - if seq == nil { - continue - } - - // don't sample prompt processing - if len(seq.tokens) != 0 { - continue - } - - // if done processing the prompt, generating an embedding and return - if seq.embeddingOnly { - embd := s.lc.GetEmbeddingsSeq(i) - if embd == nil { - embd = s.lc.GetEmbeddingsIth(seq.iBatch) - } - - seq.embedding <- embd - close(seq.embedding) - s.lc.KvCacheSeqRm(i, 0, -1) - s.seqs[i] = nil - continue - } - - // sample a token - // logits := s.lc.GetLogitsIth(ibatch[i]) - // token := s.lc.SampleTokenGreedy(logits) - token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch) - - seq.samplingCtx.Accept(s.lc, token, true) - seq.n_decoded += 1 - if seq.n_decoded == 1 { - seq.t_start_genereration = time.Now() - } - piece := s.model.TokenToPiece(token) - - seq.numPredicted++ - - slog.Debug("sampled", "piece", piece) - - // if it's an end of sequence token, break - // TODO: just end this sequence - if s.model.TokenIsEog(token) { - // TODO: end the sequence instead of quitting the pool - s.lc.KvCacheSeqRm(i, 0, -1) - - // TODO (jmorganca): we should send this back - // as it's important for the /api/generate context - // seq.responses <- piece - - seq.doneReason = "stop" - close(seq.responses) - seq.samplingCtx.Free() - pieces[i] = []string{} - s.seqs[i] = nil - continue - } - - seq.tokens = []int{token} - - pieces[i] = append(pieces[i], piece) - sequence := strings.Join(pieces[i], "") - - if incompleteUnicode(sequence) { - continue - } - - if ok, stop := findStop(sequence, seq.stop); ok { - slog.Info("hit stop token", "stop", seq.stop) - - truncated := truncateStop(pieces[i], stop) - - for _, p := range truncated { - seq.responses <- p - } - - s.lc.KvCacheSeqRm(i, 0, -1) - seq.doneReason = "stop" - close(seq.responses) - seq.samplingCtx.Free() - pieces[i] = []string{} - s.seqs[i] = nil - continue - } - - if containsStopSuffix(sequence, seq.stop) { - continue - } - - for _, p := range pieces[i] { - seq.responses <- p - } - - pieces[i] = []string{} - } - - batch.Clear() + pieces = s.processBatch(pieces) } } } +func (s *Server) processBatch(pieces [][]string) [][]string { + batch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs)) + defer batch.Free() + + s.mu.Lock() + for s.allNil() { + s.cond.Wait() // Wait until an item is added + } + defer s.mu.Unlock() + + slog.Debug("Processing batch", "seqs", len(s.seqs)) + + for i, seq := range s.seqs { + if seq == nil { + continue + } + + // if past the num predict limit + if seq.numPredict > 0 && seq.numPredicted > seq.numPredict { + seq.doneReason = "limit" + close(seq.responses) + s.lc.KvCacheSeqRm(i, 0, -1) + s.seqs[i] = nil + continue + } + + if seq.nPast+len(seq.tokens) > s.numCtx { + s.shiftContext(i) + } + + if seq.t_start_process_prompt.IsZero() { + seq.t_start_process_prompt = time.Now() + } + + var numTokensProcessed int + for j, t := range seq.tokens { + // todo: make this n_batch + if j >= s.batchSize { + break + } + batch.Add(t, seq.nPast, []int{i}, numTokensProcessed+1 == len(seq.tokens)) + seq.nPast++ + numTokensProcessed++ + } + seq.tokens = seq.tokens[numTokensProcessed:] + seq.iBatch = batch.NumTokens() - 1 + } + + if batch.NumTokens() == 0 { + return pieces + } + + err := s.lc.Decode(batch) + if err != nil { + slog.Error("failed to decode batch", "error", err) + panic("Failed to decode") + } + + for i, seq := range s.seqs { + if seq == nil { + continue + } + + // don't sample prompt processing + if len(seq.tokens) != 0 { + continue + } + + // if done processing the prompt, generating an embedding and return + if seq.embeddingOnly { + embd := s.lc.GetEmbeddingsSeq(i) + if embd == nil { + embd = s.lc.GetEmbeddingsIth(seq.iBatch) + } + + seq.embedding <- embd + close(seq.embedding) + s.lc.KvCacheSeqRm(i, 0, -1) + s.seqs[i] = nil + continue + } + + // sample a token + // logits := s.lc.GetLogitsIth(ibatch[i]) + // token := s.lc.SampleTokenGreedy(logits) + token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch) + + seq.samplingCtx.Accept(s.lc, token, true) + seq.n_decoded += 1 + if seq.n_decoded == 1 { + seq.t_start_genereration = time.Now() + } + piece := s.model.TokenToPiece(token) + + seq.numPredicted++ + + slog.Debug("sampled", "piece", piece) + + // if it's an end of sequence token, break + // TODO: just end this sequence + if s.model.TokenIsEog(token) { + // TODO: end the sequence instead of quitting the pool + s.lc.KvCacheSeqRm(i, 0, -1) + + // TODO (jmorganca): we should send this back + // as it's important for the /api/generate context + // seq.responses <- piece + + seq.doneReason = "stop" + close(seq.responses) + seq.samplingCtx.Free() + pieces[i] = []string{} + s.seqs[i] = nil + continue + } + + seq.tokens = []int{token} + + pieces[i] = append(pieces[i], piece) + sequence := strings.Join(pieces[i], "") + + if incompleteUnicode(sequence) { + continue + } + + if ok, stop := findStop(sequence, seq.stop); ok { + slog.Info("hit stop token", "stop", seq.stop) + + truncated := truncateStop(pieces[i], stop) + + for _, p := range truncated { + seq.responses <- p + } + + s.lc.KvCacheSeqRm(i, 0, -1) + seq.doneReason = "stop" + close(seq.responses) + seq.samplingCtx.Free() + pieces[i] = []string{} + s.seqs[i] = nil + continue + } + + if containsStopSuffix(sequence, seq.stop) { + continue + } + + for _, p := range pieces[i] { + seq.responses <- p + } + + pieces[i] = []string{} + } + + return pieces +} + type Options struct { api.Runner