diff --git a/llama/llama.go b/llama/llama.go index 2fb19ae7..89943380 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -315,20 +315,30 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float type Batch struct { c C.struct_llama_batch batchSize int + maxSeq int embedSize int } -// Creates a new batch for either word tokens if embed is 0 or -// image embeddings if embed is specified. Batches cannot contain -// both types at the same time -func NewBatch(nTokens int, embed int, maxSeq int) *Batch { +// Creates a new batch for either word tokens or image embeddings (if embedSize is non-zero). +// Batches cannot contain both types at the same time. batchSize is the maximum number of entries +// that can be added per sequence +func NewBatch(batchSize int, maxSeq int, embedSize int) *Batch { return &Batch{ - c: C.llama_batch_init(C.int(nTokens), C.int(embed), C.int(maxSeq)), - batchSize: nTokens, - embedSize: embed, + c: C.llama_batch_init(C.int(batchSize*maxSeq), C.int(embedSize), C.int(maxSeq)), + batchSize: batchSize, + maxSeq: maxSeq, + embedSize: embedSize, } } +func (b *Batch) Size() int { + return b.batchSize +} + +func (b *Batch) allocSize() int { + return b.batchSize * b.maxSeq +} + func (b *Batch) NumTokens() int { return int(b.c.n_tokens) } @@ -341,21 +351,21 @@ func (b *Batch) IsEmbedding() bool { // when the batch was initialized. The other argument will be ignored. Adds to the // batch with the given position for the given sequence ids, and optionally instructs // to include logits. -func (b *Batch) Add(token int, embed []float32, pos int, seqIds []int, logits bool) { +func (b *Batch) Add(token int, embed []float32, pos int, logits bool, seqIds ...int) { if !b.IsEmbedding() { - unsafe.Slice(b.c.token, b.batchSize)[b.c.n_tokens] = C.llama_token(token) + unsafe.Slice(b.c.token, b.allocSize())[b.c.n_tokens] = C.llama_token(token) } else { - copy(unsafe.Slice((*float32)(b.c.embd), b.batchSize*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed) + copy(unsafe.Slice((*float32)(b.c.embd), b.allocSize()*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed) } - unsafe.Slice(b.c.pos, b.batchSize)[b.c.n_tokens] = C.llama_pos(pos) - unsafe.Slice(b.c.n_seq_id, b.batchSize)[b.c.n_tokens] = C.int(len(seqIds)) + unsafe.Slice(b.c.pos, b.allocSize())[b.c.n_tokens] = C.llama_pos(pos) + unsafe.Slice(b.c.n_seq_id, b.allocSize())[b.c.n_tokens] = C.int(len(seqIds)) for i, s := range seqIds { - unsafe.Slice((unsafe.Slice(b.c.seq_id, b.batchSize)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s) + unsafe.Slice((unsafe.Slice(b.c.seq_id, b.allocSize())[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s) } if logits { - unsafe.Slice(b.c.logits, b.batchSize)[b.c.n_tokens] = 1 + unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 1 } b.c.n_tokens += 1 diff --git a/llama/runner/image.go b/llama/runner/image.go index 3b562186..ee76f47a 100644 --- a/llama/runner/image.go +++ b/llama/runner/image.go @@ -89,6 +89,23 @@ func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte, aspect return embed } +func (c *ImageContext) BatchSize(configuredBatchSize int) int { + // If images are not supported, we don't need to allocate embedding batches + if c == nil { + return 0 + } + + // Mllama maps an image to 1 embedding token (llava creates many tokens) + // and doesn't support more than a single image per request. + // The embeddings are large (100 MB), so allocating a big batch can fail + // on some systems + if c.mllama != nil { + return 1 + } + + return configuredBatchSize +} + func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int { if c != nil && c.mllama != nil { return c.mllama.EmbedSize(llamaContext) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index a7e0e3b0..041bafb3 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -211,6 +211,7 @@ type Server struct { // required for image embeddings image *ImageContext + // TODO (jmorganca): make this n_batch batchSize int // parallel is the number of parallel requests to handle @@ -302,13 +303,19 @@ func (s *Server) removeSequence(seqIndex int, reason string) { func (s *Server) run(ctx context.Context) { s.ready.Wait() - // logically these batches are used only within the context of processBatch + // Logically these batches are used only within the context of processBatch // but it is better for performance to allocate them once here - tokenBatch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs)) + tokenBatch := llama.NewBatch(s.batchSize, len(s.seqs), 0) defer tokenBatch.Free() - embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.image.EmbedSize(s.lc), len(s.seqs)) - defer embedBatch.Free() + var embedBatch *llama.Batch + embedBatchSize := s.image.BatchSize(s.batchSize) + if embedBatchSize != 0 { + embedBatch = llama.NewBatch(embedBatchSize, len(s.seqs), s.image.EmbedSize(s.lc)) + defer embedBatch.Free() + } else { + embedBatch = &llama.Batch{} + } for { select { @@ -378,13 +385,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) break } - // todo: make this n_batch - if i >= s.batchSize { + if i >= batch.Size() { break } crossAttention = seq.crossAttention - batch.Add(input.token, input.embed, seq.numPast, []int{seq.cache.Id}, numInputsProcessed+1 == len(seq.inputs)) + batch.Add(input.token, input.embed, seq.numPast, numInputsProcessed+1 == len(seq.inputs), seq.cache.Id) seq.numPast++ numInputsProcessed++ }