forked from third-party-mirrors/ollama
runner.go: Move pieces[] into sequence
pieces[] is used to cache pending responses and is currently being passed around to different functions. Move it into the sequences where it logically belongs.
This commit is contained in:
parent
6ccd0644e1
commit
d022cfc9e6
@ -35,6 +35,10 @@ type Sequence struct {
|
||||
// tokens left to evaluate
|
||||
tokens []int
|
||||
|
||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||
// TODO (jmorganca): simplify this
|
||||
pendingResponses []string
|
||||
|
||||
// channel to send responses over
|
||||
responses chan string
|
||||
|
||||
@ -105,16 +109,17 @@ func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence
|
||||
}
|
||||
|
||||
return &Sequence{
|
||||
tokens: tokens,
|
||||
n_prompt_tokens: len(tokens),
|
||||
numPredict: params.numPredict,
|
||||
responses: make(chan string, 1),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
samplingCtx: sc,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
tokens: tokens,
|
||||
n_prompt_tokens: len(tokens),
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 1),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
samplingCtx: sc,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
}
|
||||
}
|
||||
|
||||
@ -201,34 +206,30 @@ func incompleteUnicode(token string) bool {
|
||||
return incomplete
|
||||
}
|
||||
|
||||
func (s *Server) removeSequence(seqIndex int, pieces *[][]string, reason string) {
|
||||
func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||
seq := s.seqs[seqIndex]
|
||||
|
||||
seq.doneReason = reason
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
(*pieces)[seqIndex] = []string{}
|
||||
seq.pendingResponses = []string{}
|
||||
seq.samplingCtx.Free()
|
||||
s.lc.KvCacheSeqRm(seqIndex, 0, -1)
|
||||
s.seqs[seqIndex] = nil
|
||||
}
|
||||
|
||||
func (s *Server) run(ctx context.Context) {
|
||||
// build up stop sequences as we recognize them
|
||||
// TODO (jmorganca): simplify this
|
||||
pieces := make([][]string, s.parallel)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
pieces = s.processBatch(pieces)
|
||||
s.processBatch()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processBatch(pieces [][]string) [][]string {
|
||||
func (s *Server) processBatch() {
|
||||
batch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
|
||||
defer batch.Free()
|
||||
|
||||
@ -247,7 +248,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
|
||||
s.removeSequence(i, &pieces, "limit")
|
||||
s.removeSequence(i, "limit")
|
||||
continue
|
||||
}
|
||||
|
||||
@ -274,7 +275,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
|
||||
}
|
||||
|
||||
if batch.NumTokens() == 0 {
|
||||
return pieces
|
||||
return
|
||||
}
|
||||
|
||||
err := s.lc.Decode(batch)
|
||||
@ -301,7 +302,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
|
||||
}
|
||||
|
||||
seq.embedding <- embd
|
||||
s.removeSequence(i, &pieces, "")
|
||||
s.removeSequence(i, "")
|
||||
continue
|
||||
}
|
||||
|
||||
@ -329,14 +330,14 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
|
||||
// seq.responses <- piece
|
||||
|
||||
// TODO: end the sequence instead of quitting the pool
|
||||
s.removeSequence(i, &pieces, "stop")
|
||||
s.removeSequence(i, "stop")
|
||||
continue
|
||||
}
|
||||
|
||||
seq.tokens = []int{token}
|
||||
|
||||
pieces[i] = append(pieces[i], piece)
|
||||
sequence := strings.Join(pieces[i], "")
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
|
||||
if incompleteUnicode(sequence) {
|
||||
continue
|
||||
@ -345,7 +346,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
|
||||
if ok, stop := findStop(sequence, seq.stop); ok {
|
||||
slog.Info("hit stop token", "stop", seq.stop)
|
||||
|
||||
truncated := truncateStop(pieces[i], stop)
|
||||
truncated := truncateStop(seq.pendingResponses, stop)
|
||||
|
||||
for _, p := range truncated {
|
||||
select {
|
||||
@ -355,7 +356,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
|
||||
}
|
||||
}
|
||||
|
||||
s.removeSequence(i, &pieces, "stop")
|
||||
s.removeSequence(i, "stop")
|
||||
continue
|
||||
}
|
||||
|
||||
@ -363,19 +364,17 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, p := range pieces[i] {
|
||||
for _, p := range seq.pendingResponses {
|
||||
select {
|
||||
case seq.responses <- p:
|
||||
case <-seq.quit:
|
||||
s.removeSequence(i, &pieces, "connection")
|
||||
s.removeSequence(i, "connection")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
pieces[i] = []string{}
|
||||
seq.pendingResponses = []string{}
|
||||
}
|
||||
|
||||
return pieces
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
|
Loading…
x
Reference in New Issue
Block a user