runner.go: Move pieces[] into sequence

pieces[] is used to cache pending responses and is currently being
passed around to different functions. Move it into the sequences
where it logically belongs.
This commit is contained in:
Jesse Gross 2024-08-27 10:24:33 -07:00 committed by jmorganca
parent 6ccd0644e1
commit d022cfc9e6

View File

@ -35,6 +35,10 @@ type Sequence struct {
// tokens left to evaluate
tokens []int
// tokens that have been generated but not returned yet (e.g. for stop sequences)
// TODO (jmorganca): simplify this
pendingResponses []string
// channel to send responses over
responses chan string
@ -105,16 +109,17 @@ func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence
}
return &Sequence{
tokens: tokens,
n_prompt_tokens: len(tokens),
numPredict: params.numPredict,
responses: make(chan string, 1),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
samplingCtx: sc,
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
tokens: tokens,
n_prompt_tokens: len(tokens),
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 1),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
samplingCtx: sc,
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
}
}
@ -201,34 +206,30 @@ func incompleteUnicode(token string) bool {
return incomplete
}
func (s *Server) removeSequence(seqIndex int, pieces *[][]string, reason string) {
func (s *Server) removeSequence(seqIndex int, reason string) {
seq := s.seqs[seqIndex]
seq.doneReason = reason
close(seq.responses)
close(seq.embedding)
(*pieces)[seqIndex] = []string{}
seq.pendingResponses = []string{}
seq.samplingCtx.Free()
s.lc.KvCacheSeqRm(seqIndex, 0, -1)
s.seqs[seqIndex] = nil
}
func (s *Server) run(ctx context.Context) {
// build up stop sequences as we recognize them
// TODO (jmorganca): simplify this
pieces := make([][]string, s.parallel)
for {
select {
case <-ctx.Done():
return
default:
pieces = s.processBatch(pieces)
s.processBatch()
}
}
}
func (s *Server) processBatch(pieces [][]string) [][]string {
func (s *Server) processBatch() {
batch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
defer batch.Free()
@ -247,7 +248,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
s.removeSequence(i, &pieces, "limit")
s.removeSequence(i, "limit")
continue
}
@ -274,7 +275,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
}
if batch.NumTokens() == 0 {
return pieces
return
}
err := s.lc.Decode(batch)
@ -301,7 +302,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
}
seq.embedding <- embd
s.removeSequence(i, &pieces, "")
s.removeSequence(i, "")
continue
}
@ -329,14 +330,14 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
// seq.responses <- piece
// TODO: end the sequence instead of quitting the pool
s.removeSequence(i, &pieces, "stop")
s.removeSequence(i, "stop")
continue
}
seq.tokens = []int{token}
pieces[i] = append(pieces[i], piece)
sequence := strings.Join(pieces[i], "")
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")
if incompleteUnicode(sequence) {
continue
@ -345,7 +346,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
if ok, stop := findStop(sequence, seq.stop); ok {
slog.Info("hit stop token", "stop", seq.stop)
truncated := truncateStop(pieces[i], stop)
truncated := truncateStop(seq.pendingResponses, stop)
for _, p := range truncated {
select {
@ -355,7 +356,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
}
}
s.removeSequence(i, &pieces, "stop")
s.removeSequence(i, "stop")
continue
}
@ -363,19 +364,17 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
continue
}
for _, p := range pieces[i] {
for _, p := range seq.pendingResponses {
select {
case seq.responses <- p:
case <-seq.quit:
s.removeSequence(i, &pieces, "connection")
s.removeSequence(i, "connection")
break
}
}
pieces[i] = []string{}
seq.pendingResponses = []string{}
}
return pieces
}
type Options struct {