From bd6e38fb1afed8c570d6a5f7eb87082f5221426a Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 12 Oct 2023 09:47:17 -0700 Subject: [PATCH] refactor memory check --- llm/llm.go | 41 +++++++++++++++-------------------------- 1 file changed, 15 insertions(+), 26 deletions(-) diff --git a/llm/llm.go b/llm/llm.go index 4ae5dd2e..a2619382 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -58,38 +58,27 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error } } - totalResidentMemory := memory.TotalMemory() + var requiredMemory int64 + var f16Multiplier int64 = 2 + totalResidentMemory := int64(memory.TotalMemory()) switch ggml.ModelType() { case "3B", "7B": - if ggml.FileType() == "F16" && totalResidentMemory < 16*format.GigaByte { - return nil, fmt.Errorf("F16 model requires at least 16 GB of memory") - } else if totalResidentMemory < 8*format.GigaByte { - return nil, fmt.Errorf("model requires at least 8 GB of memory") - } + requiredMemory = 8 * format.GigaByte case "13B": - if ggml.FileType() == "F16" && totalResidentMemory < 32*format.GigaByte { - return nil, fmt.Errorf("F16 model requires at least 32 GB of memory") - } else if totalResidentMemory < 16*format.GigaByte { - return nil, fmt.Errorf("model requires at least 16 GB of memory") - } + requiredMemory = 16 * format.GigaByte case "30B", "34B", "40B": - if ggml.FileType() == "F16" && totalResidentMemory < 64*format.GigaByte { - return nil, fmt.Errorf("F16 model requires at least 64 GB of memory") - } else if totalResidentMemory < 32*format.GigaByte { - return nil, fmt.Errorf("model requires at least 32 GB of memory") - } + requiredMemory = 32 * format.GigaByte case "65B", "70B": - if ggml.FileType() == "F16" && totalResidentMemory < 128*format.GigaByte { - return nil, fmt.Errorf("F16 model requires at least 128 GB of memory") - } else if totalResidentMemory < 64*format.GigaByte { - return nil, fmt.Errorf("model requires at least 64 GB of memory") - } + requiredMemory = 64 * format.GigaByte case "180B": - if ggml.FileType() == "F16" && totalResidentMemory < 512*format.GigaByte { - return nil, fmt.Errorf("F16 model requires at least 512GB of memory") - } else if totalResidentMemory < 128*format.GigaByte { - return nil, fmt.Errorf("model requires at least 128GB of memory") - } + requiredMemory = 128 * format.GigaByte + f16Multiplier = 4 + } + + if ggml.FileType() == "F16" && requiredMemory*f16Multiplier > totalResidentMemory { + return nil, fmt.Errorf("F16 model requires at least %s of memory", format.HumanBytes(requiredMemory)) + } else if requiredMemory > totalResidentMemory { + return nil, fmt.Errorf("model requires at least %s of memory", format.HumanBytes(requiredMemory)) } switch ggml.Name() {