llama.go: Advance though tokens when processing multiple batches
If the number of input tokens exceeds the size of the batch, multiple batches will be submitted but they will all contain the first tokens. This processes the input tokens as expected so that each batch has the next set of tokens.
This commit is contained in:
parent
523d84c563
commit
8aa97b5e83
@ -61,12 +61,6 @@ type Sequence struct {
|
||||
n_prompt_tokens int
|
||||
}
|
||||
|
||||
// prompt returns true if the prompt is still being processed
|
||||
// TODO (jmorganca): clean up this logic
|
||||
func (s *Sequence) prompt() bool {
|
||||
return s.nPast < len(s.tokens)-1
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
|
||||
tokens, err := s.lc.Model().Tokenize(prompt, true, true)
|
||||
if err != nil {
|
||||
@ -176,14 +170,17 @@ func (s *Server) run(ctx context.Context) {
|
||||
seq.t_start_process_prompt = time.Now()
|
||||
}
|
||||
|
||||
var numTokensProcessed int
|
||||
for j, t := range seq.tokens {
|
||||
// todo: make this n_batch
|
||||
if j >= s.batchSize {
|
||||
break
|
||||
}
|
||||
batch.Add(t, seq.nPast, []int{i}, !seq.prompt())
|
||||
batch.Add(t, seq.nPast, []int{i}, numTokensProcessed+1 == len(seq.tokens))
|
||||
seq.nPast++
|
||||
numTokensProcessed++
|
||||
}
|
||||
seq.tokens = seq.tokens[numTokensProcessed:]
|
||||
seq.iBatch = batch.NumTokens() - 1
|
||||
}
|
||||
|
||||
@ -199,7 +196,7 @@ func (s *Server) run(ctx context.Context) {
|
||||
}
|
||||
|
||||
// don't sample prompt processing
|
||||
if seq.prompt() {
|
||||
if len(seq.tokens) != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user