llama.go: Make batch memory allocation match configuration

Batch size defaults to 512 but is configurable. However, llama.go uses
a fixed size buffer, causing crashes is the batch size is increase.
This changes the array size to follow the configuration.
This commit is contained in:
Jesse Gross 2024-08-13 11:18:02 -07:00 committed by jmorganca
parent 5d34320b7c
commit ed19fad862

View File

@ -209,11 +209,15 @@ func (m *Model) ApplyLoraFromFile(loraPath string, scale float32, baseModelPath
}
type Batch struct {
c C.struct_llama_batch
c C.struct_llama_batch
batchSize int
}
func NewBatch(nTokens int, embd int, maxSeq int) Batch {
return Batch{c: C.llama_batch_init(C.int(nTokens), C.int(embd), C.int(maxSeq))}
return Batch{
c: C.llama_batch_init(C.int(nTokens), C.int(embd), C.int(maxSeq)),
batchSize: nTokens,
}
}
func (b *Batch) NumTokens() int {
@ -223,16 +227,16 @@ func (b *Batch) NumTokens() int {
// Add adds a token to the batch with the given position for the given
// sequence ids, and optionally instructs to include logits.
func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) {
unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token)
unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos)
unsafe.Slice(b.c.n_seq_id, 512)[b.c.n_tokens] = C.int(len(seqIds))
unsafe.Slice(b.c.token, b.batchSize)[b.c.n_tokens] = C.llama_token(token)
unsafe.Slice(b.c.pos, b.batchSize)[b.c.n_tokens] = C.llama_pos(pos)
unsafe.Slice(b.c.n_seq_id, b.batchSize)[b.c.n_tokens] = C.int(len(seqIds))
for i, s := range seqIds {
unsafe.Slice((unsafe.Slice(b.c.seq_id, 512)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
unsafe.Slice((unsafe.Slice(b.c.seq_id, b.batchSize)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
}
if logits {
unsafe.Slice(b.c.logits, 512)[b.c.n_tokens] = 1
unsafe.Slice(b.c.logits, b.batchSize)[b.c.n_tokens] = 1
}
b.c.n_tokens += 1
@ -243,6 +247,7 @@ func (b *Batch) Clear() {
}
func (b *Batch) Free() {
b.batchSize = 0
C.llama_batch_free(b.c)
}