diff --git a/llama/llama.go b/llama/llama.go index 9b75e388..2acc4f58 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -4,10 +4,10 @@ package llama // #cgo CXXFLAGS: -std=c++11 -DNDEBUG -DLOG_DISABLE_LOGS // #cgo darwin,arm64 CFLAGS: -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 // #cgo darwin,arm64 CXXFLAGS: -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 -// #cgo darwin,arm64 LDFLAGS: -ld_classic ${SRCDIR}/ggml-metal.o -framework Foundation -framework Metal -framework MetalKit -framework Accelerate +// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/ggml-metal.o -framework Foundation -framework Metal -framework MetalKit -framework Accelerate // #cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers // #cgo darwin,amd64 CXXFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers -// #cgo darwin,amd64 LDFLAGS: -ld_classic -framework Foundation -framework Accelerate +// #cgo darwin,amd64 LDFLAGS: -framework Foundation -framework Accelerate // #cgo linux CFLAGS: -D_GNU_SOURCE // #cgo linux CXXFLAGS: -D_GNU_SOURCE // #cgo windows LDFLAGS: -lmsvcrt @@ -29,11 +29,14 @@ package llama // #include "clip.h" // #include "llava.h" // #include "sampling_ext.h" +// +// bool llamaProgressCallback(float progress, void *user_data); import "C" import ( "errors" "fmt" "runtime" + "runtime/cgo" "strings" "unsafe" ) @@ -65,10 +68,26 @@ type ModelParams struct { c C.struct_llama_model_params } -func NewModelParams(numGpuLayers int, mainGpu int) ModelParams { +//export llamaProgressCallback +func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool { + handle := cgo.Handle(userData) + callback := handle.Value().(func(float32)) + callback(float32(progress)) + return true +} + +func NewModelParams(numGpuLayers int, mainGpu int, callback func(float32)) ModelParams { params := C.llama_model_default_params() params.n_gpu_layers = C.int(numGpuLayers) params.main_gpu = C.int32_t(mainGpu) + + handle := cgo.NewHandle(callback) + params.progress_callback = C.llama_progress_callback(C.llamaProgressCallback) + params.progress_callback_user_data = unsafe.Pointer(handle) + runtime.SetFinalizer(¶ms, func(p *C.struct_llama_model_params) { + handle.Delete() + }) + return ModelParams{c: params} } @@ -233,7 +252,8 @@ func (m *Model) TokenToPiece(token int) string { return strings.TrimRight(string(buf), "\x00") } -func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpecial bool) ([]int, error) { +func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) { + maxTokens := len(text) + 2 cTokens := make([]C.llama_token, maxTokens) cText := C.CString(text) defer C.free(unsafe.Pointer(cText)) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index f029473b..7692d1c4 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "log/slog" + "math" "net" "net/http" "runtime" @@ -28,6 +29,9 @@ type Sequence struct { // channel to send responses over responses chan string + // number of tokens to predict + numPredict int + samplingCtx *llama.SamplingContext // channel to send back the embedding if embedding only @@ -38,6 +42,8 @@ type Sequence struct { // true if an embedding are to be returned instead of text generation embeddingOnly bool + + doneReason string } // prompt returns true if the prompt is still being processed @@ -46,11 +52,18 @@ func (s *Sequence) prompt() bool { } func (s *Server) NewSequence(prompt string, stop []string, params *llama.SamplingParams, embedding bool) *Sequence { - tokens, err := s.lc.Model().Tokenize(prompt, 2048, false, true) + tokens, err := s.lc.Model().Tokenize(prompt, false, true) if err != nil { panic(err) } + // truncate to last n tokens + // TODO: this shouldn't happen and will severely impact generation + // quality. instead we should ensure to cut prompt in the API. + if len(tokens) > s.numCtx { + tokens = tokens[:s.numCtx] + } + var sc *llama.SamplingContext if params != nil { sc = llama.NewSamplingContext(*params) @@ -83,9 +96,16 @@ type Server struct { // TODO (jmorganca): this can probably be moved into run() seqs []*Sequence + // context window size + numCtx int + mu sync.Mutex cond *sync.Cond + + progress float32 + + status string } func (s *Server) allNil() bool { @@ -183,6 +203,15 @@ func (s *Server) run(ctx context.Context) { continue } + // we've reached the context limit + if seq.nPast > s.numCtx { + seq.doneReason = "limit" + close(seq.responses) + s.lc.KvCacheSeqRm(i, 0, -1) + s.seqs[i] = nil + continue + } + for j, t := range seq.tokens { // todo: make this n_batch if j > s.batchSize { @@ -252,6 +281,7 @@ func (s *Server) run(ctx context.Context) { // as it's important for the /api/generate context // seq.responses <- piece + seq.doneReason = "stop" close(seq.responses) seq.samplingCtx.Free() pieces[i] = []string{} @@ -273,6 +303,7 @@ func (s *Server) run(ctx context.Context) { } s.lc.KvCacheSeqRm(i, 0, -1) + seq.doneReason = "stop" close(seq.responses) seq.samplingCtx.Free() pieces[i] = []string{} @@ -411,6 +442,24 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { } } +type HealthResponse struct { + Status string `json:"status"` + Progress float32 `json:"progress"` +} + +// TODO (jmorganca): is it safe to do this concurrently with decoding? +func (s *Server) health(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + if err := json.NewEncoder(w).Encode(&HealthResponse{ + Status: s.status, + Progress: s.progress, + }); err != nil { + log.Println("Failed to encode result:", err) + return + } +} + func main() { mpath := flag.String("model", "", "Path to model binary file") ppath := flag.String("projector", "", "Path to projector binary file") @@ -425,36 +474,31 @@ func main() { threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation") flag.Parse() - // load the model - llama.BackendInit() - params := llama.NewModelParams(*nGpuLayers, *mainGpu) - model := llama.LoadModelFromFile(*mpath, params) - - if *lpath != "" { - model.ApplyLoraFromFile(*lpath, 1.0, "", *threads) - } - - ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention) - lc := llama.NewContextWithModel(model, ctxParams) - if lc == nil { - panic("Failed to create context") - } - - var cc *llama.ClipContext - if *ppath != "" { - cc = llama.NewClipContext(*ppath) - if cc == nil { - panic("Failed to create clip context") - } - } - server := &Server{ - model: model, - lc: lc, - cc: cc, + numCtx: *numCtx, batchSize: *batchSize, parallel: *parallel, seqs: make([]*Sequence, *parallel), + status: "loading", + } + + // load the model + llama.BackendInit() + params := llama.NewModelParams(*nGpuLayers, *mainGpu, func(progress float32) { + slog.Info("Loading model", "progress %", math.Round(float64(progress*100))) + server.progress = progress + }) + server.model = llama.LoadModelFromFile(*mpath, params) + + if *lpath != "" { + server.model.ApplyLoraFromFile(*lpath, 1.0, "", *threads) + } + + ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention) + server.lc = llama.NewContextWithModel(server.model, ctxParams) + + if *ppath != "" { + server.cc = llama.NewClipContext(*ppath) } server.cond = sync.NewCond(&server.mu) @@ -473,11 +517,14 @@ func main() { mux := http.NewServeMux() mux.HandleFunc("/embeddings", server.embeddings) mux.HandleFunc("/completion", server.completion) + mux.HandleFunc("/health", server.health) httpServer := http.Server{ Handler: mux, } + server.status = "ready" + log.Println("Server listening on", addr) if err := httpServer.Serve(listener); err != nil { log.Fatal("server error:", err)