diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 52087276..77d7bdee 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -594,7 +594,7 @@ func main() { nGpuLayers := flag.Int("n-gpu-layers", 0, "Number of layers to offload to GPU") mainGpu := flag.Int("main-gpu", 0, "Main GPU") flashAttention := flag.Bool("flash-attn", false, "Enable flash attention") - numCtx := flag.Int("ctx-size", 2048, "Context (or KV cache) size") + kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size") lpath := flag.String("lora", "", "Path to lora layer file") port := flag.Int("port", 8080, "Port to expose the server on") threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation") @@ -647,7 +647,7 @@ func main() { } server := &Server{ - numCtx: *numCtx, + numCtx: *kvSize / *parallel, batchSize: *batchSize, parallel: *parallel, seqs: make([]*Sequence, *parallel), @@ -669,7 +669,7 @@ func main() { } } - ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention) + ctxParams := llama.NewContextParams(*kvSize, *threads, *flashAttention) server.lc = llama.NewContextWithModel(server.model, ctxParams) if server.model.ShouldAddBOSToken() {