add more runner params

This commit is contained in:
jmorganca 2024-05-28 00:02:01 -07:00
parent 72f3fe4b94
commit 20afaae020
2 changed files with 52 additions and 16 deletions

View File

@ -31,6 +31,7 @@ package llama
// #include "sampling_ext.h"
import "C"
import (
"errors"
"fmt"
"runtime"
"strings"
@ -49,13 +50,14 @@ type ContextParams struct {
c C.struct_llama_context_params
}
func NewContextParams() ContextParams {
func NewContextParams(numCtx int, threads int, flashAttention bool) ContextParams {
params := C.llama_context_default_params()
params.seed = C.uint(1234)
params.n_ctx = C.uint(2048)
params.n_ctx = C.uint(numCtx)
params.n_threads = C.uint(runtime.NumCPU())
params.n_threads_batch = params.n_threads
params.embeddings = C.bool(true)
params.flash_attn = C.bool(flashAttention)
params.n_threads = C.uint(threads)
return ContextParams{c: params}
}
@ -63,9 +65,10 @@ type ModelParams struct {
c C.struct_llama_model_params
}
func NewModelParams() ModelParams {
func NewModelParams(numGpuLayers int, mainGpu int) ModelParams {
params := C.llama_model_default_params()
params.n_gpu_layers = 999
params.n_gpu_layers = C.int(numGpuLayers)
params.main_gpu = C.int32_t(mainGpu)
return ModelParams{c: params}
}
@ -155,6 +158,23 @@ func (m *Model) TokenIsEog(token int) bool {
return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
}
func (m *Model) ApplyLoraFromFile(loraPath string, scale float32, baseModelPath string, threads int) error {
cLoraPath := C.CString(loraPath)
defer C.free(unsafe.Pointer(cLoraPath))
var cBaseModelPath *C.char
if baseModelPath != "" {
cBaseModelPath = C.CString(baseModelPath)
}
code := int(C.llama_model_apply_lora_from_file(m.c, cLoraPath, C.float(scale), cBaseModelPath, C.int32_t(threads)))
if code != 0 {
return errors.New("error applying lora from file")
}
return nil
}
type Batch struct {
c C.struct_llama_batch
}

View File

@ -9,6 +9,7 @@ import (
"log/slog"
"net"
"net/http"
"runtime"
"strconv"
"strings"
"sync"
@ -73,6 +74,8 @@ type Server struct {
lc *llama.Context
cc *llama.ClipContext
batchSize int
// parallel is the number of parallel requests to handle
parallel int
@ -154,7 +157,7 @@ func truncateStop(pieces []string, stop string) []string {
}
func (s *Server) run(ctx context.Context) {
batch := llama.NewBatch(512, 0, s.parallel)
batch := llama.NewBatch(s.batchSize, 0, s.parallel)
defer batch.Free()
// build up stop sequences as we recognize them
@ -182,7 +185,7 @@ func (s *Server) run(ctx context.Context) {
for j, t := range seq.tokens {
// todo: make this n_batch
if j > 512 {
if j > s.batchSize {
break
}
@ -207,10 +210,10 @@ func (s *Server) run(ctx context.Context) {
// don't sample prompt processing
if seq.prompt() {
if len(seq.tokens) < 512 {
if len(seq.tokens) < s.batchSize {
seq.tokens = []int{}
} else {
seq.tokens = seq.tokens[512:]
seq.tokens = seq.tokens[s.batchSize:]
}
continue
@ -412,14 +415,26 @@ func main() {
mpath := flag.String("model", "", "Path to model binary file")
ppath := flag.String("projector", "", "Path to projector binary file")
parallel := flag.Int("parallel", 1, "Number of sequences to handle simultaneously")
batchSize := flag.Int("batch-size", 512, "Batch size")
nGpuLayers := flag.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
mainGpu := flag.Int("main-gpu", 0, "Main GPU")
flashAttention := flag.Bool("flash-attention", false, "Enable flash attention")
numCtx := flag.Int("num-ctx", 2048, "Context (or KV cache) size")
lpath := flag.String("lora", "", "Path to lora layer file")
port := flag.Int("port", 8080, "Port to expose the server on")
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
flag.Parse()
// load the model
llama.BackendInit()
params := llama.NewModelParams()
params := llama.NewModelParams(*nGpuLayers, *mainGpu)
model := llama.LoadModelFromFile(*mpath, params)
ctxParams := llama.NewContextParams()
if *lpath != "" {
model.ApplyLoraFromFile(*lpath, 1.0, "", *threads)
}
ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
lc := llama.NewContextWithModel(model, ctxParams)
if lc == nil {
panic("Failed to create context")
@ -434,11 +449,12 @@ func main() {
}
server := &Server{
model: model,
lc: lc,
cc: cc,
parallel: *parallel,
seqs: make([]*Sequence, *parallel),
model: model,
lc: lc,
cc: cc,
batchSize: *batchSize,
parallel: *parallel,
seqs: make([]*Sequence, *parallel),
}
server.cond = sync.NewCond(&server.mu)