add llava to runner

This commit is contained in:
jmorganca 2024-05-23 18:22:15 -07:00
parent 87af27dac0
commit fbc8572859
3 changed files with 137 additions and 44 deletions

View File

@ -38,10 +38,6 @@ import (
"github.com/ollama/ollama/llm"
)
type Token int32
type Pos int32
type SeqId int32
// SystemInfo is an unused example of calling llama.cpp functions using CGo
func PrintSystemInfo() string {
return C.GoString(C.llama_print_system_info())
@ -78,6 +74,10 @@ type Context struct {
c *C.struct_llama_context
}
func (c *Context) KvCacheClear() {
C.llama_kv_cache_clear(c.c)
}
func (c *Context) Decode(batch Batch) error {
// Positive return values does not mean a fatal error, but rather a warning.
// 0 - success
@ -90,18 +90,18 @@ func (c *Context) Decode(batch Batch) error {
}
if code > 0 {
return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d\n", code)
return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d", code)
}
return nil
}
func (c *Context) GetModel() *Model {
func (c *Context) Model() *Model {
return &Model{c: C.llama_get_model(c.c)}
}
func (c *Context) SampleTokenGreedy(batch Batch) Token {
nv := c.GetModel().NumVocab()
func (c *Context) SampleTokenGreedy(batch Batch) int {
nv := c.Model().NumVocab()
// TODO(jmorganca): split this up into different functions
candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(nv) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{}))))
@ -116,7 +116,7 @@ func (c *Context) SampleTokenGreedy(batch Batch) Token {
ptr.p = 0.0
}
return Token(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{
return int(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{
data: candidates,
size: C.size_t(nv),
sorted: C.bool(false),
@ -135,7 +135,7 @@ func (m *Model) NumVocab() int {
return int(C.llama_n_vocab(m.c))
}
func (m *Model) TokenIsEog(token Token) bool {
func (m *Model) TokenIsEog(token int) bool {
return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
}
@ -151,7 +151,7 @@ func (b *Batch) NumTokens() int {
return int(b.c.n_tokens)
}
func (b *Batch) Add(token Token, pos Pos, seqIds []SeqId, logits bool) {
func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) {
unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token)
unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos)
unsafe.Slice(b.c.n_seq_id, 512)[b.c.n_tokens] = C.int(len(seqIds))
@ -171,13 +171,17 @@ func (b *Batch) Clear() {
b.c.n_tokens = 0
}
func (b *Batch) Free() {
C.llama_batch_free(b.c)
}
// LLAMA_API struct llama_batch llama_batch_get_one(
//
// llama_token * tokens,
// int32_t n_tokens,
// llama_pos pos_0,
// llama_seq_id seq_id);
func BatchGetOne(tokens []Token, pos0 Pos, seqId SeqId) Batch {
func BatchGetOne(tokens []int, pos0 int, seqId int) Batch {
return Batch{c: C.llama_batch_get_one((*C.int)(unsafe.Pointer(&tokens[0])), C.int32_t(len(tokens)), C.int(pos0), C.int(seqId))}
}
@ -185,7 +189,7 @@ type Model struct {
c *C.struct_llama_model
}
func (m *Model) TokenToPiece(token Token) string {
func (m *Model) TokenToPiece(token int) string {
buf := make([]byte, 12)
C.llama_token_to_piece(
m.c,
@ -197,7 +201,7 @@ func (m *Model) TokenToPiece(token Token) string {
return strings.TrimRight(string(buf), "\x00")
}
func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpecial bool) ([]Token, error) {
func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpecial bool) ([]int, error) {
cTokens := make([]C.llama_token, maxTokens)
cText := C.CString(text)
defer C.free(unsafe.Pointer(cText))
@ -216,9 +220,9 @@ func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpeci
return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
}
tokens := make([]Token, result)
tokens := make([]int, result)
for i := 0; i < int(result); i++ {
tokens[i] = Token(cTokens[i])
tokens[i] = int(cTokens[i])
}
return tokens, nil

View File

@ -56,12 +56,12 @@ func main() {
}
func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, after string) error {
beforeTokens, err := lc.GetModel().Tokenize(before, 2048, true, true)
beforeTokens, err := lc.Model().Tokenize(before, 2048, true, true)
if err != nil {
return err
}
afterTokens, err := lc.GetModel().Tokenize(after, 2048, true, true)
afterTokens, err := lc.Model().Tokenize(after, 2048, true, true)
if err != nil {
return err
}
@ -73,7 +73,7 @@ func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, af
// prompt eval
for _, t := range beforeTokens {
batch.Add(t, llama.Pos(nPast), []llama.SeqId{0}, true)
batch.Add(t, nPast, []int{0}, true)
nPast++
}
@ -88,7 +88,7 @@ func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, af
batch = llama.NewBatch(512, 0, 1)
for _, t := range afterTokens {
batch.Add(t, llama.Pos(nPast), []llama.SeqId{0}, true)
batch.Add(t, nPast, []int{0}, true)
}
// main loop
@ -102,15 +102,15 @@ func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, af
token := lc.SampleTokenGreedy(batch)
// if it's an end of sequence token, break
if lc.GetModel().TokenIsEog(token) {
if lc.Model().TokenIsEog(token) {
break
}
// print the token
str := lc.GetModel().TokenToPiece(token)
str := lc.Model().TokenToPiece(token)
fmt.Print(str)
batch.Clear()
batch.Add(token, llama.Pos(n), []llama.SeqId{0}, true)
batch.Add(token, n, []int{0}, true)
}
return nil

View File

@ -1,19 +1,24 @@
package main
import (
"encoding/base64"
"encoding/json"
"flag"
"fmt"
"log"
"log/slog"
"net"
"net/http"
"regexp"
"strconv"
"sync"
"github.com/ollama/ollama/llama"
)
type Request struct {
Prompt string `json:"prompt"`
Prompt string `json:"prompt"`
Images []string `json:"images"`
}
type Response struct {
@ -23,6 +28,7 @@ type Response struct {
type Server struct {
model *llama.Model
lc *llama.Context
cc *llama.ClipContext
}
var mu sync.Mutex
@ -34,6 +40,9 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
return
}
mu.Lock()
defer mu.Unlock()
// Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked")
@ -41,30 +50,69 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
enc := json.NewEncoder(w)
// main loop
tokens, err := s.model.Tokenize(request.Prompt, 2048, true, true)
if err != nil {
panic(err)
// create embeddings for each image
var embeddings []*llama.LlavaImageEmbed
if s.cc != nil {
for _, img := range request.Images {
data, err := base64.StdEncoding.DecodeString(img)
if err != nil {
http.Error(w, "Failed to decode image", http.StatusBadRequest)
return
}
embd := llama.NewLlavaImageEmbed(s.cc, data)
embeddings = append(embeddings, embd)
}
}
var nPast int
// eval the prompt
re := regexp.MustCompile(`\[\s*img-(\d+)\s*\]`)
matches := re.FindAllStringSubmatchIndex(request.Prompt, -1)
// eval each chunk including images
pos := 0
for _, match := range matches {
part := request.Prompt[pos:match[0]]
fmt.Println("Text part:", part)
// eval text before image
err := s.evalText(part, &nPast)
if err != nil {
log.Println("Failed to eval text:", err)
return
}
// eval image
imgIndexStr := request.Prompt[match[2]:match[3]]
imgIndex, err := strconv.Atoi(imgIndexStr)
if err != nil {
slog.Warn("Failed to parse image index", "index", imgIndexStr)
continue
}
fmt.Println("Tag index:", imgIndex)
if imgIndex <= len(embeddings) {
slog.Info("evaluating image", "index", imgIndex)
llama.LlavaEvalImageEmbed(s.lc, embeddings[imgIndex], 512, &nPast)
}
pos = match[1]
}
// eval remaining text
if pos < len(request.Prompt) {
s.evalText(request.Prompt[pos:], &nPast)
}
batch := llama.NewBatch(512, 0, 1)
// prompt eval
for i, t := range tokens {
batch.Add(t, llama.Pos(i), []llama.SeqId{0}, true)
}
defer batch.Free()
// main loop
for n := batch.NumTokens(); n < 2048; n++ {
mu.Lock()
err = s.lc.Decode(batch)
if err != nil {
panic("Failed to decode")
}
for n := nPast; n < 2048; n++ {
// sample a token
token := s.lc.SampleTokenGreedy(batch)
mu.Unlock()
// if it's an end of sequence token, break
if s.model.TokenIsEog(token) {
@ -81,27 +129,44 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
w.(http.Flusher).Flush()
batch.Clear()
batch.Add(token, llama.Pos(n), []llama.SeqId{0}, true)
batch.Add(token, n, []int{0}, true)
err := s.lc.Decode(batch)
if err != nil {
panic("Failed to decode")
}
}
s.lc.KvCacheClear()
}
func main() {
mp := flag.String("model", "", "Path to model binary file")
mpath := flag.String("model", "", "Path to model binary file")
ppath := flag.String("projector", "", "Path to projector binary file")
flag.Parse()
// load the model
llama.BackendInit()
params := llama.NewModelParams()
model := llama.LoadModelFromFile(*mp, params)
model := llama.LoadModelFromFile(*mpath, params)
ctxParams := llama.NewContextParams()
lc := llama.NewContextWithModel(model, ctxParams)
if lc == nil {
panic("Failed to create context")
}
var cc *llama.ClipContext
if ppath != nil {
cc = llama.NewClipContext(*ppath)
if cc == nil {
panic("Failed to create clip context")
}
}
server := &Server{
model: model,
lc: lc,
cc: cc,
}
addr := "127.0.0.1:8080"
@ -121,3 +186,27 @@ func main() {
log.Fatal("server error:", err)
}
}
func (s *Server) evalText(text string, nPast *int) error {
// eval before
batch := llama.NewBatch(512, 0, 1)
defer batch.Free()
tokens, err := s.lc.Model().Tokenize(text, 2048, true, true)
if err != nil {
return fmt.Errorf("tokenize failed: %w", err)
}
// prompt eval
for _, t := range tokens {
batch.Add(t, *nPast, []int{0}, true)
*nPast++
}
err = s.lc.Decode(batch)
if err != nil {
return fmt.Errorf("decode failed: %w", err)
}
return nil
}