forked from third-party-mirrors/ollama
llama.go: Make batch memory allocation match configuration
Batch size defaults to 512 but is configurable. However, llama.go uses a fixed size buffer, causing crashes is the batch size is increase. This changes the array size to follow the configuration.
This commit is contained in:
parent
5d34320b7c
commit
ed19fad862
@ -209,11 +209,15 @@ func (m *Model) ApplyLoraFromFile(loraPath string, scale float32, baseModelPath
|
||||
}
|
||||
|
||||
type Batch struct {
|
||||
c C.struct_llama_batch
|
||||
c C.struct_llama_batch
|
||||
batchSize int
|
||||
}
|
||||
|
||||
func NewBatch(nTokens int, embd int, maxSeq int) Batch {
|
||||
return Batch{c: C.llama_batch_init(C.int(nTokens), C.int(embd), C.int(maxSeq))}
|
||||
return Batch{
|
||||
c: C.llama_batch_init(C.int(nTokens), C.int(embd), C.int(maxSeq)),
|
||||
batchSize: nTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Batch) NumTokens() int {
|
||||
@ -223,16 +227,16 @@ func (b *Batch) NumTokens() int {
|
||||
// Add adds a token to the batch with the given position for the given
|
||||
// sequence ids, and optionally instructs to include logits.
|
||||
func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) {
|
||||
unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token)
|
||||
unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos)
|
||||
unsafe.Slice(b.c.n_seq_id, 512)[b.c.n_tokens] = C.int(len(seqIds))
|
||||
unsafe.Slice(b.c.token, b.batchSize)[b.c.n_tokens] = C.llama_token(token)
|
||||
unsafe.Slice(b.c.pos, b.batchSize)[b.c.n_tokens] = C.llama_pos(pos)
|
||||
unsafe.Slice(b.c.n_seq_id, b.batchSize)[b.c.n_tokens] = C.int(len(seqIds))
|
||||
|
||||
for i, s := range seqIds {
|
||||
unsafe.Slice((unsafe.Slice(b.c.seq_id, 512)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
|
||||
unsafe.Slice((unsafe.Slice(b.c.seq_id, b.batchSize)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
|
||||
}
|
||||
|
||||
if logits {
|
||||
unsafe.Slice(b.c.logits, 512)[b.c.n_tokens] = 1
|
||||
unsafe.Slice(b.c.logits, b.batchSize)[b.c.n_tokens] = 1
|
||||
}
|
||||
|
||||
b.c.n_tokens += 1
|
||||
@ -243,6 +247,7 @@ func (b *Batch) Clear() {
|
||||
}
|
||||
|
||||
func (b *Batch) Free() {
|
||||
b.batchSize = 0
|
||||
C.llama_batch_free(b.c)
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user