runner.go: Hold mutex for entire time when processing batch

It is not safe to hold a mutex only while we are waiting for the
condition variable to signal that a new sequence has been added. It's
possible that a sequence could be added in the middle of batch
processing. For example, if a new sequence is added while Decode()
is running, it will get picked up for sampling, despite not having
been added to the original batch.

This change holds a mutex for the majority of the time when active
processing is happening, releasing it only for a brief period each
time around the loop. Depending on the workload and the scheduler
is may result in unfairness between different requests. However,
this was not actually observed in testing.

This addresses the correctness issue - better performance and fairness
can be achieved with additional improvements in the future.
This commit is contained in:
Jesse Gross 2024-08-23 16:28:38 -07:00 committed by jmorganca
parent 8e1554c91d
commit 53b600921e

View File

@ -198,9 +198,6 @@ func incompleteUnicode(token string) bool {
}
func (s *Server) run(ctx context.Context) {
batch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
defer batch.Free()
// build up stop sequences as we recognize them
// TODO (jmorganca): simplify this
pieces := make([][]string, s.parallel)
@ -210,160 +207,168 @@ func (s *Server) run(ctx context.Context) {
case <-ctx.Done():
return
default:
slog.Debug("Processing batch", "seqs", len(s.seqs))
s.mu.Lock()
for s.allNil() {
s.cond.Wait() // Wait until an item is added
}
s.mu.Unlock()
for i, seq := range s.seqs {
if seq == nil {
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
seq.doneReason = "limit"
close(seq.responses)
s.lc.KvCacheSeqRm(i, 0, -1)
s.seqs[i] = nil
continue
}
if seq.nPast+len(seq.tokens) > s.numCtx {
s.shiftContext(i)
}
if seq.t_start_process_prompt.IsZero() {
seq.t_start_process_prompt = time.Now()
}
var numTokensProcessed int
for j, t := range seq.tokens {
// todo: make this n_batch
if j >= s.batchSize {
break
}
batch.Add(t, seq.nPast, []int{i}, numTokensProcessed+1 == len(seq.tokens))
seq.nPast++
numTokensProcessed++
}
seq.tokens = seq.tokens[numTokensProcessed:]
seq.iBatch = batch.NumTokens() - 1
}
if batch.NumTokens() == 0 {
continue
}
err := s.lc.Decode(batch)
if err != nil {
slog.Error("failed to decode batch", "error", err)
panic("Failed to decode")
}
for i, seq := range s.seqs {
if seq == nil {
continue
}
// don't sample prompt processing
if len(seq.tokens) != 0 {
continue
}
// if done processing the prompt, generating an embedding and return
if seq.embeddingOnly {
embd := s.lc.GetEmbeddingsSeq(i)
if embd == nil {
embd = s.lc.GetEmbeddingsIth(seq.iBatch)
}
seq.embedding <- embd
close(seq.embedding)
s.lc.KvCacheSeqRm(i, 0, -1)
s.seqs[i] = nil
continue
}
// sample a token
// logits := s.lc.GetLogitsIth(ibatch[i])
// token := s.lc.SampleTokenGreedy(logits)
token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch)
seq.samplingCtx.Accept(s.lc, token, true)
seq.n_decoded += 1
if seq.n_decoded == 1 {
seq.t_start_genereration = time.Now()
}
piece := s.model.TokenToPiece(token)
seq.numPredicted++
slog.Debug("sampled", "piece", piece)
// if it's an end of sequence token, break
// TODO: just end this sequence
if s.model.TokenIsEog(token) {
// TODO: end the sequence instead of quitting the pool
s.lc.KvCacheSeqRm(i, 0, -1)
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
seq.doneReason = "stop"
close(seq.responses)
seq.samplingCtx.Free()
pieces[i] = []string{}
s.seqs[i] = nil
continue
}
seq.tokens = []int{token}
pieces[i] = append(pieces[i], piece)
sequence := strings.Join(pieces[i], "")
if incompleteUnicode(sequence) {
continue
}
if ok, stop := findStop(sequence, seq.stop); ok {
slog.Info("hit stop token", "stop", seq.stop)
truncated := truncateStop(pieces[i], stop)
for _, p := range truncated {
seq.responses <- p
}
s.lc.KvCacheSeqRm(i, 0, -1)
seq.doneReason = "stop"
close(seq.responses)
seq.samplingCtx.Free()
pieces[i] = []string{}
s.seqs[i] = nil
continue
}
if containsStopSuffix(sequence, seq.stop) {
continue
}
for _, p := range pieces[i] {
seq.responses <- p
}
pieces[i] = []string{}
}
batch.Clear()
pieces = s.processBatch(pieces)
}
}
}
func (s *Server) processBatch(pieces [][]string) [][]string {
batch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
defer batch.Free()
s.mu.Lock()
for s.allNil() {
s.cond.Wait() // Wait until an item is added
}
defer s.mu.Unlock()
slog.Debug("Processing batch", "seqs", len(s.seqs))
for i, seq := range s.seqs {
if seq == nil {
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
seq.doneReason = "limit"
close(seq.responses)
s.lc.KvCacheSeqRm(i, 0, -1)
s.seqs[i] = nil
continue
}
if seq.nPast+len(seq.tokens) > s.numCtx {
s.shiftContext(i)
}
if seq.t_start_process_prompt.IsZero() {
seq.t_start_process_prompt = time.Now()
}
var numTokensProcessed int
for j, t := range seq.tokens {
// todo: make this n_batch
if j >= s.batchSize {
break
}
batch.Add(t, seq.nPast, []int{i}, numTokensProcessed+1 == len(seq.tokens))
seq.nPast++
numTokensProcessed++
}
seq.tokens = seq.tokens[numTokensProcessed:]
seq.iBatch = batch.NumTokens() - 1
}
if batch.NumTokens() == 0 {
return pieces
}
err := s.lc.Decode(batch)
if err != nil {
slog.Error("failed to decode batch", "error", err)
panic("Failed to decode")
}
for i, seq := range s.seqs {
if seq == nil {
continue
}
// don't sample prompt processing
if len(seq.tokens) != 0 {
continue
}
// if done processing the prompt, generating an embedding and return
if seq.embeddingOnly {
embd := s.lc.GetEmbeddingsSeq(i)
if embd == nil {
embd = s.lc.GetEmbeddingsIth(seq.iBatch)
}
seq.embedding <- embd
close(seq.embedding)
s.lc.KvCacheSeqRm(i, 0, -1)
s.seqs[i] = nil
continue
}
// sample a token
// logits := s.lc.GetLogitsIth(ibatch[i])
// token := s.lc.SampleTokenGreedy(logits)
token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch)
seq.samplingCtx.Accept(s.lc, token, true)
seq.n_decoded += 1
if seq.n_decoded == 1 {
seq.t_start_genereration = time.Now()
}
piece := s.model.TokenToPiece(token)
seq.numPredicted++
slog.Debug("sampled", "piece", piece)
// if it's an end of sequence token, break
// TODO: just end this sequence
if s.model.TokenIsEog(token) {
// TODO: end the sequence instead of quitting the pool
s.lc.KvCacheSeqRm(i, 0, -1)
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
seq.doneReason = "stop"
close(seq.responses)
seq.samplingCtx.Free()
pieces[i] = []string{}
s.seqs[i] = nil
continue
}
seq.tokens = []int{token}
pieces[i] = append(pieces[i], piece)
sequence := strings.Join(pieces[i], "")
if incompleteUnicode(sequence) {
continue
}
if ok, stop := findStop(sequence, seq.stop); ok {
slog.Info("hit stop token", "stop", seq.stop)
truncated := truncateStop(pieces[i], stop)
for _, p := range truncated {
seq.responses <- p
}
s.lc.KvCacheSeqRm(i, 0, -1)
seq.doneReason = "stop"
close(seq.responses)
seq.samplingCtx.Free()
pieces[i] = []string{}
s.seqs[i] = nil
continue
}
if containsStopSuffix(sequence, seq.stop) {
continue
}
for _, p := range pieces[i] {
seq.responses <- p
}
pieces[i] = []string{}
}
return pieces
}
type Options struct {
api.Runner