From ec17359a68c19a8cc6a4a908258cacff698fa4f9 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Fri, 24 May 2024 10:09:35 -0700 Subject: [PATCH] wip --- llama/ggml-metal-darwin_arm64.m | 34 ++-- llama/llama.go | 18 +- llama/runner/runner.go | 298 +++++++++++++++++++------------- 3 files changed, 202 insertions(+), 148 deletions(-) diff --git a/llama/ggml-metal-darwin_arm64.m b/llama/ggml-metal-darwin_arm64.m index d465e522..f4c2c412 100644 --- a/llama/ggml-metal-darwin_arm64.m +++ b/llama/ggml-metal-darwin_arm64.m @@ -1499,27 +1499,27 @@ static enum ggml_status ggml_metal_graph_compute( // to the matrix-vector kernel int ne11_mm_min = 1; -#if 0 + // the numbers below are measured on M2 Ultra for 7B and 13B models // these numbers do not translate to other devices or model sizes // TODO: need to find a better approach - if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) { - switch (src0t) { - case GGML_TYPE_F16: ne11_mm_min = 2; break; - case GGML_TYPE_Q8_0: ne11_mm_min = 7; break; - case GGML_TYPE_Q2_K: ne11_mm_min = 15; break; - case GGML_TYPE_Q3_K: ne11_mm_min = 7; break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: ne11_mm_min = 15; break; - case GGML_TYPE_Q4_K: ne11_mm_min = 11; break; - case GGML_TYPE_Q5_0: // not tested yet - case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet - case GGML_TYPE_Q5_K: ne11_mm_min = 7; break; - case GGML_TYPE_Q6_K: ne11_mm_min = 7; break; - default: ne11_mm_min = 1; break; - } + // if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) { + switch (src0t) { + case GGML_TYPE_F16: ne11_mm_min = 2; break; + case GGML_TYPE_Q8_0: ne11_mm_min = 7; break; + case GGML_TYPE_Q2_K: ne11_mm_min = 15; break; + case GGML_TYPE_Q3_K: ne11_mm_min = 7; break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: ne11_mm_min = 15; break; + case GGML_TYPE_Q4_K: ne11_mm_min = 11; break; + case GGML_TYPE_Q5_0: // not tested yet + case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet + case GGML_TYPE_Q5_K: ne11_mm_min = 7; break; + case GGML_TYPE_Q6_K: ne11_mm_min = 7; break; + default: ne11_mm_min = 1; break; } -#endif + // } + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel diff --git a/llama/llama.go b/llama/llama.go index 48bfc119..421cff0c 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -38,15 +38,14 @@ import ( "github.com/ollama/ollama/llm" ) -// SystemInfo is an unused example of calling llama.cpp functions using CGo -func PrintSystemInfo() string { - return C.GoString(C.llama_print_system_info()) -} - func BackendInit() { C.llama_backend_init() } +func PrintSystemInfo() string { + return C.GoString(C.llama_print_system_info()) +} + type ContextParams struct { c C.struct_llama_context_params } @@ -100,7 +99,8 @@ func (c *Context) Model() *Model { return &Model{c: C.llama_get_model(c.c)} } -func (c *Context) SampleTokenGreedy(batch Batch) int { +// TODO: break this up +func (c *Context) SampleTokenGreedy(batch Batch, i int) int { nv := c.Model().NumVocab() // TODO(jmorganca): split this up into different functions @@ -108,7 +108,7 @@ func (c *Context) SampleTokenGreedy(batch Batch) int { defer C.free(unsafe.Pointer(candidates)) // get most recent logits - logits := C.llama_get_logits_ith(c.c, C.int(batch.NumTokens()-1)) + logits := C.llama_get_logits_ith(c.c, C.int(i)) for i := 0; i < int(nv); i++ { ptr := (*C.struct_llama_token_data)(unsafe.Pointer(uintptr(unsafe.Pointer(candidates)) + uintptr(i)*unsafe.Sizeof(C.struct_llama_token_data{}))) ptr.id = C.int(i) @@ -123,6 +123,10 @@ func (c *Context) SampleTokenGreedy(batch Batch) int { })) } +func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool { + return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1))) +} + func LoadModelFromFile(modelPath string, params ModelParams) *Model { return &Model{c: C.llama_load_model_from_file(C.CString(modelPath), params.c)} } diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 6cb9e5af..85d03f0f 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -1,7 +1,7 @@ package main import ( - "encoding/base64" + "context" "encoding/json" "flag" "fmt" @@ -9,13 +9,149 @@ import ( "log/slog" "net" "net/http" - "regexp" - "strconv" "sync" "github.com/ollama/ollama/llama" ) +type Sequence struct { + // number of tokens evaluated + nPast int + + // tokens left to evaluate + tokens []int + + responses chan string +} + +// prompt returns true if the prompt is still being processed +func (s *Sequence) prompt() bool { + return s.nPast < len(s.tokens)-1 +} + +func (s *Server) NewSequence(text string, w http.ResponseWriter) *Sequence { + tokens, err := s.lc.Model().Tokenize(text, 2048, true, true) + if err != nil { + panic(err) + } + + return &Sequence{ + tokens: tokens, + responses: make(chan string, 1), + } +} + +type Server struct { + model *llama.Model + lc *llama.Context + cc *llama.ClipContext + + // parallel is the number of parallel requests to handle + parallel int + + // seqs is the list of parallel sequences being evaluated + seqs []*Sequence + + mu sync.Mutex + + cond *sync.Cond +} + +func (s *Server) allNil() bool { + for _, item := range s.seqs { + if item != nil { + return false + } + } + return true +} + +func (s *Server) run(ctx context.Context) { + batch := llama.NewBatch(512, 0, s.parallel) + defer batch.Free() + + for { + select { + case <-ctx.Done(): + return + default: + slog.Info("Processing batch", "seqs", len(s.seqs)) + s.mu.Lock() + for s.allNil() { + fmt.Println("wait") + s.cond.Wait() // Wait until an item is added + } + s.mu.Unlock() + + fmt.Println("seqs", s.seqs, len(s.seqs)) + + // prepare the batch + ibatch := make([]int, s.parallel) + for i, seq := range s.seqs { + if seq == nil { + continue + } + + for j, t := range seq.tokens { + // todo: make this n_batch + if j > 512 { + break + } + + batch.Add(t, seq.nPast, []int{i}, !seq.prompt()) + seq.nPast++ + + if seq.prompt() { + ibatch[i] = batch.NumTokens() + 1 + } + } + } + + err := s.lc.Decode(batch) + if err != nil { + panic("Failed to decode") + } + + for i, seq := range s.seqs { + if seq == nil { + continue + } + + // don't sample prompt processing + if seq.prompt() { + if len(seq.tokens) < 512 { + seq.tokens = []int{} + } else { + seq.tokens = seq.tokens[512:] + } + + continue + } + + // sample a token + // TODO: sample based on the sequence + fmt.Println("Sampling token", i, ibatch[i]) + token := s.lc.SampleTokenGreedy(batch, ibatch[i]) + + // if it's an end of sequence token, break + // TODO: just end this sequence + if s.model.TokenIsEog(token) { + // TODO: end the sequence instead of quitting the pool + s.lc.KvCacheSeqRm(i, 0, -1) + close(seq.responses) + s.seqs[i] = nil + continue + } + + seq.responses <- s.model.TokenToPiece(token) + seq.tokens = []int{token} + } + + batch.Clear() + } + } +} + type Request struct { Prompt string `json:"prompt"` Images []string `json:"images"` @@ -25,124 +161,53 @@ type Response struct { Token string `json:"token"` } -type Server struct { - model *llama.Model - lc *llama.Context - cc *llama.ClipContext -} - -var mu sync.Mutex - -func (s *Server) stream(w http.ResponseWriter, r *http.Request) { +func (s *Server) handler(w http.ResponseWriter, r *http.Request) { var request Request if err := json.NewDecoder(r.Body).Decode(&request); err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return } - mu.Lock() - defer mu.Unlock() - // Set the headers to indicate streaming w.Header().Set("Content-Type", "application/json") w.Header().Set("Transfer-Encoding", "chunked") w.WriteHeader(http.StatusOK) - enc := json.NewEncoder(w) + seq := s.NewSequence(request.Prompt, w) - // create embeddings for each image - var embeddings []*llama.LlavaImageEmbed - if s.cc != nil { - for _, img := range request.Images { - data, err := base64.StdEncoding.DecodeString(img) - if err != nil { - http.Error(w, "Failed to decode image", http.StatusBadRequest) - return - } - - embd := llama.NewLlavaImageEmbed(s.cc, data) - embeddings = append(embeddings, embd) - } - } - - var nPast int - - // eval the prompt - re := regexp.MustCompile(`\[\s*img-(\d+)\s*\]`) - matches := re.FindAllStringSubmatchIndex(request.Prompt, -1) - - // eval each chunk including images - pos := 0 - for _, match := range matches { - part := request.Prompt[pos:match[0]] - fmt.Println("Text part:", part) - - // eval text before image - err := s.evalText(part, &nPast) - if err != nil { - log.Println("Failed to eval text:", err) - return - } - - // eval image - imgIndexStr := request.Prompt[match[2]:match[3]] - imgIndex, err := strconv.Atoi(imgIndexStr) - if err != nil { - slog.Warn("Failed to parse image index", "index", imgIndexStr) - continue - } - - fmt.Println("Tag index:", imgIndex) - if imgIndex <= len(embeddings) { - slog.Info("evaluating image", "index", imgIndex) - llama.LlavaEvalImageEmbed(s.lc, embeddings[imgIndex], 512, &nPast) - } - - pos = match[1] - } - - // eval remaining text - if pos < len(request.Prompt) { - s.evalText(request.Prompt[pos:], &nPast) - } - - batch := llama.NewBatch(512, 0, 1) - defer batch.Free() - - // main loop - for n := nPast; n < 2048; n++ { - // sample a token - token := s.lc.SampleTokenGreedy(batch) - - // if it's an end of sequence token, break - if s.model.TokenIsEog(token) { + s.mu.Lock() + for i, sq := range s.seqs { + if sq == nil { + s.seqs[i] = seq + fmt.Println("signal") + s.cond.Signal() break } + } + s.mu.Unlock() - // print the token - str := s.model.TokenToPiece(token) - - if err := enc.Encode(&Response{Token: str}); err != nil { + for token := range seq.responses { + if err := json.NewEncoder(w).Encode(&Response{ + Token: token, + }); err != nil { log.Println("Failed to encode result:", err) return } - w.(http.Flusher).Flush() - batch.Clear() - batch.Add(token, n, []int{0}, true) - - err := s.lc.Decode(batch) - if err != nil { - panic("Failed to decode") + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming not supported", http.StatusInternalServerError) + return } - } - s.lc.KvCacheClear() + flusher.Flush() + } } func main() { mpath := flag.String("model", "", "Path to model binary file") ppath := flag.String("projector", "", "Path to projector binary file") + parallel := flag.Int("parallel", 1, "Number of sequences to handle simultaneously") flag.Parse() // load the model @@ -156,7 +221,7 @@ func main() { } var cc *llama.ClipContext - if ppath != nil { + if *ppath != "" { cc = llama.NewClipContext(*ppath) if cc == nil { panic("Failed to create clip context") @@ -164,11 +229,18 @@ func main() { } server := &Server{ - model: model, - lc: lc, - cc: cc, + model: model, + lc: lc, + cc: cc, + parallel: *parallel, + seqs: make([]*Sequence, *parallel), } + server.cond = sync.NewCond(&server.mu) + + ctx, cancel := context.WithCancel(context.Background()) + go server.run(ctx) + addr := "127.0.0.1:8080" listener, err := net.Listen("tcp", addr) if err != nil { @@ -178,35 +250,13 @@ func main() { defer listener.Close() httpServer := http.Server{ - Handler: http.HandlerFunc(server.stream), + Handler: http.HandlerFunc(server.handler), } log.Println("Server listening on", addr) if err := httpServer.Serve(listener); err != nil { log.Fatal("server error:", err) } -} - -func (s *Server) evalText(text string, nPast *int) error { - // eval before - batch := llama.NewBatch(512, 0, 1) - defer batch.Free() - - tokens, err := s.lc.Model().Tokenize(text, 2048, true, true) - if err != nil { - return fmt.Errorf("tokenize failed: %w", err) - } - - // prompt eval - for _, t := range tokens { - batch.Add(t, *nPast, []int{0}, true) - *nPast++ - } - - err = s.lc.Decode(batch) - if err != nil { - return fmt.Errorf("decode failed: %w", err) - } - - return nil + + cancel() }