Move Go code out of llm package

This commit is contained in:
Daniel Hiltgen 2024-10-08 13:54:25 -07:00
parent 0ccc73251a
commit 4e988ad5d6
38 changed files with 243 additions and 225 deletions

View File

@ -9,7 +9,7 @@ import (
"log/slog"
"strings"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/fileutils"
)
type ModelParameters struct {
@ -27,8 +27,8 @@ type AdapterParameters struct {
} `json:"lora_parameters"`
}
func (ModelParameters) KV(t *Tokenizer) llm.KV {
kv := llm.KV{
func (ModelParameters) KV(t *Tokenizer) fileutils.KV {
kv := fileutils.KV{
"general.file_type": uint32(1),
"general.quantization_version": uint32(2),
"tokenizer.ggml.pre": t.Pre,
@ -54,7 +54,7 @@ func (ModelParameters) KV(t *Tokenizer) llm.KV {
return kv
}
func (p AdapterParameters) KV() llm.KV {
func (p AdapterParameters) KV() fileutils.KV {
var alpha float32
if p.LoraParameters.Alpha == 0 {
alpha = float32(p.Alpha)
@ -62,7 +62,7 @@ func (p AdapterParameters) KV() llm.KV {
alpha = p.LoraParameters.Alpha
}
kv := llm.KV{
kv := fileutils.KV{
"adapter.lora.alpha": alpha,
"adapter.type": "lora",
"general.file_type": uint32(1),
@ -79,19 +79,19 @@ func (ModelParameters) specialTokenTypes() []string {
}
}
func (ModelParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
return llm.WriteGGUF(ws, kv, ts)
func (ModelParameters) writeFile(ws io.WriteSeeker, kv fileutils.KV, ts []fileutils.Tensor) error {
return fileutils.WriteGGUF(ws, kv, ts)
}
func (AdapterParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
return llm.WriteGGUF(ws, kv, ts)
func (AdapterParameters) writeFile(ws io.WriteSeeker, kv fileutils.KV, ts []fileutils.Tensor) error {
return fileutils.WriteGGUF(ws, kv, ts)
}
type ModelConverter interface {
// KV maps parameters to LLM key-values
KV(*Tokenizer) llm.KV
KV(*Tokenizer) fileutils.KV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Tensors([]Tensor) []llm.Tensor
Tensors([]Tensor) []fileutils.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string
@ -99,7 +99,7 @@ type ModelConverter interface {
// specialTokenTypes returns any special token types the model uses
specialTokenTypes() []string
// writeFile writes the model to the provided io.WriteSeeker
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
writeFile(io.WriteSeeker, fileutils.KV, []fileutils.Tensor) error
}
type moreParser interface {
@ -108,17 +108,17 @@ type moreParser interface {
type AdapterConverter interface {
// KV maps parameters to LLM key-values
KV(llm.KV) llm.KV
KV(fileutils.KV) fileutils.KV
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
Tensors([]Tensor) []llm.Tensor
Tensors([]Tensor) []fileutils.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
writeFile(io.WriteSeeker, fileutils.KV, []fileutils.Tensor) error
}
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV llm.KV) error {
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV fileutils.KV) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil {
return err

View File

@ -8,7 +8,7 @@ import (
"slices"
"strings"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/fileutils"
)
type bertModel struct {
@ -85,7 +85,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *bertModel) KV(t *Tokenizer) llm.KV {
func (p *bertModel) KV(t *Tokenizer) fileutils.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false
@ -132,8 +132,8 @@ func (p *bertModel) KV(t *Tokenizer) llm.KV {
return kv
}
func (p *bertModel) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor
func (p *bertModel) Tensors(ts []Tensor) []fileutils.Tensor {
var out []fileutils.Tensor
for _, t := range ts {
if slices.Contains([]string{
"embeddings.position_ids",
@ -143,7 +143,7 @@ func (p *bertModel) Tensors(ts []Tensor) []llm.Tensor {
continue
}
out = append(out, llm.Tensor{
out = append(out, fileutils.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@ -6,7 +6,7 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/fileutils"
)
type gemmaModel struct {
@ -23,7 +23,7 @@ type gemmaModel struct {
var _ ModelConverter = (*gemmaModel)(nil)
func (p *gemmaModel) KV(t *Tokenizer) llm.KV {
func (p *gemmaModel) KV(t *Tokenizer) fileutils.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma"
kv["gemma.context_length"] = p.MaxPositionEmbeddings
@ -42,14 +42,14 @@ func (p *gemmaModel) KV(t *Tokenizer) llm.KV {
return kv
}
func (p *gemmaModel) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor
func (p *gemmaModel) Tensors(ts []Tensor) []fileutils.Tensor {
var out []fileutils.Tensor
for _, t := range ts {
if strings.HasSuffix(t.Name(), "_norm.weight") {
t.SetRepacker(p.addOne)
}
out = append(out, llm.Tensor{
out = append(out, fileutils.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@ -1,7 +1,7 @@
package convert
import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/fileutils"
)
type gemma2Model struct {
@ -11,7 +11,7 @@ type gemma2Model struct {
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
}
func (p *gemma2Model) KV(t *Tokenizer) llm.KV {
func (p *gemma2Model) KV(t *Tokenizer) fileutils.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma2"
kv["gemma2.context_length"] = p.MaxPositionEmbeddings

View File

@ -6,7 +6,7 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/fileutils"
)
type gemma2Adapter struct {
@ -15,14 +15,14 @@ type gemma2Adapter struct {
var _ AdapterConverter = (*gemma2Adapter)(nil)
func (p *gemma2Adapter) KV(baseKV llm.KV) llm.KV {
func (p *gemma2Adapter) KV(baseKV fileutils.KV) fileutils.KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "gemma2"
return kv
}
func (p *gemma2Adapter) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor
func (p *gemma2Adapter) Tensors(ts []Tensor) []fileutils.Tensor {
var out []fileutils.Tensor
for _, t := range ts {
shape := t.Shape()
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
@ -31,7 +31,7 @@ func (p *gemma2Adapter) Tensors(ts []Tensor) []llm.Tensor {
t.SetRepacker(p.repack)
}
out = append(out, llm.Tensor{
out = append(out, fileutils.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@ -9,7 +9,7 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/fileutils"
)
type llamaModel struct {
@ -46,7 +46,7 @@ type llamaModel struct {
var _ ModelConverter = (*llamaModel)(nil)
func (p *llamaModel) KV(t *Tokenizer) llm.KV {
func (p *llamaModel) KV(t *Tokenizer) fileutils.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama"
kv["llama.vocab_size"] = p.VocabSize
@ -120,11 +120,11 @@ func (p *llamaModel) KV(t *Tokenizer) llm.KV {
return kv
}
func (p *llamaModel) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor
func (p *llamaModel) Tensors(ts []Tensor) []fileutils.Tensor {
var out []fileutils.Tensor
if p.RopeScaling.factors != nil {
out = append(out, llm.Tensor{
out = append(out, fileutils.Tensor{
Name: "rope_freqs.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
@ -138,7 +138,7 @@ func (p *llamaModel) Tensors(ts []Tensor) []llm.Tensor {
t.SetRepacker(p.repack)
}
out = append(out, llm.Tensor{
out = append(out, fileutils.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@ -7,7 +7,7 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/fileutils"
)
type llamaAdapter struct {
@ -18,7 +18,7 @@ type llamaAdapter struct {
var _ AdapterConverter = (*llamaAdapter)(nil)
func (p *llamaAdapter) KV(baseKV llm.KV) llm.KV {
func (p *llamaAdapter) KV(baseKV fileutils.KV) fileutils.KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "llama"
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
@ -29,8 +29,8 @@ func (p *llamaAdapter) KV(baseKV llm.KV) llm.KV {
return kv
}
func (p *llamaAdapter) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor
func (p *llamaAdapter) Tensors(ts []Tensor) []fileutils.Tensor {
var out []fileutils.Tensor
for _, t := range ts {
shape := t.Shape()
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
@ -41,7 +41,7 @@ func (p *llamaAdapter) Tensors(ts []Tensor) []llm.Tensor {
t.SetRepacker(p.repack)
}
out = append(out, llm.Tensor{
out = append(out, fileutils.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: shape,

View File

@ -6,7 +6,7 @@ import (
"slices"
"strings"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/fileutils"
)
type mixtralModel struct {
@ -15,7 +15,7 @@ type mixtralModel struct {
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
}
func (p *mixtralModel) KV(t *Tokenizer) llm.KV {
func (p *mixtralModel) KV(t *Tokenizer) fileutils.KV {
kv := p.llamaModel.KV(t)
if p.NumLocalExperts > 0 {
@ -29,7 +29,7 @@ func (p *mixtralModel) KV(t *Tokenizer) llm.KV {
return kv
}
func (p *mixtralModel) Tensors(ts []Tensor) []llm.Tensor {
func (p *mixtralModel) Tensors(ts []Tensor) []fileutils.Tensor {
oldnew := []string{
"model.layers", "blk",
"w1", "ffn_gate_exps",
@ -56,10 +56,10 @@ func (p *mixtralModel) Tensors(ts []Tensor) []llm.Tensor {
return true
})
var out []llm.Tensor
var out []fileutils.Tensor
for n, e := range experts {
// TODO(mxyng): sanity check experts
out = append(out, llm.Tensor{
out = append(out, fileutils.Tensor{
Name: n,
Kind: e[0].Kind(),
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),

View File

@ -8,7 +8,7 @@ import (
"strings"
"sync"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/fileutils"
)
type phi3Model struct {
@ -37,7 +37,7 @@ type phi3Model struct {
var _ ModelConverter = (*phi3Model)(nil)
func (p *phi3Model) KV(t *Tokenizer) llm.KV {
func (p *phi3Model) KV(t *Tokenizer) fileutils.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "phi3"
kv["phi3.context_length"] = p.MaxPositionEmbeddings
@ -68,19 +68,19 @@ func (p *phi3Model) KV(t *Tokenizer) llm.KV {
return kv
}
func (p *phi3Model) Tensors(ts []Tensor) []llm.Tensor {
func (p *phi3Model) Tensors(ts []Tensor) []fileutils.Tensor {
var addRopeFactors sync.Once
out := make([]llm.Tensor, 0, len(ts)+2)
out := make([]fileutils.Tensor, 0, len(ts)+2)
for _, t := range ts {
if strings.HasPrefix(t.Name(), "blk.0.") {
addRopeFactors.Do(func() {
out = append(out, llm.Tensor{
out = append(out, fileutils.Tensor{
Name: "rope_factors_long.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))},
WriterTo: p.RopeScaling.LongFactor,
}, llm.Tensor{
}, fileutils.Tensor{
Name: "rope_factors_short.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))},
@ -89,7 +89,7 @@ func (p *phi3Model) Tensors(ts []Tensor) []llm.Tensor {
})
}
out = append(out, llm.Tensor{
out = append(out, fileutils.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@ -20,7 +20,7 @@ import (
"golang.org/x/exp/maps"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/fileutils"
)
type tensorData struct {
@ -29,7 +29,7 @@ type tensorData struct {
Shape []int `json:"shape"`
}
func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, *llm.Tensors) {
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fileutils.KV, *fileutils.Tensors) {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "f16")
@ -48,7 +48,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, *llm.Tensors) {
}
t.Cleanup(func() { r.Close() })
m, _, err := llm.DecodeGGML(r, math.MaxInt)
m, _, err := fileutils.DecodeGGML(r, math.MaxInt)
if err != nil {
t.Fatal(err)
}
@ -60,7 +60,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, *llm.Tensors) {
return r, m.KV(), m.Tensors()
}
func generateResultsJSON(t *testing.T, f *os.File, kv llm.KV, tensors *llm.Tensors) map[string]string {
func generateResultsJSON(t *testing.T, f *os.File, kv fileutils.KV, tensors *fileutils.Tensors) map[string]string {
actual := make(map[string]string)
for k, v := range kv {
if s, ok := v.(json.Marshaler); !ok {
@ -330,7 +330,7 @@ func TestConvertAdapter(t *testing.T) {
}
defer r.Close()
m, _, err := llm.DecodeGGML(r, math.MaxInt)
m, _, err := fileutils.DecodeGGML(r, math.MaxInt)
if err != nil {
t.Fatal(err)
}

3
discover/README.md Normal file
View File

@ -0,0 +1,3 @@
# `discover`
This package is responsible for discovering information about the system and the capabilities to run LLM. This includes GPU and CPU discovery so the optimal runner can be chosen for a given model. The ollama scheduler relies on up-to-date available memory information, so this package provides the ability to refresh free memory as efficiently as possible.

3
fileutils/README.md Normal file
View File

@ -0,0 +1,3 @@
# `modelfile`
This package provides utilities for loading and inspecting model files

View File

@ -1,9 +1,11 @@
package llm
package fileutils
import "fmt"
type fileType uint32
// TODO this should map over to the GGML CGO enum type
const (
fileTypeF32 fileType = iota
fileTypeF16

View File

@ -1,4 +1,4 @@
package llm
package fileutils
import (
"encoding/binary"

View File

@ -1,10 +1,11 @@
package llm
package fileutils
import (
"encoding/binary"
"errors"
"fmt"
"io"
"os"
"slices"
"strings"
"sync"
@ -488,3 +489,23 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
return
}
// LoadModel will load a model from disk. The model must be in the GGML format.
//
// It collects array values for arrays with a size less than or equal to
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
// the maxArraySize is negative, all arrays are collected.
func LoadModel(model string, maxArraySize int) (*GGML, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
}
f, err := os.Open(model)
if err != nil {
return nil, err
}
defer f.Close()
ggml, _, err := DecodeGGML(f, maxArraySize)
return ggml, err
}

1
fileutils/ggml_test.go Normal file
View File

@ -0,0 +1 @@
package fileutils

View File

@ -1,4 +1,4 @@
package llm
package fileutils
import (
"bytes"

View File

@ -1,4 +1,4 @@
package llm
package fileutils
import (
"fmt"
@ -329,7 +329,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
return estimate
}
func (m MemoryEstimate) log() {
func (m MemoryEstimate) Log() {
overhead := envconfig.GpuOverhead()
log := slog.With()

View File

@ -1,4 +1,4 @@
package llm
package fileutils
import (
"bytes"

View File

@ -1 +0,0 @@
package llm

3
runners/README.md Normal file
View File

@ -0,0 +1,3 @@
# `runners`
Ollama uses a subprocess model to run one or more child processes to load the LLM. On some platforms (Linux non-containerized, MacOS) these executables are carried as payloads inside the main executable via the ../build package. Extraction and discovery of these runners at runtime is implemented in this package. This package also provides the abstraction to communicate with these subprocesses.

View File

@ -2,6 +2,7 @@ package runners
import (
"compress/gzip"
"context"
"errors"
"fmt"
"io"
@ -15,9 +16,11 @@ import (
"strings"
"sync"
"syscall"
"time"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/envconfig"
)
@ -31,6 +34,36 @@ var (
runnersDir = ""
)
type CompletionRequest struct {
Prompt string
Format string
Images []ImageData
Options *api.Options
}
type CompletionResponse struct {
Content string
DoneReason string
Done bool
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
}
type LLMServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, input string) ([]float32, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
EstimatedVRAM() uint64 // Total VRAM across all GPUs
EstimatedTotal() uint64
EstimatedVRAMByGPU(gpuID string) uint64
}
// Return the location where runners are stored
// If runners are payloads, this will either extract them
// or refresh them if any have disappeared due to tmp cleaners

View File

@ -1,4 +1,4 @@
package llm
package runners
import (
"bufio"
@ -28,24 +28,11 @@ import (
"github.com/ollama/ollama/build"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fileutils"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/runners"
)
type LlamaServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, input string) ([]float32, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
EstimatedVRAM() uint64 // Total VRAM across all GPUs
EstimatedTotal() uint64
EstimatedVRAMByGPU(gpuID string) uint64
}
// llmServer is an instance of the llama.cpp server
type llmServer struct {
port int
@ -58,7 +45,7 @@ type llmServer struct {
modelLock sync.Mutex // Temporary until we switch fully to Go server
model *llama.Model // If non-nil, the runner is a new Go server
estimate MemoryEstimate
estimate fileutils.MemoryEstimate
totalLayers uint64
// gpuCount int
gpus discover.GpuInfoList // Recorded just before the model loaded, free space will be incorrect
@ -68,32 +55,12 @@ type llmServer struct {
sem *semaphore.Weighted
}
// LoadModel will load a model from disk. The model must be in the GGML format.
//
// It collects array values for arrays with a size less than or equal to
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
// the maxArraySize is negative, all arrays are collected.
func LoadModel(model string, maxArraySize int) (*GGML, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
}
f, err := os.Open(model)
if err != nil {
return nil, err
}
defer f.Close()
ggml, _, err := DecodeGGML(f, maxArraySize)
return ggml, err
}
// NewLlamaServer will run a server for the given GPUs
// The gpu list must be a single family.
func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *fileutils.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LLMServer, error) {
var err error
var cpuRunner string
var estimate MemoryEstimate
var estimate fileutils.MemoryEstimate
var systemTotalMemory uint64
var systemFreeMemory uint64
var systemSwapFreeMemory uint64
@ -109,10 +76,10 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
gpus = discover.GetCPUInfo()
}
if len(gpus) == 1 && gpus[0].Library == "cpu" {
cpuRunner = runners.ServerForCpu()
estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
cpuRunner = ServerForCpu()
estimate = fileutils.EstimateGPULayers(gpus, ggml, projectors, opts)
} else {
estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
estimate = fileutils.EstimateGPULayers(gpus, ggml, projectors, opts)
switch {
case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
@ -121,7 +88,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
opts.NumGPU = 0
case gpus[0].Library != "metal" && estimate.Layers == 0:
// Don't bother loading into the GPU if no layers can fit
cpuRunner = runners.ServerForCpu()
cpuRunner = ServerForCpu()
gpus = discover.GetCPUInfo()
case opts.NumGPU < 0 && estimate.Layers > 0 && gpus[0].Library != "cpu":
opts.NumGPU = estimate.Layers
@ -139,7 +106,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
}
}
estimate.log()
estimate.Log()
// Loop through potential servers
finalErr := errors.New("no suitable llama servers found")
@ -148,12 +115,12 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
}
rDir, err := runners.Refresh(build.EmbedFS)
rDir, err := Refresh(build.EmbedFS)
if err != nil {
return nil, err
}
availableServers := runners.GetAvailableServers(rDir)
availableServers := GetAvailableServers(rDir)
if len(availableServers) == 0 {
return nil, finalErr
}
@ -161,7 +128,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
if cpuRunner != "" {
servers = []string{cpuRunner}
} else {
servers = runners.ServersForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant
servers = ServersForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant
}
demandLib := envconfig.LLMLibrary()
if demandLib != "" {
@ -325,7 +292,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
_, err := os.Stat(server)
if errors.Is(err, os.ErrNotExist) {
slog.Warn("llama server disappeared, reinitializing payloads", "path", server, "error", err)
_, err = runners.Refresh(build.EmbedFS)
_, err = Refresh(build.EmbedFS)
if err != nil {
slog.Warn("failed to reinitialize payloads", "error", err)
return nil, err
@ -673,23 +640,6 @@ type completion struct {
}
}
type CompletionRequest struct {
Prompt string
Format string
Images []ImageData
Options *api.Options
}
type CompletionResponse struct {
Content string
DoneReason string
Done bool
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)

View File

@ -1,4 +1,4 @@
package llm
package runners
import (
"bytes"

View File

@ -1,4 +1,4 @@
package llm
package runners
import (
"syscall"

View File

@ -1,4 +1,4 @@
package llm
package runners
import (
"syscall"

View File

@ -1,4 +1,4 @@
package llm
package runners
import (
"syscall"

View File

@ -25,9 +25,9 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fileutils"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
@ -91,7 +91,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
defer f.Close()
// TODO(mxyng): decode the GGML into model to avoid doing this multiple times
ggml, _, err := llm.DecodeGGML(f, 0)
ggml, _, err := fileutils.DecodeGGML(f, 0)
if err != nil {
slog.Error("couldn't decode ggml", "error", err)
continue
@ -431,7 +431,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
baseLayer.MediaType == "application/vnd.ollama.image.model" &&
baseLayer.GGML != nil &&
baseLayer.GGML.Name() == "gguf" {
want, err := llm.ParseFileType(quantization)
want, err := fileutils.ParseFileType(quantization)
if err != nil {
return err
}
@ -467,7 +467,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
return err
}
ggml, _, err := llm.DecodeGGML(temp, 0)
ggml, _, err := fileutils.DecodeGGML(temp, 0)
if err != nil {
return err
}

View File

@ -18,7 +18,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/fileutils"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
@ -27,7 +27,7 @@ var intermediateBlobs map[string]string = make(map[string]string)
type layerGGML struct {
Layer
*llm.GGML
*fileutils.GGML
}
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
@ -67,7 +67,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
}
defer blob.Close()
ggml, _, err := llm.DecodeGGML(blob, 0)
ggml, _, err := fileutils.DecodeGGML(blob, 0)
if err != nil {
return nil, err
}
@ -112,7 +112,7 @@ func parseFromZipFile(_ context.Context, command string, baseLayers []*layerGGML
switch command {
case "adapter":
var baseModel *llm.GGML
var baseModel *fileutils.GGML
for _, l := range baseLayers {
if l.GGML != nil {
baseModel = l.GGML
@ -150,7 +150,7 @@ func parseFromZipFile(_ context.Context, command string, baseLayers []*layerGGML
}
defer bin.Close()
ggml, _, err := llm.DecodeGGML(bin, 0)
ggml, _, err := fileutils.DecodeGGML(bin, 0)
if err != nil {
return nil, err
}
@ -184,7 +184,7 @@ func parseFromFile(ctx context.Context, command string, baseLayers []*layerGGML,
var offset int64
for offset < stat.Size() {
ggml, n, err := llm.DecodeGGML(file, 0)
ggml, n, err := fileutils.DecodeGGML(file, 0)
if errors.Is(err, io.EOF) {
break
} else if err != nil {
@ -263,7 +263,7 @@ func detectContentType(r io.Reader) (string, error) {
return "", err
}
if contentType := llm.DetectGGMLType(b.Bytes()); contentType != "" {
if contentType := fileutils.DetectGGMLType(b.Bytes()); contentType != "" {
return contentType, nil
}

View File

@ -13,7 +13,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/fileutils"
"github.com/ollama/ollama/template"
)
@ -147,7 +147,7 @@ func TestParseFromFileFromLayer(t *testing.T) {
t.Fatalf("failed to open file: %v", err)
}
defer file.Close()
if err := llm.WriteGGUF(file, llm.KV{"general.architecture": "gemma"}, []llm.Tensor{}); err != nil {
if err := fileutils.WriteGGUF(file, fileutils.KV{"general.architecture": "gemma"}, []fileutils.Tensor{}); err != nil {
t.Fatalf("failed to write gguf: %v", err)
}
@ -200,7 +200,7 @@ func TestParseLayerFromCopy(t *testing.T) {
defer file2.Close()
for range 5 {
if err := llm.WriteGGUF(file2, llm.KV{"general.architecture": "gemma"}, []llm.Tensor{}); err != nil {
if err := fileutils.WriteGGUF(file2, fileutils.KV{"general.architecture": "gemma"}, []fileutils.Tensor{}); err != nil {
t.Fatalf("failed to write gguf: %v", err)
}
}

View File

@ -10,7 +10,7 @@ import (
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/runners"
"github.com/ollama/ollama/server/imageproc"
"github.com/ollama/ollama/template"
)
@ -22,7 +22,7 @@ var errTooManyImages = errors.New("vision model only supports a single image per
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []runners.ImageData, _ error) {
var system []api.Message
isMllama := checkMllamaModelFamily(m)
@ -90,7 +90,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
return "", nil, err
}
imgData := llm.ImageData{
imgData := runners.ImageData{
Data: buf.Bytes(),
AspectRatioID: aspectRatioID,
}
@ -105,7 +105,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
prefix := ""
prompt := msg.Content
for _, i := range msg.Images {
imgData := llm.ImageData{
imgData := runners.ImageData{
ID: len(images),
Data: i,
}

View File

@ -29,7 +29,7 @@ import (
"github.com/ollama/ollama/build"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/fileutils"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/runners"
@ -78,7 +78,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (runners.LLMServer, *Model, *api.Options, error) {
if name == "" {
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
}
@ -187,9 +187,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
images := make([]llm.ImageData, len(req.Images))
images := make([]runners.ImageData, len(req.Images))
for i := range req.Images {
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
images[i] = runners.ImageData{ID: i, Data: req.Images[i]}
}
prompt := req.Prompt
@ -255,12 +255,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// TODO (jmorganca): avoid building the response twice both here and below
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
if err := r.Completion(c.Request.Context(), runners.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(cr llm.CompletionResponse) {
}, func(cr runners.CompletionResponse) {
res := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
@ -639,7 +639,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
}
if r.Path == "" && r.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or fileutils are required"})
return
}
@ -647,7 +647,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
if r.Path != "" && r.Modelfile == "" {
f, err := os.Open(r.Path)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading fileutils: %s", err)})
return
}
defer f.Close()
@ -851,12 +851,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil
}
func getKVData(digest string, verbose bool) (llm.KV, error) {
func getKVData(digest string, verbose bool) (fileutils.KV, error) {
maxArraySize := 0
if verbose {
maxArraySize = -1
}
kvData, err := llm.LoadModel(digest, maxArraySize)
kvData, err := fileutils.LoadModel(digest, maxArraySize)
if err != nil {
return nil, err
}
@ -1436,12 +1436,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
ch := make(chan any)
go func() {
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
if err := r.Completion(c.Request.Context(), runners.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(r llm.CompletionResponse) {
}, func(r runners.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),

View File

@ -16,12 +16,12 @@ import (
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/fileutils"
)
var stream bool = false
func createBinFile(t *testing.T, kv map[string]any, ti []llm.Tensor) string {
func createBinFile(t *testing.T, kv map[string]any, ti []fileutils.Tensor) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "")
@ -30,7 +30,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []llm.Tensor) string {
}
defer f.Close()
if err := llm.WriteGGUF(f, kv, ti); err != nil {
if err := fileutils.WriteGGUF(f, kv, ti); err != nil {
t.Fatal(err)
}
@ -581,7 +581,7 @@ func TestCreateDetectTemplate(t *testing.T) {
t.Run("matched", func(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, fileutils.KV{
"tokenizer.chat_template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
}, nil)),
Stream: &stream,

View File

@ -16,18 +16,19 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/fileutils"
"github.com/ollama/ollama/runners"
)
type mockRunner struct {
llm.LlamaServer
runners.LLMServer
// CompletionRequest is only valid until the next call to Completion
llm.CompletionRequest
llm.CompletionResponse
runners.CompletionRequest
runners.CompletionResponse
}
func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
func (m *mockRunner) Completion(_ context.Context, r runners.CompletionRequest, fn func(r runners.CompletionResponse)) error {
m.CompletionRequest = r
fn(m.CompletionResponse)
return nil
@ -41,8 +42,8 @@ func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error
return
}
func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
return func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *fileutils.GGML, []string, []string, api.Options, int) (runners.LLMServer, error) {
return func(gpus discover.GpuInfoList, model string, ggml *fileutils.GGML, projectors, system []string, opts api.Options, numParallel int) (runners.LLMServer, error) {
return mock, nil
}
}
@ -51,7 +52,7 @@ func TestGenerateChat(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
CompletionResponse: runners.CompletionResponse{
Done: true,
DoneReason: "stop",
PromptEvalCount: 1,
@ -72,7 +73,7 @@ func TestGenerateChat(t *testing.T) {
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
loadFn: func(req *LlmRequest, ggml *fileutils.GGML, gpus discover.GpuInfoList, numParallel int) {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
@ -91,7 +92,7 @@ func TestGenerateChat(t *testing.T) {
{{- if .System }}System: {{ .System }} {{ end }}
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
`, createBinFile(t, llm.KV{
`, createBinFile(t, fileutils.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
@ -101,7 +102,7 @@ func TestGenerateChat(t *testing.T) {
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
}, []fileutils.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
@ -146,10 +147,10 @@ func TestGenerateChat(t *testing.T) {
t.Run("missing capabilities chat", func(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "bert",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, fileutils.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(0),
}, []llm.Tensor{})),
}, []fileutils.Tensor{})),
Stream: &stream,
})
@ -349,7 +350,7 @@ func TestGenerate(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
CompletionResponse: runners.CompletionResponse{
Done: true,
DoneReason: "stop",
PromptEvalCount: 1,
@ -370,7 +371,7 @@ func TestGenerate(t *testing.T) {
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
loadFn: func(req *LlmRequest, ggml *fileutils.GGML, gpus discover.GpuInfoList, numParallel int) {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
@ -389,7 +390,7 @@ func TestGenerate(t *testing.T) {
{{- if .System }}System: {{ .System }} {{ end }}
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
`, createBinFile(t, llm.KV{
`, createBinFile(t, fileutils.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
@ -399,7 +400,7 @@ func TestGenerate(t *testing.T) {
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
}, []fileutils.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
@ -444,10 +445,10 @@ func TestGenerate(t *testing.T) {
t.Run("missing capabilities generate", func(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "bert",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, fileutils.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(0),
}, []llm.Tensor{})),
}, []fileutils.Tensor{})),
Stream: &stream,
})

View File

@ -16,7 +16,7 @@ import (
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/fileutils"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/types/model"
@ -83,14 +83,14 @@ func Test_Routes(t *testing.T) {
fname := createTestFile(t, "ollama-model")
r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
modelfile, err := parser.ParseFile(r)
fileutils, err := parser.ParseFile(r)
if err != nil {
t.Fatalf("failed to parse file: %v", err)
}
fn := func(resp api.ProgressResponse) {
t.Logf("Status: %s", resp.Status)
}
err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn)
err = CreateModel(context.TODO(), model.ParseName(name), "", "", fileutils, fn)
if err != nil {
t.Fatalf("failed to create model: %v", err)
}
@ -561,8 +561,8 @@ func TestShow(t *testing.T) {
Name: "show-model",
Modelfile: fmt.Sprintf(
"FROM %s\nFROM %s",
createBinFile(t, llm.KV{"general.architecture": "test"}, nil),
createBinFile(t, llm.KV{"general.type": "projector", "general.architecture": "clip"}, nil),
createBinFile(t, fileutils.KV{"general.architecture": "test"}, nil),
createBinFile(t, fileutils.KV{"general.type": "projector", "general.architecture": "clip"}, nil),
),
})

View File

@ -17,8 +17,9 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fileutils"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/runners"
)
type LlmRequest struct {
@ -41,8 +42,8 @@ type Scheduler struct {
loaded map[string]*runnerRef
loadedMu sync.Mutex
loadFn func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int)
newServerFn func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
loadFn func(req *LlmRequest, ggml *fileutils.GGML, gpus discover.GpuInfoList, numParallel int)
newServerFn func(gpus discover.GpuInfoList, model string, ggml *fileutils.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (runners.LLMServer, error)
getGpuFn func() discover.GpuInfoList
getCpuFn func() discover.GpuInfoList
reschedDelay time.Duration
@ -68,7 +69,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
expiredCh: make(chan *runnerRef, maxQueue),
unloadedCh: make(chan interface{}, maxQueue),
loaded: make(map[string]*runnerRef),
newServerFn: llm.NewLlamaServer,
newServerFn: runners.NewLlamaServer,
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
@ -187,7 +188,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
}
// Load model for fitting
ggml, err := llm.LoadModel(pending.model.ModelPath, 0)
ggml, err := fileutils.LoadModel(pending.model.ModelPath, 0)
if err != nil {
pending.errCh <- err
break
@ -409,7 +410,7 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
}()
}
func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
func (s *Scheduler) load(req *LlmRequest, ggml *fileutils.GGML, gpus discover.GpuInfoList, numParallel int) {
if numParallel < 1 {
numParallel = 1
}
@ -422,7 +423,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoL
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(err, llm.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") {
if errors.Is(err, fileutils.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
}
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
@ -540,7 +541,7 @@ type runnerRef struct {
refCount uint // prevent unloading if > 0
// unloading bool // set to true when we are trying to unload the runner
llama llm.LlamaServer
llama runners.LLMServer
loading bool // True only during initial load, then false forever
gpus discover.GpuInfoList // Recorded at time of provisioning
estimatedVRAM uint64
@ -685,7 +686,7 @@ func (a ByDuration) Less(i, j int) bool {
// If the model can not be fit fully within the available GPU(s) nil is returned
// If numParallel is <= 0, this will attempt try to optimize parallism based on available VRAM, and adjust
// opts.NumCtx accordingly
func pickBestFullFitByLibrary(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel *int) discover.GpuInfoList {
func pickBestFullFitByLibrary(req *LlmRequest, ggml *fileutils.GGML, gpus discover.GpuInfoList, numParallel *int) discover.GpuInfoList {
var estimatedVRAM uint64
var numParallelToTry []int
@ -710,7 +711,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, ggml *llm.GGML, gpus discover.Gpu
req.opts.NumCtx = req.origNumCtx * p
if !envconfig.SchedSpread() {
for _, g := range sgl {
if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
if ok, estimatedVRAM = fileutils.PredictServerFit([]discover.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
*numParallel = p
return []discover.GpuInfo{g}
@ -726,7 +727,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, ggml *llm.GGML, gpus discover.Gpu
// Now try all the GPUs
for _, p := range numParallelToTry {
req.opts.NumCtx = req.origNumCtx * p
if ok, estimatedVRAM = llm.PredictServerFit(sgl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
if ok, estimatedVRAM = fileutils.PredictServerFit(sgl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM))
*numParallel = p
return sgl
@ -737,7 +738,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, ggml *llm.GGML, gpus discover.Gpu
}
// If multiple Libraries are detected, pick the Library which loads the most layers for the model
func pickBestPartialFitByLibrary(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel *int) discover.GpuInfoList {
func pickBestPartialFitByLibrary(req *LlmRequest, ggml *fileutils.GGML, gpus discover.GpuInfoList, numParallel *int) discover.GpuInfoList {
if *numParallel <= 0 {
*numParallel = 1
req.opts.NumCtx = req.origNumCtx
@ -749,7 +750,7 @@ func pickBestPartialFitByLibrary(req *LlmRequest, ggml *llm.GGML, gpus discover.
var bestEstimate uint64
var bestFit int
for i, gl := range byLibrary {
_, estimatedVRAM := llm.PredictServerFit(gl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
_, estimatedVRAM := fileutils.PredictServerFit(gl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
if estimatedVRAM > bestEstimate {
bestEstimate = estimatedVRAM
bestFit = i
@ -822,9 +823,9 @@ func (s *Scheduler) expireRunner(model *Model) {
// If other runners are loaded, make sure the pending request will fit in system memory
// If not, pick a runner to unload, else return nil and the request can be loaded
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList) *runnerRef {
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, ggml *fileutils.GGML, gpus discover.GpuInfoList) *runnerRef {
slog.Debug("evaluating if CPU model load will fit in available system memory")
estimate := llm.EstimateGPULayers(gpus, ggml, req.model.ProjectorPaths, req.opts)
estimate := fileutils.EstimateGPULayers(gpus, ggml, req.model.ProjectorPaths, req.opts)
if estimate.TotalSize <= gpus[0].FreeMemory {
slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory))
return nil

View File

@ -14,8 +14,9 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/fileutils"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/runners"
)
func TestMain(m *testing.M) {
@ -37,7 +38,7 @@ func TestLoad(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer done()
s := InitScheduler(ctx)
var ggml *llm.GGML // value not used in tests
var ggml *fileutils.GGML // value not used in tests
req := &LlmRequest{
ctx: ctx,
model: &Model{ModelPath: "foo"},
@ -47,7 +48,7 @@ func TestLoad(t *testing.T) {
sessionDuration: &api.Duration{Duration: 2 * time.Second},
}
// Fail to load model first
s.newServerFn = func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
s.newServerFn = func(gpus discover.GpuInfoList, model string, ggml *fileutils.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (runners.LLMServer, error) {
return nil, errors.New("something failed to load model blah")
}
gpus := discover.GpuInfoList{}
@ -61,7 +62,7 @@ func TestLoad(t *testing.T) {
require.Contains(t, err.Error(), "this model may be incompatible")
server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}}
s.newServerFn = func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
s.newServerFn = func(gpus discover.GpuInfoList, model string, ggml *fileutils.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (runners.LLMServer, error) {
return server, nil
}
s.load(req, ggml, gpus, 0)
@ -99,10 +100,10 @@ type reqBundle struct {
ctxDone func()
srv *mockLlm
req *LlmRequest
ggml *llm.GGML
ggml *fileutils.GGML
}
func (scenario *reqBundle) newServer(gpus discover.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
func (scenario *reqBundle) newServer(gpus discover.GpuInfoList, model string, ggml *fileutils.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (runners.LLMServer, error) {
return scenario.srv, nil
}
@ -115,7 +116,7 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
require.NoError(t, err)
defer f.Close()
require.NoError(t, llm.WriteGGUF(f, llm.KV{
require.NoError(t, fileutils.WriteGGUF(f, fileutils.KV{
"general.architecture": "llama",
"llama.context_length": uint32(32),
"llama.embedding_length": uint32(4096),
@ -125,7 +126,7 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
"tokenizer.ggml.tokens": []string{" "},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
}, []fileutils.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
}))
@ -133,7 +134,7 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname}
b.ggml, err = llm.LoadModel(model.ModelPath, 0)
b.ggml, err = fileutils.LoadModel(model.ModelPath, 0)
require.NoError(t, err)
if duration == nil {
@ -419,10 +420,10 @@ func TestExpireRunner(t *testing.T) {
sessionDuration: &api.Duration{Duration: 2 * time.Minute},
}
var ggml *llm.GGML
var ggml *fileutils.GGML
gpus := discover.GpuInfoList{}
server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}}
s.newServerFn = func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
s.newServerFn = func(gpus discover.GpuInfoList, model string, ggml *fileutils.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (runners.LLMServer, error) {
return server, nil
}
s.load(req, ggml, gpus, 0)
@ -729,7 +730,7 @@ func TestHomogeneousGPUs(t *testing.T) {
}
s.getCpuFn = getCpuFn
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
s.newServerFn = func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
s.newServerFn = func(gpus discover.GpuInfoList, model string, ggml *fileutils.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (runners.LLMServer, error) {
require.Len(t, gpus, 1)
return a.newServer(gpus, model, ggml, adapters, projectors, opts, numParallel)
}
@ -768,7 +769,7 @@ type mockLlm struct {
func (s *mockLlm) Ping(ctx context.Context) error { return s.pingResp }
func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitResp }
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
func (s *mockLlm) Completion(ctx context.Context, req runners.CompletionRequest, fn func(runners.CompletionResponse)) error {
return s.completionResp
}

View File

@ -14,7 +14,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/fileutils"
)
func TestNamed(t *testing.T) {
@ -33,7 +33,7 @@ func TestNamed(t *testing.T) {
for k, v := range ss {
t.Run(k, func(t *testing.T) {
kv := llm.KV{"tokenizer.chat_template": v}
kv := fileutils.KV{"tokenizer.chat_template": v}
s := kv.ChatTemplate()
r, err := Named(s)
if err != nil {