diff --git a/llm/llm.go b/llm/llm.go index 193d5241..ef424b5d 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "os" + "runtime" "github.com/pbnjay/memory" @@ -37,20 +38,22 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error return nil, err } - switch ggml.FileType() { - case "Q8_0": - if ggml.Name() != "gguf" && opts.NumGPU != 0 { - // GGML Q8_0 do not support Metal API and will - // cause the runner to segmentation fault so disable GPU - log.Printf("WARNING: GPU disabled for F32, Q5_0, Q5_1, and Q8_0") - opts.NumGPU = 0 - } - case "F32", "Q5_0", "Q5_1": - if opts.NumGPU != 0 { - // F32, Q5_0, Q5_1, and Q8_0 do not support Metal API and will - // cause the runner to segmentation fault so disable GPU - log.Printf("WARNING: GPU disabled for F32, Q5_0, Q5_1, and Q8_0") - opts.NumGPU = 0 + if runtime.GOOS == "darwin" { + switch ggml.FileType() { + case "Q8_0": + if ggml.Name() != "gguf" && opts.NumGPU != 0 { + // GGML Q8_0 do not support Metal API and will + // cause the runner to segmentation fault so disable GPU + log.Printf("WARNING: GPU disabled for F32, Q5_0, Q5_1, and Q8_0") + opts.NumGPU = 0 + } + case "F32", "Q5_0", "Q5_1": + if opts.NumGPU != 0 { + // F32, Q5_0, Q5_1, and Q8_0 do not support Metal API and will + // cause the runner to segmentation fault so disable GPU + log.Printf("WARNING: GPU disabled for F32, Q5_0, Q5_1, and Q8_0") + opts.NumGPU = 0 + } } }