fix(docs): update FA FAQ wording slightly

refactor: only allow setting K and V cache types together
This commit is contained in:
Sam 2024-11-14 07:00:43 +11:00
parent b637acb4e5
commit 7d787ba90d
8 changed files with 27 additions and 38 deletions

View File

@ -1482,8 +1482,7 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_SCHED_SPREAD"],
envVars["OLLAMA_TMPDIR"],
envVars["OLLAMA_FLASH_ATTENTION"],
envVars["OLLAMA_CACHE_TYPE_K"],
envVars["OLLAMA_CACHE_TYPE_V"],
envVars["OLLAMA_KV_CACHE_TYPE"],
envVars["OLLAMA_LLM_LIBRARY"],
envVars["OLLAMA_GPU_OVERHEAD"],
envVars["OLLAMA_LOAD_TIMEOUT"],

View File

@ -291,16 +291,15 @@ Installing multiple GPUs of the same brand can be a great way to increase your a
Flash Attention is a feature of most (but not all) modern models that can significantly reduce memory usage as the context size grows. To enable Flash Attention, set the `OLLAMA_FLASH_ATTENTION` environment variable to `1` when starting the Ollama server.
> Note: If you're using an uncommon quantization type with CUDA, you may benefit from build Ollama with `LLAMA_CUDA_FA_ALL_QUANTS=1` to make llama.cpp build all flash attention quantization types.
> Note: If you're using an uncommon quantization type with CUDA, advanced users may benefit from building Ollama and passing `GGML_CUDA_FA_ALL_QUANTS=1` to the llama.cpp build to enable FA for all combinations of quantisation types. More information on this can be found in [llama.cpp](https://github.com/ggerganov/llama.cpp/blob/fb4a0ec0833c71cff5a1a367ba375447ce6106eb/ggml/src/ggml-cuda/fattn-common.cuh#L575).
## How can I set the quantization type for the K/V cache?
The K/V context cache can be quantized to significantly reduce memory usage when Flash Attention is enabled.
To use quantized K/V cache with Ollama you can set the following environment variables:
To use quantized K/V cache with Ollama you can set the following environment variable:
- `OLLAMA_CACHE_TYPE_K` - The quantization type for the key cache. Default is `f16`.
- `OLLAMA_CACHE_TYPE_V` - The quantization type for the value cache. Default is `f16`.
- `OLLAMA_KV_CACHE_TYPE` - The quantization type for the K/V cache. Default is `f16`.
> Note: Currently this is a global option - meaning all models will run with the specified quantization type.

View File

@ -153,10 +153,8 @@ var (
Debug = Bool("OLLAMA_DEBUG")
// FlashAttention enables the experimental flash attention feature.
FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
// CacheTypeK is the quantization type for the K/V cache keys.
CacheTypeK = String("OLLAMA_CACHE_TYPE_K")
// CacheTypeV is the quantization type for the K/V cache values.
CacheTypeV = String("OLLAMA_CACHE_TYPE_V")
// KvCacheType is the quantization type for the K/V cache.
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
// NoHistory disables readline history.
NoHistory = Bool("OLLAMA_NOHISTORY")
// NoPrune disables pruning of model blobs on startup.
@ -238,8 +236,7 @@ func AsMap() map[string]EnvVar {
ret := map[string]EnvVar{
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
"OLLAMA_CACHE_TYPE_K": {"OLLAMA_CACHE_TYPE_K", CacheTypeK(), "Type of cache for keys (default: f16)"},
"OLLAMA_CACHE_TYPE_V": {"OLLAMA_CACHE_TYPE_V", CacheTypeV(), "Type of cache for values (default: f16)"},
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantisation type for the K/V cache (default: f16)"},
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},

View File

@ -140,7 +140,7 @@ type ContextParams struct {
c C.struct_llama_context_params
}
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, cacheTypeK string, cacheTypeV string) ContextParams {
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
params := C.llama_context_default_params()
params.n_ctx = C.uint(numCtx)
params.n_batch = C.uint(batchSize)
@ -149,8 +149,8 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla
params.n_threads_batch = params.n_threads
params.embeddings = C.bool(true)
params.flash_attn = C.bool(flashAttention)
params.type_k = KvCacheTypeFromStr(cacheTypeK)
params.type_v = KvCacheTypeFromStr(cacheTypeV)
params.type_k = KvCacheTypeFromStr(kvCacheType)
params.type_v = KvCacheTypeFromStr(kvCacheType)
return ContextParams{c: params}
}

View File

@ -471,7 +471,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
// the last one generated wasn't submitted to Decode
// - Remove any stop sequences that we stripped out
// - If truncateStop removed a portion of a token, drop that
// - As defense-in-depth, if truncatedToken didn't find a stop token
// - As defence-in-depth, if truncatedToken didn't find a stop token
// remove the extra one that we added to the cache len
tokenLen := len(seq.cache.Inputs) + 1
tokenLen -= origLen - newLen
@ -762,8 +762,7 @@ func (s *Server) loadModel(
flashAttention bool,
threads int,
multiUserCache bool,
cacheTypeK string,
cacheTypeV string,
kvCacheType string,
) {
llama.BackendInit()
@ -773,7 +772,7 @@ func (s *Server) loadModel(
panic(err)
}
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, cacheTypeK, cacheTypeV)
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, kvCacheType)
s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
if err != nil {
panic(err)
@ -821,8 +820,7 @@ func main() {
tensorSplit := flag.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
requirements := flag.Bool("requirements", false, "print json requirement information")
cacheTypeK := flag.String("cache-type-k", "f16", "quantization type for key in cache (default: f16)")
cacheTypeV := flag.String("cache-type-v", "f16", "quantization type for value in cache (default: f16)")
kvCacheType := flag.String("kv-cache-type", "f16", "quantization type for KV cache (default: f16)")
flag.Parse()
if *requirements {
@ -878,7 +876,7 @@ func main() {
}
server.ready.Add(1)
go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache, *cacheTypeK, *cacheTypeV)
go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache, *kvCacheType)
server.cond = sync.NewCond(&server.mu)

View File

@ -129,10 +129,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
isEmbeddingModel = true
}
// Estimate the memory required for K and V caches separately as they can have different quantization types
kSize := estimateKvCacheSize(envconfig.CacheTypeK(), uint64(opts.NumCtx), ggml.KV().BlockCount(), ggml.KV().EmbeddingHeadCountK(), ggml.KV().HeadCountKV(), isEmbeddingModel)
vSize := estimateKvCacheSize(envconfig.CacheTypeV(), uint64(opts.NumCtx), ggml.KV().BlockCount(), ggml.KV().EmbeddingHeadCountV(), ggml.KV().HeadCountKV(), isEmbeddingModel)
kv := kSize + vSize
// Estimate the memory required for KV cache quantisation
kv := estimateKvCacheSize(envconfig.KvCacheType(), uint64(opts.NumCtx), ggml.KV().BlockCount(), ggml.KV().EmbeddingHeadCountK(), ggml.KV().HeadCountKV(), isEmbeddingModel) * 2
// KV is proportional to the number of layers
layerSize += kv / ggml.KV().BlockCount()

