From 30cc6cab23dea167d2c0dae735cd2af7e5dc99da Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 31 Oct 2024 08:36:32 +1100 Subject: [PATCH 1/5] feat: allow setting KV cache type --- cmd/cmd.go | 2 + docs/faq.md | 34 ++++++++++-- envconfig/config.go | 6 +++ llama/llama.go | 29 ++++++++++- llama/runner/runner.go | 8 ++- llm/memory.go | 48 ++++++++++++++++- llm/memory_test.go | 116 ++++++++++++++++++++++++++++++++++++++++- llm/server.go | 91 +++++++++++++++++++++++++++++--- 8 files changed, 317 insertions(+), 17 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index b8c9c640..d3f8fd56 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1482,6 +1482,8 @@ func NewCLI() *cobra.Command { envVars["OLLAMA_SCHED_SPREAD"], envVars["OLLAMA_TMPDIR"], envVars["OLLAMA_FLASH_ATTENTION"], + envVars["OLLAMA_CACHE_TYPE_K"], + envVars["OLLAMA_CACHE_TYPE_V"], envVars["OLLAMA_LLM_LIBRARY"], envVars["OLLAMA_GPU_OVERHEAD"], envVars["OLLAMA_LOAD_TIMEOUT"], diff --git a/docs/faq.md b/docs/faq.md index 0dbbb3ff..14022a43 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -151,7 +151,7 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e Ollama runs an HTTP server and can be exposed using a proxy server such as Nginx. To do so, configure the proxy to forward requests and optionally set required headers (if not exposing Ollama on the network). For example, with Nginx: -``` +```nginx server { listen 80; server_name example.com; # Replace with your domain or IP @@ -164,7 +164,7 @@ server { ## How can I use Ollama with ngrok? -Ollama can be accessed using a range of tools for tunneling tools. For example with Ngrok: +Ollama can be accessed using a range of tools for tunnelling tools. For example with Ngrok: ```shell ngrok http 11434 --host-header="localhost:11434" @@ -285,4 +285,32 @@ Note: Windows with Radeon GPUs currently default to 1 model maximum due to limit ## How does Ollama load models on multiple GPUs? -Installing multiple GPUs of the same brand can be a great way to increase your available VRAM to load larger models. When you load a new model, Ollama evaluates the required VRAM for the model against what is currently available. If the model will entirely fit on any single GPU, Ollama will load the model on that GPU. This typically provides the best performance as it reduces the amount of data transfering across the PCI bus during inference. If the model does not fit entirely on one GPU, then it will be spread across all the available GPUs. +Installing multiple GPUs of the same brand can be a great way to increase your available VRAM to load larger models. When you load a new model, Ollama evaluates the required VRAM for the model against what is currently available. If the model will entirely fit on any single GPU, Ollama will load the model on that GPU. This typically provides the best performance as it reduces the amount of data transferring across the PCI bus during inference. If the model does not fit entirely on one GPU, then it will be spread across all the available GPUs. + +## How can I enable Flash Attention? + +Flash Attention is a feature of most (but not all) modern models that can significantly reduce memory usage as the context size grows. To enable Flash Attention, set the `OLLAMA_FLASH_ATTENTION` environment variable to `1` when starting the Ollama server. + +> Note: If you're using an uncommon quantization type with CUDA, you may benefit from build Ollama with `LLAMA_CUDA_FA_ALL_QUANTS=1` to make llama.cpp build all flash attention quantization types. + +## How can I set the quantization type for the K/V cache? + +The K/V context cache can be quantized to significantly reduce memory usage when Flash Attention is enabled. + +To use quantized K/V cache with Ollama you can set the following environment variables: + +- `OLLAMA_CACHE_TYPE_K` - The quantization type for the key cache. Default is `f16`. +- `OLLAMA_CACHE_TYPE_V` - The quantization type for the value cache. Default is `f16`. + +> Note: Currently this is a global option - meaning all models will run with the specified quantization type. + +There are [a number of quantization types available](https://github.com/ggerganov/llama.cpp/pull/7527), the most commonly used are: + +- `f32` - full precision and memory usage. +- `f16` - high precision and memory usage (default). +- `q8_0` - 8-bit quantization, uses approximately 1/2 the memory of `f16` with a very small loss in precision, this usually has no noticeable impact on the model's quality (recommended if not using f16). +- `q4_0` - 4-bit quantization, uses approximately 1/4 the memory of `f16` with a small-medium loss in precision that may be more noticeable at higher context sizes. + +How much the cache quantization impacts the model's response quality will depend on the model and the task. Models have a high GQA count (e.g. Qwen2) may see a larger impact on precision from quantization than models with a low GQA count. + +You may need to experiment with different quantization types to find the best balance between memory usage and quality. diff --git a/envconfig/config.go b/envconfig/config.go index e80c67ba..73271f86 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -153,6 +153,10 @@ var ( Debug = Bool("OLLAMA_DEBUG") // FlashAttention enables the experimental flash attention feature. FlashAttention = Bool("OLLAMA_FLASH_ATTENTION") + // CacheTypeK is the quantization type for the K/V cache keys. + CacheTypeK = String("OLLAMA_CACHE_TYPE_K") + // CacheTypeV is the quantization type for the K/V cache values. + CacheTypeV = String("OLLAMA_CACHE_TYPE_V") // NoHistory disables readline history. NoHistory = Bool("OLLAMA_NOHISTORY") // NoPrune disables pruning of model blobs on startup. @@ -234,6 +238,8 @@ func AsMap() map[string]EnvVar { ret := map[string]EnvVar{ "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"}, + "OLLAMA_CACHE_TYPE_K": {"OLLAMA_CACHE_TYPE_K", CacheTypeK(), "Type of cache for keys (default: f16)"}, + "OLLAMA_CACHE_TYPE_V": {"OLLAMA_CACHE_TYPE_V", CacheTypeV(), "Type of cache for values (default: f16)"}, "OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"}, "OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"}, diff --git a/llama/llama.go b/llama/llama.go index 2fb19ae7..38186b2e 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -137,7 +137,7 @@ type ContextParams struct { c C.struct_llama_context_params } -func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool) ContextParams { +func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, cacheTypeK string, cacheTypeV string) ContextParams { params := C.llama_context_default_params() params.n_ctx = C.uint(numCtx) params.n_batch = C.uint(batchSize) @@ -146,6 +146,9 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla params.n_threads_batch = params.n_threads params.embeddings = C.bool(true) params.flash_attn = C.bool(flashAttention) + params.type_k = KvCacheTypeFromStr(cacheTypeK) + params.type_v = KvCacheTypeFromStr(cacheTypeV) + return ContextParams{c: params} } @@ -621,3 +624,27 @@ func (s *SamplingContext) Sample(llamaContext *Context, idx int) int { func (s *SamplingContext) Accept(id int, applyGrammar bool) { C.gpt_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar)) } + +// KvCacheTypeFromStr converts a string cache type to the corresponding GGML type value +func KvCacheTypeFromStr(s string) C.enum_ggml_type { + switch s { + case "f32": + return C.GGML_TYPE_F32 + case "f16": + return C.GGML_TYPE_F16 + case "q8_0": + return C.GGML_TYPE_Q8_0 + case "q4_0": + return C.GGML_TYPE_Q4_0 + case "q4_1": + return C.GGML_TYPE_Q4_1 + case "iq4_nl": + return C.GGML_TYPE_IQ4_NL + case "q5_0": + return C.GGML_TYPE_Q5_0 + case "q5_1": + return C.GGML_TYPE_Q5_1 + default: + panic("Unsupported cache type: " + s) + } +} diff --git a/llama/runner/runner.go b/llama/runner/runner.go index a137f879..2e369268 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -758,6 +758,8 @@ func (s *Server) loadModel( flashAttention bool, threads int, multiUserCache bool, + cacheTypeK string, + cacheTypeV string, ) { llama.BackendInit() @@ -767,7 +769,7 @@ func (s *Server) loadModel( panic(err) } - ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention) + ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, cacheTypeK, cacheTypeV) s.lc, err = llama.NewContextWithModel(s.model, ctxParams) if err != nil { panic(err) @@ -817,6 +819,8 @@ func main() { multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users") // Expose requirements as a JSON output to stdout requirements := flag.Bool("requirements", false, "print json requirement information") + cacheTypeK := flag.String("cache-type-k", "f16", "quantization type for key in cache (default: f16)") + cacheTypeV := flag.String("cache-type-v", "f16", "quantization type for value in cache (default: f16)") // These are either ignored by llama.cpp or have no significance to us _ = flag.Bool("embedding", false, "enable embedding vector output (default: disabled)") @@ -877,7 +881,7 @@ func main() { } server.ready.Add(1) - go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache) + go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache, *cacheTypeK, *cacheTypeV) server.cond = sync.NewCond(&server.mu) diff --git a/llm/memory.go b/llm/memory.go index 16f9a743..2bf2f11b 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -123,8 +123,16 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string, slog.Warn("model missing blk.0 layer size") } - // fp16 k,v = sizeof(float16) * n_ctx * n_layer * (n_embd_head_k + n_embd_head_v) * n_head_kv - var kv uint64 = 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * (ggml.KV().EmbeddingHeadCountK() + ggml.KV().EmbeddingHeadCountV()) * ggml.KV().HeadCountKV() + // Check if the model is an embedding model + isEmbeddingModel := false + if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok { + isEmbeddingModel = true + } + + // Estimate the memory required for K and V caches separately as they can have different quantization types + kSize := estimateKvCacheSize(envconfig.CacheTypeK(), uint64(opts.NumCtx), ggml.KV().BlockCount(), ggml.KV().EmbeddingHeadCountK(), ggml.KV().HeadCountKV(), isEmbeddingModel) + vSize := estimateKvCacheSize(envconfig.CacheTypeV(), uint64(opts.NumCtx), ggml.KV().BlockCount(), ggml.KV().EmbeddingHeadCountV(), ggml.KV().HeadCountKV(), isEmbeddingModel) + kv := kSize + vSize // KV is proportional to the number of layers layerSize += kv / ggml.KV().BlockCount() @@ -440,3 +448,39 @@ func projectorMemoryRequirements(filename string) (weights, graphSize uint64) { return weights, graphSize } + +// estimateKvCacheSize determines the memory required for K or V cache based on the quantization type +func estimateKvCacheSize(cacheType string, numCtx, blockCount, embeddingHeadCount, headCountKV uint64, isEmbeddingModel bool) uint64 { + var bytesPerElement float64 + + if isEmbeddingModel && cacheType != "f16" && cacheType != "f32" { + cacheType = "f16" // Default to f16 for embedding models if an unsupported type is specified + } + + switch cacheType { + case "f32", "fp32": + bytesPerElement = 4 // fp32 + case "", "f16", "fp16": + bytesPerElement = 2 // fp16 + case "q8_0": + bytesPerElement = 1 // 1/2 of fp16 + case "q5_1": + bytesPerElement = 0.65 + case "q5_0": + bytesPerElement = 0.625 + case "iq4_nl": + bytesPerElement = 0.6 // 3/4 of fp16 + case "q4_1": + bytesPerElement = 0.55 + case "q4_0": + bytesPerElement = 0.5 // 1/4 of fp16 + default: + // Default to fp16 if unknown + bytesPerElement = 2 + slog.Warn("Unknown cache type, defaulting to fp16", "type", cacheType) + } + + estimate := uint64(float64(numCtx*blockCount*embeddingHeadCount*headCountKV) * bytesPerElement) + // round up to the nearest multiple of 64 bytes + return ((estimate + 63) / 64) * 64 +} diff --git a/llm/memory_test.go b/llm/memory_test.go index 73e77d90..b0780b48 100644 --- a/llm/memory_test.go +++ b/llm/memory_test.go @@ -15,6 +15,8 @@ import ( func TestEstimateGPULayers(t *testing.T) { t.Setenv("OLLAMA_DEBUG", "1") + t.Setenv("OLLAMA_CACHE_TYPE_K", "") + t.Setenv("OLLAMA_CACHE_TYPE_V", "") modelName := "dummy" f, err := os.CreateTemp(t.TempDir(), modelName) @@ -57,6 +59,7 @@ func TestEstimateGPULayers(t *testing.T) { } projectors := []string{} opts := api.DefaultOptions() + t.Run("cpu", func(t *testing.T) { estimate := EstimateGPULayers(gpus, ggml, projectors, opts) assert.Equal(t, 0, estimate.Layers) @@ -70,7 +73,7 @@ func TestEstimateGPULayers(t *testing.T) { projectorSize := uint64(0) memoryLayerOutput := uint64(4) - // Dual CUDA scenario with assymetry + // Dual CUDA scenario with asymmetry gpuMinimumMemory := uint64(2048) gpus = []discover.GpuInfo{ { @@ -126,3 +129,114 @@ func TestEstimateGPULayers(t *testing.T) { }) } } + +func TestEstimateKvCacheSize(t *testing.T) { + tests := []struct { + name string + cacheType string + numCtx uint64 + blockCount uint64 + embeddingHeadCount uint64 + headCountKV uint64 + isEmbeddingModel bool + expected uint64 + }{ + { + name: "f32 cache type", + cacheType: "f32", + numCtx: 1024, + blockCount: 32, + embeddingHeadCount: 32, + headCountKV: 32, + isEmbeddingModel: false, + expected: 134217728, // 128 MB + }, + { + name: "f16 cache type", + cacheType: "f16", + numCtx: 1024, + blockCount: 32, + embeddingHeadCount: 32, + headCountKV: 32, + isEmbeddingModel: false, + expected: 67108864, // 64 MB + }, + { + name: "q4_0 cache type", + cacheType: "q4_0", + numCtx: 1024, + blockCount: 32, + embeddingHeadCount: 32, + headCountKV: 32, + isEmbeddingModel: false, + expected: 16777216, // 16 MB + }, + { + name: "q8_0 cache type", + cacheType: "q8_0", + numCtx: 1024, + blockCount: 32, + embeddingHeadCount: 32, + headCountKV: 32, + isEmbeddingModel: false, + expected: 33554432, // 32 MB + }, + { + name: "unknown cache type", + cacheType: "unknown", + numCtx: 1024, + blockCount: 32, + embeddingHeadCount: 32, + headCountKV: 32, + isEmbeddingModel: false, + expected: 67108864, // 64 MB (defaults to f16) + }, + { + name: "empty cache type", + cacheType: "", + numCtx: 1024, + blockCount: 32, + embeddingHeadCount: 32, + headCountKV: 32, + isEmbeddingModel: false, + expected: 67108864, // 64 MB (defaults to f16) + }, + { + name: "rounding test", + cacheType: "f32", + numCtx: 1000, + blockCount: 32, + embeddingHeadCount: 32, + headCountKV: 32, + isEmbeddingModel: false, + expected: 131072000, // Rounded up to nearest multiple of 64 + }, + { + name: "embedding model with q4_0 (should default to f16)", + cacheType: "q4_0", + numCtx: 1024, + blockCount: 32, + embeddingHeadCount: 32, + headCountKV: 32, + isEmbeddingModel: true, + expected: 67108864, // 64 MB (defaults to f16) + }, + { + name: "embedding model with f32", + cacheType: "f32", + numCtx: 1024, + blockCount: 32, + embeddingHeadCount: 32, + headCountKV: 32, + isEmbeddingModel: true, + expected: 134217728, // 128 MB + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := estimateKvCacheSize(tt.cacheType, tt.numCtx, tt.blockCount, tt.embeddingHeadCount, tt.headCountKV, tt.isEmbeddingModel) + assert.Equal(t, tt.expected, result, "Estimated KV cache size does not match expected value") + }) + } +} diff --git a/llm/server.go b/llm/server.go index a4c99dd9..e8ca9c6a 100644 --- a/llm/server.go +++ b/llm/server.go @@ -17,6 +17,7 @@ import ( "os/exec" "path/filepath" "runtime" + "slices" "strconv" "strings" "sync" @@ -218,19 +219,97 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter params = append(params, "--threads", strconv.Itoa(defaultThreads)) } - if !opts.F16KV { + if !opts.F16KV && envconfig.CacheTypeK() == "" && envconfig.CacheTypeV() == "" { params = append(params, "--memory-f32") } - flashAttnEnabled := envconfig.FlashAttention() + isEmbeddingModel := false + if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok { + isEmbeddingModel = true + } + setCacheTypeParam := func(paramName, cacheType string) { + if cacheType == "" { + return + } + + validCacheTypes := []string{"f32", "f16", "q8_0", "q5_1", "q5_0", "iq4_nl", "q4_1", "q4_0"} + if !slices.Contains(validCacheTypes, cacheType) { + slog.Warn("invalid cache type, ignoring", "param", paramName, "type", cacheType) + return + } + + // For embedding models, only allow f16 and f32 + if isEmbeddingModel && cacheType != "f16" && cacheType != "f32" { + slog.Warn("only f16 and f32 cache types are supported for embedding models, ignoring", + "param", paramName, "type", cacheType) + return + } + + params = append(params, paramName, cacheType) + slog.Debug("Setting cache type", "param", paramName, "type", cacheType) + } + + // Define cacheTypeK and cacheTypeV + cacheTypeK := envconfig.CacheTypeK() + cacheTypeV := envconfig.CacheTypeV() + + // Set cache types only if they are not empty + supportsFlashAttention := func(ggml *GGML) bool { + headCountK := ggml.KV().EmbeddingHeadCountK() + headCountV := ggml.KV().EmbeddingHeadCountV() + + if headCountK == 0 || headCountV == 0 { + slog.Debug("Model is missing embedding head count for K or V") + return false + } + + if headCountK != headCountV { + slog.Debug("Embedding head count K does not equal V", "K", headCountK, "V", headCountV) + return false + } + + slog.Debug("Model supports flash attention", "headCountK", headCountK, "headCountV", headCountV) + return true + } + + flashAttnSupported := supportsFlashAttention(ggml) + + hardwareSupportsFlashAttn := true for _, g := range gpus { // only cuda (compute capability 7+) and metal support flash attention if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) { - flashAttnEnabled = false + hardwareSupportsFlashAttn = false + break } + } - // mmap has issues with partial offloading on metal + flashAttnEnabled := envconfig.FlashAttention() && flashAttnSupported && hardwareSupportsFlashAttn && !isEmbeddingModel + + slog.Debug("Flash attention status", + "supported_by_model", flashAttnSupported, + "supported_by_hardware", hardwareSupportsFlashAttn, + "is_embedding_model", isEmbeddingModel, + "enabled", flashAttnEnabled) + + if flashAttnEnabled { + params = append(params, "--flash-attn") + slog.Info("Enabling flash attention") + + setCacheTypeParam("--cache-type-k", cacheTypeK) + setCacheTypeParam("--cache-type-v", cacheTypeV) + } else { + slog.Info("Flash attention not enabled") + quantizedCacheTypes := []string{"q8_0", "q5_1", "q5_0", "iq4_nl", "q4_1", "q4_0"} + if !isEmbeddingModel && (cacheTypeK != "" || cacheTypeV != "") { + if slices.Contains(quantizedCacheTypes, cacheTypeK) || slices.Contains(quantizedCacheTypes, cacheTypeV) { + slog.Warn("Quantized cache types require flash attention. Using default cache types.") + } + } + } + + // mmap has issues with partial offloading on metal + for _, g := range gpus { if g.Library == "metal" && uint64(opts.NumGPU) > 0 && uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 { @@ -239,10 +318,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter } } - if flashAttnEnabled { - params = append(params, "--flash-attn") - } - // Windows CUDA should not use mmap for best performance // Linux with a model larger than free space, mmap leads to thrashing // For CPU loads we want the memory to be allocated, not FS cache From af7d64b88741a0d6571c2c7c2c538b3410e2df2a Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 6 Nov 2024 08:10:52 +1100 Subject: [PATCH 2/5] feat: allow setting KV cache type --- llm/memory.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llm/memory.go b/llm/memory.go index ea19ccab..26abdcee 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -141,10 +141,10 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string, _, graphPartialOffload, graphFullOffload = ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch))) if graphPartialOffload == 0 { - graphPartialOffload = ggml.KV().GQA() * kv / 6 + graphPartialOffload = ggml.KV().GQA() * kv / 6 } if graphFullOffload == 0 { - graphFullOffload = graphPartialOffload + graphFullOffload = graphPartialOffload } // KV is proportional to the number of layers From cd0be17fbadf359126d966c58448d5e6a5f80668 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 6 Nov 2024 08:12:07 +1100 Subject: [PATCH 3/5] feat: allow setting KV cache type --- llm/memory.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/llm/memory.go b/llm/memory.go index 26abdcee..8eb99ca6 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -147,9 +147,6 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string, graphFullOffload = graphPartialOffload } - // KV is proportional to the number of layers - layerSize += kv / ggml.KV().BlockCount() - // on metal there's no partial offload overhead if gpus[0].Library == "metal" { graphPartialOffload = graphFullOffload From 7d787ba90d07ab0991eee393e62cf5fbe421542f Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 14 Nov 2024 07:00:43 +1100 Subject: [PATCH 4/5] fix(docs): update FA FAQ wording slightly refactor: only allow setting K and V cache types together --- cmd/cmd.go | 3 +-- docs/faq.md | 7 +++---- envconfig/config.go | 9 +++------ llama/llama.go | 6 +++--- llama/runner/runner.go | 12 +++++------- llm/memory.go | 6 ++---- llm/memory_test.go | 3 +-- llm/server.go | 19 +++++++++---------- 8 files changed, 27 insertions(+), 38 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index d3f8fd56..066ea067 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1482,8 +1482,7 @@ func NewCLI() *cobra.Command { envVars["OLLAMA_SCHED_SPREAD"], envVars["OLLAMA_TMPDIR"], envVars["OLLAMA_FLASH_ATTENTION"], - envVars["OLLAMA_CACHE_TYPE_K"], - envVars["OLLAMA_CACHE_TYPE_V"], + envVars["OLLAMA_KV_CACHE_TYPE"], envVars["OLLAMA_LLM_LIBRARY"], envVars["OLLAMA_GPU_OVERHEAD"], envVars["OLLAMA_LOAD_TIMEOUT"], diff --git a/docs/faq.md b/docs/faq.md index 14022a43..2d3cd1a8 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -291,16 +291,15 @@ Installing multiple GPUs of the same brand can be a great way to increase your a Flash Attention is a feature of most (but not all) modern models that can significantly reduce memory usage as the context size grows. To enable Flash Attention, set the `OLLAMA_FLASH_ATTENTION` environment variable to `1` when starting the Ollama server. -> Note: If you're using an uncommon quantization type with CUDA, you may benefit from build Ollama with `LLAMA_CUDA_FA_ALL_QUANTS=1` to make llama.cpp build all flash attention quantization types. +> Note: If you're using an uncommon quantization type with CUDA, advanced users may benefit from building Ollama and passing `GGML_CUDA_FA_ALL_QUANTS=1` to the llama.cpp build to enable FA for all combinations of quantisation types. More information on this can be found in [llama.cpp](https://github.com/ggerganov/llama.cpp/blob/fb4a0ec0833c71cff5a1a367ba375447ce6106eb/ggml/src/ggml-cuda/fattn-common.cuh#L575). ## How can I set the quantization type for the K/V cache? The K/V context cache can be quantized to significantly reduce memory usage when Flash Attention is enabled. -To use quantized K/V cache with Ollama you can set the following environment variables: +To use quantized K/V cache with Ollama you can set the following environment variable: -- `OLLAMA_CACHE_TYPE_K` - The quantization type for the key cache. Default is `f16`. -- `OLLAMA_CACHE_TYPE_V` - The quantization type for the value cache. Default is `f16`. +- `OLLAMA_KV_CACHE_TYPE` - The quantization type for the K/V cache. Default is `f16`. > Note: Currently this is a global option - meaning all models will run with the specified quantization type. diff --git a/envconfig/config.go b/envconfig/config.go index 73271f86..027608f7 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -153,10 +153,8 @@ var ( Debug = Bool("OLLAMA_DEBUG") // FlashAttention enables the experimental flash attention feature. FlashAttention = Bool("OLLAMA_FLASH_ATTENTION") - // CacheTypeK is the quantization type for the K/V cache keys. - CacheTypeK = String("OLLAMA_CACHE_TYPE_K") - // CacheTypeV is the quantization type for the K/V cache values. - CacheTypeV = String("OLLAMA_CACHE_TYPE_V") + // KvCacheType is the quantization type for the K/V cache. + KvCacheType = String("OLLAMA_KV_CACHE_TYPE") // NoHistory disables readline history. NoHistory = Bool("OLLAMA_NOHISTORY") // NoPrune disables pruning of model blobs on startup. @@ -238,8 +236,7 @@ func AsMap() map[string]EnvVar { ret := map[string]EnvVar{ "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"}, - "OLLAMA_CACHE_TYPE_K": {"OLLAMA_CACHE_TYPE_K", CacheTypeK(), "Type of cache for keys (default: f16)"}, - "OLLAMA_CACHE_TYPE_V": {"OLLAMA_CACHE_TYPE_V", CacheTypeV(), "Type of cache for values (default: f16)"}, + "OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantisation type for the K/V cache (default: f16)"}, "OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"}, "OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"}, diff --git a/llama/llama.go b/llama/llama.go index ec9fe0b3..04cab77c 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -140,7 +140,7 @@ type ContextParams struct { c C.struct_llama_context_params } -func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, cacheTypeK string, cacheTypeV string) ContextParams { +func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams { params := C.llama_context_default_params() params.n_ctx = C.uint(numCtx) params.n_batch = C.uint(batchSize) @@ -149,8 +149,8 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla params.n_threads_batch = params.n_threads params.embeddings = C.bool(true) params.flash_attn = C.bool(flashAttention) - params.type_k = KvCacheTypeFromStr(cacheTypeK) - params.type_v = KvCacheTypeFromStr(cacheTypeV) + params.type_k = KvCacheTypeFromStr(kvCacheType) + params.type_v = KvCacheTypeFromStr(kvCacheType) return ContextParams{c: params} } diff --git a/llama/runner/runner.go b/llama/runner/runner.go index bb434046..3c289ced 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -471,7 +471,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) // the last one generated wasn't submitted to Decode // - Remove any stop sequences that we stripped out // - If truncateStop removed a portion of a token, drop that - // - As defense-in-depth, if truncatedToken didn't find a stop token + // - As defence-in-depth, if truncatedToken didn't find a stop token // remove the extra one that we added to the cache len tokenLen := len(seq.cache.Inputs) + 1 tokenLen -= origLen - newLen @@ -762,8 +762,7 @@ func (s *Server) loadModel( flashAttention bool, threads int, multiUserCache bool, - cacheTypeK string, - cacheTypeV string, + kvCacheType string, ) { llama.BackendInit() @@ -773,7 +772,7 @@ func (s *Server) loadModel( panic(err) } - ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, cacheTypeK, cacheTypeV) + ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, kvCacheType) s.lc, err = llama.NewContextWithModel(s.model, ctxParams) if err != nil { panic(err) @@ -821,8 +820,7 @@ func main() { tensorSplit := flag.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions") multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users") requirements := flag.Bool("requirements", false, "print json requirement information") - cacheTypeK := flag.String("cache-type-k", "f16", "quantization type for key in cache (default: f16)") - cacheTypeV := flag.String("cache-type-v", "f16", "quantization type for value in cache (default: f16)") + kvCacheType := flag.String("kv-cache-type", "f16", "quantization type for KV cache (default: f16)") flag.Parse() if *requirements { @@ -878,7 +876,7 @@ func main() { } server.ready.Add(1) - go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache, *cacheTypeK, *cacheTypeV) + go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache, *kvCacheType) server.cond = sync.NewCond(&server.mu) diff --git a/llm/memory.go b/llm/memory.go index 8eb99ca6..ff1e09cc 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -129,10 +129,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string, isEmbeddingModel = true } - // Estimate the memory required for K and V caches separately as they can have different quantization types - kSize := estimateKvCacheSize(envconfig.CacheTypeK(), uint64(opts.NumCtx), ggml.KV().BlockCount(), ggml.KV().EmbeddingHeadCountK(), ggml.KV().HeadCountKV(), isEmbeddingModel) - vSize := estimateKvCacheSize(envconfig.CacheTypeV(), uint64(opts.NumCtx), ggml.KV().BlockCount(), ggml.KV().EmbeddingHeadCountV(), ggml.KV().HeadCountKV(), isEmbeddingModel) - kv := kSize + vSize + // Estimate the memory required for KV cache quantisation + kv := estimateKvCacheSize(envconfig.KvCacheType(), uint64(opts.NumCtx), ggml.KV().BlockCount(), ggml.KV().EmbeddingHeadCountK(), ggml.KV().HeadCountKV(), isEmbeddingModel) * 2 // KV is proportional to the number of layers layerSize += kv / ggml.KV().BlockCount() diff --git a/llm/memory_test.go b/llm/memory_test.go index b0780b48..73ee7915 100644 --- a/llm/memory_test.go +++ b/llm/memory_test.go @@ -15,8 +15,7 @@ import ( func TestEstimateGPULayers(t *testing.T) { t.Setenv("OLLAMA_DEBUG", "1") - t.Setenv("OLLAMA_CACHE_TYPE_K", "") - t.Setenv("OLLAMA_CACHE_TYPE_V", "") + t.Setenv("OLLAMA_KV_CACHE_TYPE", "") modelName := "dummy" f, err := os.CreateTemp(t.TempDir(), modelName) diff --git a/llm/server.go b/llm/server.go index 413a19c2..98b3c07a 100644 --- a/llm/server.go +++ b/llm/server.go @@ -218,11 +218,13 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter params = append(params, "--threads", strconv.Itoa(defaultThreads)) } + // isEmbeddingModel checks for common GGML attributes that help distinguish most embedding models from normal models. isEmbeddingModel := false if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok { isEmbeddingModel = true } + // Validates and applies KV cache parameters setCacheTypeParam := func(paramName, cacheType string) { if cacheType == "" { return @@ -245,9 +247,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter slog.Debug("Setting cache type", "param", paramName, "type", cacheType) } - // Define cacheTypeK and cacheTypeV - cacheTypeK := envconfig.CacheTypeK() - cacheTypeV := envconfig.CacheTypeV() + kvCacheType := envconfig.KvCacheType() // Set cache types only if they are not empty supportsFlashAttention := func(ggml *GGML) bool { @@ -255,12 +255,12 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter headCountV := ggml.KV().EmbeddingHeadCountV() if headCountK == 0 || headCountV == 0 { - slog.Debug("Model is missing embedding head count for K or V") + slog.Debug("Model is missing embedding head count for K or V, does not support flash attention") return false } if headCountK != headCountV { - slog.Debug("Embedding head count K does not equal V", "K", headCountK, "V", headCountV) + slog.Debug("Embedding head count K does not equal V, does not support flash attention", "K", headCountK, "V", headCountV) return false } @@ -291,14 +291,13 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter params = append(params, "--flash-attn") slog.Info("Enabling flash attention") - setCacheTypeParam("--cache-type-k", cacheTypeK) - setCacheTypeParam("--cache-type-v", cacheTypeV) + setCacheTypeParam("--kv-cache-type", kvCacheType) } else { slog.Info("Flash attention not enabled") quantizedCacheTypes := []string{"q8_0", "q5_1", "q5_0", "iq4_nl", "q4_1", "q4_0"} - if !isEmbeddingModel && (cacheTypeK != "" || cacheTypeV != "") { - if slices.Contains(quantizedCacheTypes, cacheTypeK) || slices.Contains(quantizedCacheTypes, cacheTypeV) { - slog.Warn("Quantized cache types require flash attention. Using default cache types.") + if !isEmbeddingModel && (kvCacheType != "") { + if slices.Contains(quantizedCacheTypes, kvCacheType) { + slog.Warn("Quantized cache types require flash attention. Falling back to default cache types.") } } } From 66839c3bd7c6f31b59a66b723550d999189fe0c2 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 14 Nov 2024 09:37:23 +1100 Subject: [PATCH 5/5] fix(docs): update FA FAQ wording slightly --- docs/faq.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/faq.md b/docs/faq.md index 2d3cd1a8..47cf7ce2 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -291,7 +291,7 @@ Installing multiple GPUs of the same brand can be a great way to increase your a Flash Attention is a feature of most (but not all) modern models that can significantly reduce memory usage as the context size grows. To enable Flash Attention, set the `OLLAMA_FLASH_ATTENTION` environment variable to `1` when starting the Ollama server. -> Note: If you're using an uncommon quantization type with CUDA, advanced users may benefit from building Ollama and passing `GGML_CUDA_FA_ALL_QUANTS=1` to the llama.cpp build to enable FA for all combinations of quantisation types. More information on this can be found in [llama.cpp](https://github.com/ggerganov/llama.cpp/blob/fb4a0ec0833c71cff5a1a367ba375447ce6106eb/ggml/src/ggml-cuda/fattn-common.cuh#L575). +> Note: Advanced users using CUDA may benefit from building Ollama and passing `GGML_CUDA_FA_ALL_QUANTS=1` to the llama.cpp build to enable FA for all combinations of quantisation types. More information on this can be found in [llama.cpp](https://github.com/ggerganov/llama.cpp/blob/fb4a0ec0833c71cff5a1a367ba375447ce6106eb/ggml/src/ggml-cuda/fattn-common.cuh#L575). ## How can I set the quantization type for the K/V cache?