add llava
to runner
This commit is contained in:
parent
87af27dac0
commit
fbc8572859
@ -38,10 +38,6 @@ import (
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type Token int32
|
||||
type Pos int32
|
||||
type SeqId int32
|
||||
|
||||
// SystemInfo is an unused example of calling llama.cpp functions using CGo
|
||||
func PrintSystemInfo() string {
|
||||
return C.GoString(C.llama_print_system_info())
|
||||
@ -78,6 +74,10 @@ type Context struct {
|
||||
c *C.struct_llama_context
|
||||
}
|
||||
|
||||
func (c *Context) KvCacheClear() {
|
||||
C.llama_kv_cache_clear(c.c)
|
||||
}
|
||||
|
||||
func (c *Context) Decode(batch Batch) error {
|
||||
// Positive return values does not mean a fatal error, but rather a warning.
|
||||
// 0 - success
|
||||
@ -90,18 +90,18 @@ func (c *Context) Decode(batch Batch) error {
|
||||
}
|
||||
|
||||
if code > 0 {
|
||||
return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d\n", code)
|
||||
return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d", code)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Context) GetModel() *Model {
|
||||
func (c *Context) Model() *Model {
|
||||
return &Model{c: C.llama_get_model(c.c)}
|
||||
}
|
||||
|
||||
func (c *Context) SampleTokenGreedy(batch Batch) Token {
|
||||
nv := c.GetModel().NumVocab()
|
||||
func (c *Context) SampleTokenGreedy(batch Batch) int {
|
||||
nv := c.Model().NumVocab()
|
||||
|
||||
// TODO(jmorganca): split this up into different functions
|
||||
candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(nv) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{}))))
|
||||
@ -116,7 +116,7 @@ func (c *Context) SampleTokenGreedy(batch Batch) Token {
|
||||
ptr.p = 0.0
|
||||
}
|
||||
|
||||
return Token(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{
|
||||
return int(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{
|
||||
data: candidates,
|
||||
size: C.size_t(nv),
|
||||
sorted: C.bool(false),
|
||||
@ -135,7 +135,7 @@ func (m *Model) NumVocab() int {
|
||||
return int(C.llama_n_vocab(m.c))
|
||||
}
|
||||
|
||||
func (m *Model) TokenIsEog(token Token) bool {
|
||||
func (m *Model) TokenIsEog(token int) bool {
|
||||
return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
|
||||
}
|
||||
|
||||
@ -151,7 +151,7 @@ func (b *Batch) NumTokens() int {
|
||||
return int(b.c.n_tokens)
|
||||
}
|
||||
|
||||
func (b *Batch) Add(token Token, pos Pos, seqIds []SeqId, logits bool) {
|
||||
func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) {
|
||||
unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token)
|
||||
unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos)
|
||||
unsafe.Slice(b.c.n_seq_id, 512)[b.c.n_tokens] = C.int(len(seqIds))
|
||||
@ -171,13 +171,17 @@ func (b *Batch) Clear() {
|
||||
b.c.n_tokens = 0
|
||||
}
|
||||
|
||||
func (b *Batch) Free() {
|
||||
C.llama_batch_free(b.c)
|
||||
}
|
||||
|
||||
// LLAMA_API struct llama_batch llama_batch_get_one(
|
||||
//
|
||||
// llama_token * tokens,
|
||||
// int32_t n_tokens,
|
||||
// llama_pos pos_0,
|
||||
// llama_seq_id seq_id);
|
||||
func BatchGetOne(tokens []Token, pos0 Pos, seqId SeqId) Batch {
|
||||
func BatchGetOne(tokens []int, pos0 int, seqId int) Batch {
|
||||
return Batch{c: C.llama_batch_get_one((*C.int)(unsafe.Pointer(&tokens[0])), C.int32_t(len(tokens)), C.int(pos0), C.int(seqId))}
|
||||
}
|
||||
|
||||
@ -185,7 +189,7 @@ type Model struct {
|
||||
c *C.struct_llama_model
|
||||
}
|
||||
|
||||
func (m *Model) TokenToPiece(token Token) string {
|
||||
func (m *Model) TokenToPiece(token int) string {
|
||||
buf := make([]byte, 12)
|
||||
C.llama_token_to_piece(
|
||||
m.c,
|
||||
@ -197,7 +201,7 @@ func (m *Model) TokenToPiece(token Token) string {
|
||||
return strings.TrimRight(string(buf), "\x00")
|
||||
}
|
||||
|
||||
func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpecial bool) ([]Token, error) {
|
||||
func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpecial bool) ([]int, error) {
|
||||
cTokens := make([]C.llama_token, maxTokens)
|
||||
cText := C.CString(text)
|
||||
defer C.free(unsafe.Pointer(cText))
|
||||
@ -216,9 +220,9 @@ func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpeci
|
||||
return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
|
||||
}
|
||||
|
||||
tokens := make([]Token, result)
|
||||
tokens := make([]int, result)
|
||||
for i := 0; i < int(result); i++ {
|
||||
tokens[i] = Token(cTokens[i])
|
||||
tokens[i] = int(cTokens[i])
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
|
@ -56,12 +56,12 @@ func main() {
|
||||
}
|
||||
|
||||
func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, after string) error {
|
||||
beforeTokens, err := lc.GetModel().Tokenize(before, 2048, true, true)
|
||||
beforeTokens, err := lc.Model().Tokenize(before, 2048, true, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
afterTokens, err := lc.GetModel().Tokenize(after, 2048, true, true)
|
||||
afterTokens, err := lc.Model().Tokenize(after, 2048, true, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -73,7 +73,7 @@ func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, af
|
||||
|
||||
// prompt eval
|
||||
for _, t := range beforeTokens {
|
||||
batch.Add(t, llama.Pos(nPast), []llama.SeqId{0}, true)
|
||||
batch.Add(t, nPast, []int{0}, true)
|
||||
nPast++
|
||||
}
|
||||
|
||||
@ -88,7 +88,7 @@ func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, af
|
||||
|
||||
batch = llama.NewBatch(512, 0, 1)
|
||||
for _, t := range afterTokens {
|
||||
batch.Add(t, llama.Pos(nPast), []llama.SeqId{0}, true)
|
||||
batch.Add(t, nPast, []int{0}, true)
|
||||
}
|
||||
|
||||
// main loop
|
||||
@ -102,15 +102,15 @@ func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, af
|
||||
token := lc.SampleTokenGreedy(batch)
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if lc.GetModel().TokenIsEog(token) {
|
||||
if lc.Model().TokenIsEog(token) {
|
||||
break
|
||||
}
|
||||
|
||||
// print the token
|
||||
str := lc.GetModel().TokenToPiece(token)
|
||||
str := lc.Model().TokenToPiece(token)
|
||||
fmt.Print(str)
|
||||
batch.Clear()
|
||||
batch.Add(token, llama.Pos(n), []llama.SeqId{0}, true)
|
||||
batch.Add(token, n, []int{0}, true)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -1,19 +1,24 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Prompt string `json:"prompt"`
|
||||
Images []string `json:"images"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
@ -23,6 +28,7 @@ type Response struct {
|
||||
type Server struct {
|
||||
model *llama.Model
|
||||
lc *llama.Context
|
||||
cc *llama.ClipContext
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
@ -34,6 +40,9 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Set the headers to indicate streaming
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
@ -41,30 +50,69 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
|
||||
// main loop
|
||||
tokens, err := s.model.Tokenize(request.Prompt, 2048, true, true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
// create embeddings for each image
|
||||
var embeddings []*llama.LlavaImageEmbed
|
||||
if s.cc != nil {
|
||||
for _, img := range request.Images {
|
||||
data, err := base64.StdEncoding.DecodeString(img)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to decode image", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
embd := llama.NewLlavaImageEmbed(s.cc, data)
|
||||
embeddings = append(embeddings, embd)
|
||||
}
|
||||
}
|
||||
|
||||
var nPast int
|
||||
|
||||
// eval the prompt
|
||||
re := regexp.MustCompile(`\[\s*img-(\d+)\s*\]`)
|
||||
matches := re.FindAllStringSubmatchIndex(request.Prompt, -1)
|
||||
|
||||
// eval each chunk including images
|
||||
pos := 0
|
||||
for _, match := range matches {
|
||||
part := request.Prompt[pos:match[0]]
|
||||
fmt.Println("Text part:", part)
|
||||
|
||||
// eval text before image
|
||||
err := s.evalText(part, &nPast)
|
||||
if err != nil {
|
||||
log.Println("Failed to eval text:", err)
|
||||
return
|
||||
}
|
||||
|
||||
// eval image
|
||||
imgIndexStr := request.Prompt[match[2]:match[3]]
|
||||
imgIndex, err := strconv.Atoi(imgIndexStr)
|
||||
if err != nil {
|
||||
slog.Warn("Failed to parse image index", "index", imgIndexStr)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Println("Tag index:", imgIndex)
|
||||
if imgIndex <= len(embeddings) {
|
||||
slog.Info("evaluating image", "index", imgIndex)
|
||||
llama.LlavaEvalImageEmbed(s.lc, embeddings[imgIndex], 512, &nPast)
|
||||
}
|
||||
|
||||
pos = match[1]
|
||||
}
|
||||
|
||||
// eval remaining text
|
||||
if pos < len(request.Prompt) {
|
||||
s.evalText(request.Prompt[pos:], &nPast)
|
||||
}
|
||||
|
||||
batch := llama.NewBatch(512, 0, 1)
|
||||
|
||||
// prompt eval
|
||||
for i, t := range tokens {
|
||||
batch.Add(t, llama.Pos(i), []llama.SeqId{0}, true)
|
||||
}
|
||||
defer batch.Free()
|
||||
|
||||
// main loop
|
||||
for n := batch.NumTokens(); n < 2048; n++ {
|
||||
mu.Lock()
|
||||
err = s.lc.Decode(batch)
|
||||
if err != nil {
|
||||
panic("Failed to decode")
|
||||
}
|
||||
|
||||
for n := nPast; n < 2048; n++ {
|
||||
// sample a token
|
||||
token := s.lc.SampleTokenGreedy(batch)
|
||||
mu.Unlock()
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.TokenIsEog(token) {
|
||||
@ -81,27 +129,44 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
|
||||
w.(http.Flusher).Flush()
|
||||
|
||||
batch.Clear()
|
||||
batch.Add(token, llama.Pos(n), []llama.SeqId{0}, true)
|
||||
batch.Add(token, n, []int{0}, true)
|
||||
|
||||
err := s.lc.Decode(batch)
|
||||
if err != nil {
|
||||
panic("Failed to decode")
|
||||
}
|
||||
}
|
||||
|
||||
s.lc.KvCacheClear()
|
||||
}
|
||||
|
||||
func main() {
|
||||
mp := flag.String("model", "", "Path to model binary file")
|
||||
mpath := flag.String("model", "", "Path to model binary file")
|
||||
ppath := flag.String("projector", "", "Path to projector binary file")
|
||||
flag.Parse()
|
||||
|
||||
// load the model
|
||||
llama.BackendInit()
|
||||
params := llama.NewModelParams()
|
||||
model := llama.LoadModelFromFile(*mp, params)
|
||||
model := llama.LoadModelFromFile(*mpath, params)
|
||||
ctxParams := llama.NewContextParams()
|
||||
lc := llama.NewContextWithModel(model, ctxParams)
|
||||
if lc == nil {
|
||||
panic("Failed to create context")
|
||||
}
|
||||
|
||||
var cc *llama.ClipContext
|
||||
if ppath != nil {
|
||||
cc = llama.NewClipContext(*ppath)
|
||||
if cc == nil {
|
||||
panic("Failed to create clip context")
|
||||
}
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
model: model,
|
||||
lc: lc,
|
||||
cc: cc,
|
||||
}
|
||||
|
||||
addr := "127.0.0.1:8080"
|
||||
@ -121,3 +186,27 @@ func main() {
|
||||
log.Fatal("server error:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) evalText(text string, nPast *int) error {
|
||||
// eval before
|
||||
batch := llama.NewBatch(512, 0, 1)
|
||||
defer batch.Free()
|
||||
|
||||
tokens, err := s.lc.Model().Tokenize(text, 2048, true, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenize failed: %w", err)
|
||||
}
|
||||
|
||||
// prompt eval
|
||||
for _, t := range tokens {
|
||||
batch.Add(t, *nPast, []int{0}, true)
|
||||
*nPast++
|
||||
}
|
||||
|
||||
err = s.lc.Decode(batch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user