add llava to runner

This commit is contained in:
jmorganca 2024-05-23 18:22:15 -07:00
parent 87af27dac0
commit fbc8572859
3 changed files with 137 additions and 44 deletions

View File

@ -38,10 +38,6 @@ import (
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
) )
type Token int32
type Pos int32
type SeqId int32
// SystemInfo is an unused example of calling llama.cpp functions using CGo // SystemInfo is an unused example of calling llama.cpp functions using CGo
func PrintSystemInfo() string { func PrintSystemInfo() string {
return C.GoString(C.llama_print_system_info()) return C.GoString(C.llama_print_system_info())
@ -78,6 +74,10 @@ type Context struct {
c *C.struct_llama_context c *C.struct_llama_context
} }
func (c *Context) KvCacheClear() {
C.llama_kv_cache_clear(c.c)
}
func (c *Context) Decode(batch Batch) error { func (c *Context) Decode(batch Batch) error {
// Positive return values does not mean a fatal error, but rather a warning. // Positive return values does not mean a fatal error, but rather a warning.
// 0 - success // 0 - success
@ -90,18 +90,18 @@ func (c *Context) Decode(batch Batch) error {
} }
if code > 0 { if code > 0 {
return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d\n", code) return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d", code)
} }
return nil return nil
} }
func (c *Context) GetModel() *Model { func (c *Context) Model() *Model {
return &Model{c: C.llama_get_model(c.c)} return &Model{c: C.llama_get_model(c.c)}
} }
func (c *Context) SampleTokenGreedy(batch Batch) Token { func (c *Context) SampleTokenGreedy(batch Batch) int {
nv := c.GetModel().NumVocab() nv := c.Model().NumVocab()
// TODO(jmorganca): split this up into different functions // TODO(jmorganca): split this up into different functions
candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(nv) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{})))) candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(nv) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{}))))
@ -116,7 +116,7 @@ func (c *Context) SampleTokenGreedy(batch Batch) Token {
ptr.p = 0.0 ptr.p = 0.0
} }
return Token(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{ return int(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{
data: candidates, data: candidates,
size: C.size_t(nv), size: C.size_t(nv),
sorted: C.bool(false), sorted: C.bool(false),
@ -135,7 +135,7 @@ func (m *Model) NumVocab() int {
return int(C.llama_n_vocab(m.c)) return int(C.llama_n_vocab(m.c))
} }
func (m *Model) TokenIsEog(token Token) bool { func (m *Model) TokenIsEog(token int) bool {
return bool(C.llama_token_is_eog(m.c, C.llama_token(token))) return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
} }
@ -151,7 +151,7 @@ func (b *Batch) NumTokens() int {
return int(b.c.n_tokens) return int(b.c.n_tokens)
} }
func (b *Batch) Add(token Token, pos Pos, seqIds []SeqId, logits bool) { func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) {
unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token) unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token)
unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos) unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos)
unsafe.Slice(b.c.n_seq_id, 512)[b.c.n_tokens] = C.int(len(seqIds)) unsafe.Slice(b.c.n_seq_id, 512)[b.c.n_tokens] = C.int(len(seqIds))
@ -171,13 +171,17 @@ func (b *Batch) Clear() {
b.c.n_tokens = 0 b.c.n_tokens = 0
} }
func (b *Batch) Free() {
C.llama_batch_free(b.c)
}
// LLAMA_API struct llama_batch llama_batch_get_one( // LLAMA_API struct llama_batch llama_batch_get_one(
// //
// llama_token * tokens, // llama_token * tokens,
// int32_t n_tokens, // int32_t n_tokens,
// llama_pos pos_0, // llama_pos pos_0,
// llama_seq_id seq_id); // llama_seq_id seq_id);
func BatchGetOne(tokens []Token, pos0 Pos, seqId SeqId) Batch { func BatchGetOne(tokens []int, pos0 int, seqId int) Batch {
return Batch{c: C.llama_batch_get_one((*C.int)(unsafe.Pointer(&tokens[0])), C.int32_t(len(tokens)), C.int(pos0), C.int(seqId))} return Batch{c: C.llama_batch_get_one((*C.int)(unsafe.Pointer(&tokens[0])), C.int32_t(len(tokens)), C.int(pos0), C.int(seqId))}
} }
@ -185,7 +189,7 @@ type Model struct {
c *C.struct_llama_model c *C.struct_llama_model
} }
func (m *Model) TokenToPiece(token Token) string { func (m *Model) TokenToPiece(token int) string {
buf := make([]byte, 12) buf := make([]byte, 12)
C.llama_token_to_piece( C.llama_token_to_piece(
m.c, m.c,
@ -197,7 +201,7 @@ func (m *Model) TokenToPiece(token Token) string {
return strings.TrimRight(string(buf), "\x00") return strings.TrimRight(string(buf), "\x00")
} }
func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpecial bool) ([]Token, error) { func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpecial bool) ([]int, error) {
cTokens := make([]C.llama_token, maxTokens) cTokens := make([]C.llama_token, maxTokens)
cText := C.CString(text) cText := C.CString(text)
defer C.free(unsafe.Pointer(cText)) defer C.free(unsafe.Pointer(cText))
@ -216,9 +220,9 @@ func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpeci
return nil, fmt.Errorf("tokenization failed, required %d tokens", -result) return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
} }
tokens := make([]Token, result) tokens := make([]int, result)
for i := 0; i < int(result); i++ { for i := 0; i < int(result); i++ {
tokens[i] = Token(cTokens[i]) tokens[i] = int(cTokens[i])
} }
return tokens, nil return tokens, nil

View File

@ -56,12 +56,12 @@ func main() {
} }
func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, after string) error { func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, after string) error {
beforeTokens, err := lc.GetModel().Tokenize(before, 2048, true, true) beforeTokens, err := lc.Model().Tokenize(before, 2048, true, true)
if err != nil { if err != nil {
return err return err
} }
afterTokens, err := lc.GetModel().Tokenize(after, 2048, true, true) afterTokens, err := lc.Model().Tokenize(after, 2048, true, true)
if err != nil { if err != nil {
return err return err
} }
@ -73,7 +73,7 @@ func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, af
// prompt eval // prompt eval
for _, t := range beforeTokens { for _, t := range beforeTokens {
batch.Add(t, llama.Pos(nPast), []llama.SeqId{0}, true) batch.Add(t, nPast, []int{0}, true)
nPast++ nPast++
} }
@ -88,7 +88,7 @@ func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, af
batch = llama.NewBatch(512, 0, 1) batch = llama.NewBatch(512, 0, 1)
for _, t := range afterTokens { for _, t := range afterTokens {
batch.Add(t, llama.Pos(nPast), []llama.SeqId{0}, true) batch.Add(t, nPast, []int{0}, true)
} }
// main loop // main loop
@ -102,15 +102,15 @@ func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, af
token := lc.SampleTokenGreedy(batch) token := lc.SampleTokenGreedy(batch)
// if it's an end of sequence token, break // if it's an end of sequence token, break
if lc.GetModel().TokenIsEog(token) { if lc.Model().TokenIsEog(token) {
break break
} }
// print the token // print the token
str := lc.GetModel().TokenToPiece(token) str := lc.Model().TokenToPiece(token)
fmt.Print(str) fmt.Print(str)
batch.Clear() batch.Clear()
batch.Add(token, llama.Pos(n), []llama.SeqId{0}, true) batch.Add(token, n, []int{0}, true)
} }
return nil return nil

