From fbc85728597be69ec930703af51aa2a696abd1c9 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Thu, 23 May 2024 18:22:15 -0700 Subject: [PATCH] add `llava` to `runner` --- llama/llama.go | 36 ++++++----- llama/llava/main.go | 14 ++--- llama/runner/runner.go | 131 ++++++++++++++++++++++++++++++++++------- 3 files changed, 137 insertions(+), 44 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index b2897582..48bfc119 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -38,10 +38,6 @@ import ( "github.com/ollama/ollama/llm" ) -type Token int32 -type Pos int32 -type SeqId int32 - // SystemInfo is an unused example of calling llama.cpp functions using CGo func PrintSystemInfo() string { return C.GoString(C.llama_print_system_info()) @@ -78,6 +74,10 @@ type Context struct { c *C.struct_llama_context } +func (c *Context) KvCacheClear() { + C.llama_kv_cache_clear(c.c) +} + func (c *Context) Decode(batch Batch) error { // Positive return values does not mean a fatal error, but rather a warning. // 0 - success @@ -90,18 +90,18 @@ func (c *Context) Decode(batch Batch) error { } if code > 0 { - return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d\n", code) + return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d", code) } return nil } -func (c *Context) GetModel() *Model { +func (c *Context) Model() *Model { return &Model{c: C.llama_get_model(c.c)} } -func (c *Context) SampleTokenGreedy(batch Batch) Token { - nv := c.GetModel().NumVocab() +func (c *Context) SampleTokenGreedy(batch Batch) int { + nv := c.Model().NumVocab() // TODO(jmorganca): split this up into different functions candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(nv) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{})))) @@ -116,7 +116,7 @@ func (c *Context) SampleTokenGreedy(batch Batch) Token { ptr.p = 0.0 } - return Token(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{ + return int(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{ data: candidates, size: C.size_t(nv), sorted: C.bool(false), @@ -135,7 +135,7 @@ func (m *Model) NumVocab() int { return int(C.llama_n_vocab(m.c)) } -func (m *Model) TokenIsEog(token Token) bool { +func (m *Model) TokenIsEog(token int) bool { return bool(C.llama_token_is_eog(m.c, C.llama_token(token))) } @@ -151,7 +151,7 @@ func (b *Batch) NumTokens() int { return int(b.c.n_tokens) } -func (b *Batch) Add(token Token, pos Pos, seqIds []SeqId, logits bool) { +func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) { unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token) unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos) unsafe.Slice(b.c.n_seq_id, 512)[b.c.n_tokens] = C.int(len(seqIds)) @@ -171,13 +171,17 @@ func (b *Batch) Clear() { b.c.n_tokens = 0 } +func (b *Batch) Free() { + C.llama_batch_free(b.c) +} + // LLAMA_API struct llama_batch llama_batch_get_one( // // llama_token * tokens, // int32_t n_tokens, // llama_pos pos_0, // llama_seq_id seq_id); -func BatchGetOne(tokens []Token, pos0 Pos, seqId SeqId) Batch { +func BatchGetOne(tokens []int, pos0 int, seqId int) Batch { return Batch{c: C.llama_batch_get_one((*C.int)(unsafe.Pointer(&tokens[0])), C.int32_t(len(tokens)), C.int(pos0), C.int(seqId))} } @@ -185,7 +189,7 @@ type Model struct { c *C.struct_llama_model } -func (m *Model) TokenToPiece(token Token) string { +func (m *Model) TokenToPiece(token int) string { buf := make([]byte, 12) C.llama_token_to_piece( m.c, @@ -197,7 +201,7 @@ func (m *Model) TokenToPiece(token Token) string { return strings.TrimRight(string(buf), "\x00") } -func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpecial bool) ([]Token, error) { +func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpecial bool) ([]int, error) { cTokens := make([]C.llama_token, maxTokens) cText := C.CString(text) defer C.free(unsafe.Pointer(cText)) @@ -216,9 +220,9 @@ func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpeci return nil, fmt.Errorf("tokenization failed, required %d tokens", -result) } - tokens := make([]Token, result) + tokens := make([]int, result) for i := 0; i < int(result); i++ { - tokens[i] = Token(cTokens[i]) + tokens[i] = int(cTokens[i]) } return tokens, nil diff --git a/llama/llava/main.go b/llama/llava/main.go index f28f4416..60331490 100644 --- a/llama/llava/main.go +++ b/llama/llava/main.go @@ -56,12 +56,12 @@ func main() { } func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, after string) error { - beforeTokens, err := lc.GetModel().Tokenize(before, 2048, true, true) + beforeTokens, err := lc.Model().Tokenize(before, 2048, true, true) if err != nil { return err } - afterTokens, err := lc.GetModel().Tokenize(after, 2048, true, true) + afterTokens, err := lc.Model().Tokenize(after, 2048, true, true) if err != nil { return err } @@ -73,7 +73,7 @@ func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, af // prompt eval for _, t := range beforeTokens { - batch.Add(t, llama.Pos(nPast), []llama.SeqId{0}, true) + batch.Add(t, nPast, []int{0}, true) nPast++ } @@ -88,7 +88,7 @@ func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, af batch = llama.NewBatch(512, 0, 1) for _, t := range afterTokens { - batch.Add(t, llama.Pos(nPast), []llama.SeqId{0}, true) + batch.Add(t, nPast, []int{0}, true) } // main loop @@ -102,15 +102,15 @@ func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, af token := lc.SampleTokenGreedy(batch) // if it's an end of sequence token, break - if lc.GetModel().TokenIsEog(token) { + if lc.Model().TokenIsEog(token) { break } // print the token - str := lc.GetModel().TokenToPiece(token) + str := lc.Model().TokenToPiece(token) fmt.Print(str) batch.Clear() - batch.Add(token, llama.Pos(n), []llama.SeqId{0}, true) + batch.Add(token, n, []int{0}, true) } return nil diff --git a/llama/runner/runner.go b/llama/runner/runner.go index c5d27c52..6cb9e5af 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -1,19 +1,24 @@ package main import ( + "encoding/base64" "encoding/json" "flag" "fmt" "log" + "log/slog" "net" "net/http" + "regexp" + "strconv" "sync" "github.com/ollama/ollama/llama" ) type Request struct { - Prompt string `json:"prompt"` + Prompt string `json:"prompt"` + Images []string `json:"images"` } type Response struct { @@ -23,6 +28,7 @@ type Response struct { type Server struct { model *llama.Model lc *llama.Context + cc *llama.ClipContext } var mu sync.Mutex @@ -34,6 +40,9 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) { return } + mu.Lock() + defer mu.Unlock() + // Set the headers to indicate streaming w.Header().Set("Content-Type", "application/json") w.Header().Set("Transfer-Encoding", "chunked") @@ -41,30 +50,69 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) { enc := json.NewEncoder(w) - // main loop - tokens, err := s.model.Tokenize(request.Prompt, 2048, true, true) - if err != nil { - panic(err) + // create embeddings for each image + var embeddings []*llama.LlavaImageEmbed + if s.cc != nil { + for _, img := range request.Images { + data, err := base64.StdEncoding.DecodeString(img) + if err != nil { + http.Error(w, "Failed to decode image", http.StatusBadRequest) + return + } + + embd := llama.NewLlavaImageEmbed(s.cc, data) + embeddings = append(embeddings, embd) + } + } + + var nPast int + + // eval the prompt + re := regexp.MustCompile(`\[\s*img-(\d+)\s*\]`) + matches := re.FindAllStringSubmatchIndex(request.Prompt, -1) + + // eval each chunk including images + pos := 0 + for _, match := range matches { + part := request.Prompt[pos:match[0]] + fmt.Println("Text part:", part) + + // eval text before image + err := s.evalText(part, &nPast) + if err != nil { + log.Println("Failed to eval text:", err) + return + } + + // eval image + imgIndexStr := request.Prompt[match[2]:match[3]] + imgIndex, err := strconv.Atoi(imgIndexStr) + if err != nil { + slog.Warn("Failed to parse image index", "index", imgIndexStr) + continue + } + + fmt.Println("Tag index:", imgIndex) + if imgIndex <= len(embeddings) { + slog.Info("evaluating image", "index", imgIndex) + llama.LlavaEvalImageEmbed(s.lc, embeddings[imgIndex], 512, &nPast) + } + + pos = match[1] + } + + // eval remaining text + if pos < len(request.Prompt) { + s.evalText(request.Prompt[pos:], &nPast) } batch := llama.NewBatch(512, 0, 1) - - // prompt eval - for i, t := range tokens { - batch.Add(t, llama.Pos(i), []llama.SeqId{0}, true) - } + defer batch.Free() // main loop - for n := batch.NumTokens(); n < 2048; n++ { - mu.Lock() - err = s.lc.Decode(batch) - if err != nil { - panic("Failed to decode") - } - + for n := nPast; n < 2048; n++ { // sample a token token := s.lc.SampleTokenGreedy(batch) - mu.Unlock() // if it's an end of sequence token, break if s.model.TokenIsEog(token) { @@ -81,27 +129,44 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) { w.(http.Flusher).Flush() batch.Clear() - batch.Add(token, llama.Pos(n), []llama.SeqId{0}, true) + batch.Add(token, n, []int{0}, true) + + err := s.lc.Decode(batch) + if err != nil { + panic("Failed to decode") + } } + + s.lc.KvCacheClear() } func main() { - mp := flag.String("model", "", "Path to model binary file") + mpath := flag.String("model", "", "Path to model binary file") + ppath := flag.String("projector", "", "Path to projector binary file") flag.Parse() // load the model llama.BackendInit() params := llama.NewModelParams() - model := llama.LoadModelFromFile(*mp, params) + model := llama.LoadModelFromFile(*mpath, params) ctxParams := llama.NewContextParams() lc := llama.NewContextWithModel(model, ctxParams) if lc == nil { panic("Failed to create context") } + var cc *llama.ClipContext + if ppath != nil { + cc = llama.NewClipContext(*ppath) + if cc == nil { + panic("Failed to create clip context") + } + } + server := &Server{ model: model, lc: lc, + cc: cc, } addr := "127.0.0.1:8080" @@ -121,3 +186,27 @@ func main() { log.Fatal("server error:", err) } } + +func (s *Server) evalText(text string, nPast *int) error { + // eval before + batch := llama.NewBatch(512, 0, 1) + defer batch.Free() + + tokens, err := s.lc.Model().Tokenize(text, 2048, true, true) + if err != nil { + return fmt.Errorf("tokenize failed: %w", err) + } + + // prompt eval + for _, t := range tokens { + batch.Add(t, *nPast, []int{0}, true) + *nPast++ + } + + err = s.lc.Decode(batch) + if err != nil { + return fmt.Errorf("decode failed: %w", err) + } + + return nil +}