From ed19fad8623807a38a572f18e3f3ceaad38003d4 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 13 Aug 2024 11:18:02 -0700 Subject: [PATCH] llama.go: Make batch memory allocation match configuration Batch size defaults to 512 but is configurable. However, llama.go uses a fixed size buffer, causing crashes is the batch size is increase. This changes the array size to follow the configuration. --- llama/llama.go | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index 24765132..35ac0d4e 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -209,11 +209,15 @@ func (m *Model) ApplyLoraFromFile(loraPath string, scale float32, baseModelPath } type Batch struct { - c C.struct_llama_batch + c C.struct_llama_batch + batchSize int } func NewBatch(nTokens int, embd int, maxSeq int) Batch { - return Batch{c: C.llama_batch_init(C.int(nTokens), C.int(embd), C.int(maxSeq))} + return Batch{ + c: C.llama_batch_init(C.int(nTokens), C.int(embd), C.int(maxSeq)), + batchSize: nTokens, + } } func (b *Batch) NumTokens() int { @@ -223,16 +227,16 @@ func (b *Batch) NumTokens() int { // Add adds a token to the batch with the given position for the given // sequence ids, and optionally instructs to include logits. func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) { - unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token) - unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos) - unsafe.Slice(b.c.n_seq_id, 512)[b.c.n_tokens] = C.int(len(seqIds)) + unsafe.Slice(b.c.token, b.batchSize)[b.c.n_tokens] = C.llama_token(token) + unsafe.Slice(b.c.pos, b.batchSize)[b.c.n_tokens] = C.llama_pos(pos) + unsafe.Slice(b.c.n_seq_id, b.batchSize)[b.c.n_tokens] = C.int(len(seqIds)) for i, s := range seqIds { - unsafe.Slice((unsafe.Slice(b.c.seq_id, 512)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s) + unsafe.Slice((unsafe.Slice(b.c.seq_id, b.batchSize)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s) } if logits { - unsafe.Slice(b.c.logits, 512)[b.c.n_tokens] = 1 + unsafe.Slice(b.c.logits, b.batchSize)[b.c.n_tokens] = 1 } b.c.n_tokens += 1 @@ -243,6 +247,7 @@ func (b *Batch) Clear() { } func (b *Batch) Free() { + b.batchSize = 0 C.llama_batch_free(b.c) }