View File

@ -1,19 +1,24 @@
package main package main
import ( import (
"encoding/base64"
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"log" "log"
"log/slog"
"net" "net"
"net/http" "net/http"
"regexp"
"strconv"
"sync" "sync"
"github.com/ollama/ollama/llama" "github.com/ollama/ollama/llama"
) )
type Request struct { type Request struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Images []string `json:"images"`
} }
type Response struct { type Response struct {
@ -23,6 +28,7 @@ type Response struct {
type Server struct { type Server struct {
model *llama.Model model *llama.Model
lc *llama.Context lc *llama.Context
cc *llama.ClipContext
} }
var mu sync.Mutex var mu sync.Mutex
@ -34,6 +40,9 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
return return
} }
mu.Lock()
defer mu.Unlock()
// Set the headers to indicate streaming // Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked") w.Header().Set("Transfer-Encoding", "chunked")
@ -41,30 +50,69 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
enc := json.NewEncoder(w) enc := json.NewEncoder(w)
// main loop // create embeddings for each image
tokens, err := s.model.Tokenize(request.Prompt, 2048, true, true) var embeddings []*llama.LlavaImageEmbed
if err != nil { if s.cc != nil {
panic(err) for _, img := range request.Images {
data, err := base64.StdEncoding.DecodeString(img)
if err != nil {
http.Error(w, "Failed to decode image", http.StatusBadRequest)
return
}
embd := llama.NewLlavaImageEmbed(s.cc, data)
embeddings = append(embeddings, embd)
}
}
var nPast int
// eval the prompt
re := regexp.MustCompile(`\[\s*img-(\d+)\s*\]`)
matches := re.FindAllStringSubmatchIndex(request.Prompt, -1)
// eval each chunk including images
pos := 0
for _, match := range matches {
part := request.Prompt[pos:match[0]]
fmt.Println("Text part:", part)
// eval text before image
err := s.evalText(part, &nPast)
if err != nil {
log.Println("Failed to eval text:", err)
return
}
// eval image
imgIndexStr := request.Prompt[match[2]:match[3]]
imgIndex, err := strconv.Atoi(imgIndexStr)
if err != nil {
slog.Warn("Failed to parse image index", "index", imgIndexStr)
continue
}
fmt.Println("Tag index:", imgIndex)
if imgIndex <= len(embeddings) {
slog.Info("evaluating image", "index", imgIndex)
llama.LlavaEvalImageEmbed(s.lc, embeddings[imgIndex], 512, &nPast)
}
pos = match[1]
}
// eval remaining text
if pos < len(request.Prompt) {
s.evalText(request.Prompt[pos:], &nPast)
} }
batch := llama.NewBatch(512, 0, 1) batch := llama.NewBatch(512, 0, 1)
defer batch.Free()
// prompt eval
for i, t := range tokens {
batch.Add(t, llama.Pos(i), []llama.SeqId{0}, true)
}
// main loop // main loop
for n := batch.NumTokens(); n < 2048; n++ { for n := nPast; n < 2048; n++ {
mu.Lock()
err = s.lc.Decode(batch)
if err != nil {
panic("Failed to decode")
}
// sample a token // sample a token
token := s.lc.SampleTokenGreedy(batch) token := s.lc.SampleTokenGreedy(batch)
mu.Unlock()
// if it's an end of sequence token, break // if it's an end of sequence token, break
if s.model.TokenIsEog(token) { if s.model.TokenIsEog(token) {
@ -81,27 +129,44 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
w.(http.Flusher).Flush() w.(http.Flusher).Flush()
batch.Clear() batch.Clear()
batch.Add(token, llama.Pos(n), []llama.SeqId{0}, true) batch.Add(token, n, []int{0}, true)
err := s.lc.Decode(batch)
if err != nil {
panic("Failed to decode")
}
} }
s.lc.KvCacheClear()
} }
func main() { func main() {
mp := flag.String("model", "", "Path to model binary file") mpath := flag.String("model", "", "Path to model binary file")
ppath := flag.String("projector", "", "Path to projector binary file")
flag.Parse() flag.Parse()
// load the model // load the model
llama.BackendInit() llama.BackendInit()
params := llama.NewModelParams() params := llama.NewModelParams()
model := llama.LoadModelFromFile(*mp, params) model := llama.LoadModelFromFile(*mpath, params)
ctxParams := llama.NewContextParams() ctxParams := llama.NewContextParams()
lc := llama.NewContextWithModel(model, ctxParams) lc := llama.NewContextWithModel(model, ctxParams)
if lc == nil { if lc == nil {
panic("Failed to create context") panic("Failed to create context")
} }
var cc *llama.ClipContext
if ppath != nil {
cc = llama.NewClipContext(*ppath)
if cc == nil {
panic("Failed to create clip context")
}
}
server := &Server{ server := &Server{
model: model, model: model,
lc: lc, lc: lc,
cc: cc,
} }
addr := "127.0.0.1:8080" addr := "127.0.0.1:8080"
@ -121,3 +186,27 @@ func main() {
log.Fatal("server error:", err) log.Fatal("server error:", err)
} }
} }
func (s *Server) evalText(text string, nPast *int) error {
// eval before
batch := llama.NewBatch(512, 0, 1)
defer batch.Free()
tokens, err := s.lc.Model().Tokenize(text, 2048, true, true)
if err != nil {
return fmt.Errorf("tokenize failed: %w", err)
}
// prompt eval
for _, t := range tokens {
batch.Add(t, *nPast, []int{0}, true)
*nPast++
}
err = s.lc.Decode(batch)
if err != nil {
return fmt.Errorf("decode failed: %w", err)
}
return nil
}