View File

@ -15,8 +15,7 @@ import (
func TestEstimateGPULayers(t *testing.T) {
t.Setenv("OLLAMA_DEBUG", "1")
t.Setenv("OLLAMA_CACHE_TYPE_K", "")
t.Setenv("OLLAMA_CACHE_TYPE_V", "")
t.Setenv("OLLAMA_KV_CACHE_TYPE", "")
modelName := "dummy"
f, err := os.CreateTemp(t.TempDir(), modelName)

View File

@ -218,11 +218,13 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
params = append(params, "--threads", strconv.Itoa(defaultThreads))
}
// isEmbeddingModel checks for common GGML attributes that help distinguish most embedding models from normal models.
isEmbeddingModel := false
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
isEmbeddingModel = true
}
// Validates and applies KV cache parameters
setCacheTypeParam := func(paramName, cacheType string) {
if cacheType == "" {
return
@ -245,9 +247,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
slog.Debug("Setting cache type", "param", paramName, "type", cacheType)
}
// Define cacheTypeK and cacheTypeV
cacheTypeK := envconfig.CacheTypeK()
cacheTypeV := envconfig.CacheTypeV()
kvCacheType := envconfig.KvCacheType()
// Set cache types only if they are not empty
supportsFlashAttention := func(ggml *GGML) bool {
@ -255,12 +255,12 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
headCountV := ggml.KV().EmbeddingHeadCountV()
if headCountK == 0 || headCountV == 0 {
slog.Debug("Model is missing embedding head count for K or V")
slog.Debug("Model is missing embedding head count for K or V, does not support flash attention")
return false
}
if headCountK != headCountV {
slog.Debug("Embedding head count K does not equal V", "K", headCountK, "V", headCountV)
slog.Debug("Embedding head count K does not equal V, does not support flash attention", "K", headCountK, "V", headCountV)
return false
}
@ -291,14 +291,13 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
params = append(params, "--flash-attn")
slog.Info("Enabling flash attention")
setCacheTypeParam("--cache-type-k", cacheTypeK)
setCacheTypeParam("--cache-type-v", cacheTypeV)
setCacheTypeParam("--kv-cache-type", kvCacheType)
} else {
slog.Info("Flash attention not enabled")
quantizedCacheTypes := []string{"q8_0", "q5_1", "q5_0", "iq4_nl", "q4_1", "q4_0"}
if !isEmbeddingModel && (cacheTypeK != "" || cacheTypeV != "") {
if slices.Contains(quantizedCacheTypes, cacheTypeK) || slices.Contains(quantizedCacheTypes, cacheTypeV) {
slog.Warn("Quantized cache types require flash attention. Using default cache types.")
if !isEmbeddingModel && (kvCacheType != "") {
if slices.Contains(quantizedCacheTypes, kvCacheType) {
slog.Warn("Quantized cache types require flash attention. Falling back to default cache types.")
}
}
}