diff --git a/llm/llm.go b/llm/llm.go index a2619382..49405250 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -60,7 +60,7 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error var requiredMemory int64 var f16Multiplier int64 = 2 - totalResidentMemory := int64(memory.TotalMemory()) + switch ggml.ModelType() { case "3B", "7B": requiredMemory = 8 * format.GigaByte @@ -75,10 +75,19 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error f16Multiplier = 4 } - if ggml.FileType() == "F16" && requiredMemory*f16Multiplier > totalResidentMemory { - return nil, fmt.Errorf("F16 model requires at least %s of memory", format.HumanBytes(requiredMemory)) - } else if requiredMemory > totalResidentMemory { - return nil, fmt.Errorf("model requires at least %s of memory", format.HumanBytes(requiredMemory)) + systemMemory := int64(memory.TotalMemory()) + + videoMemory, err := CheckVRAM() + if err != nil{ + videoMemory = 0 + } + + totalMemory := systemMemory + videoMemory + + if ggml.FileType() == "F16" && requiredMemory*f16Multiplier > totalMemory { + return nil, fmt.Errorf("F16 model requires at least %s of total memory", format.HumanBytes(requiredMemory)) + } else if requiredMemory > totalMemory { + return nil, fmt.Errorf("model requires at least %s of total memory", format.HumanBytes(requiredMemory)) } switch ggml.Name() {