diff --git a/llama/example/README.md b/llama/example/README.md new file mode 100644 index 00000000..f1624884 --- /dev/null +++ b/llama/example/README.md @@ -0,0 +1,31 @@ +# `example` + +Demo app for the `llama` package + +Pull a model: + +``` +ollama pull mistral:7b-instruct-v0.3-q4_0 +``` + +Then run it: + +``` +go run -x . \ + -model ~/.ollama/models/blobs/sha256-ff82381e2bea77d91c1b824c7afb83f6fb73e9f7de9dda631bcdbca564aa5435 \ + -prompt "[ISNT] Why is the sky blue? [/INST]" +``` + +## Vision + +``` +ollama pull llava:7b-v1.6-mistral-q4_0 +``` + +``` +go run -x . \ + -model ~/.ollama/models/blobs/sha256-170370233dd5c5415250a2ecd5c71586352850729062ccef1496385647293868 \ + -projector ~/.ollama/models/blobs/sha256-72d6f08a42f656d36b356dbe0920675899a99ce21192fd66266fb7d82ed07539 \ + -image ./alonso.jpg \ + -prompt "[ISNT] What is in this image? [/INST]" +``` diff --git a/llama/llava/alonso.jpg b/llama/example/alonso.jpg similarity index 100% rename from llama/llava/alonso.jpg rename to llama/example/alonso.jpg diff --git a/llama/example/main.go b/llama/example/main.go new file mode 100644 index 00000000..dde52440 --- /dev/null +++ b/llama/example/main.go @@ -0,0 +1,128 @@ +package main + +import ( + "flag" + "fmt" + "io" + "log" + "os" + "strings" + + "github.com/ollama/ollama/llama" +) + +func main() { + mpath := flag.String("model", "", "Path to model binary file") + ppath := flag.String("projector", "", "Path to projector binary file") + image := flag.String("image", "", "Path to image file") + prompt := flag.String("prompt", "", "Prompt including tag") + flag.Parse() + + if *mpath == "" { + panic("model path is required") + } + + if *prompt == "" { + panic("prompt is required") + } + + // load the model + llama.BackendInit() + params := llama.NewModelParams() + model := llama.LoadModelFromFile(*mpath, params) + ctxParams := llama.NewContextParams() + + // language model context + lc := llama.NewContextWithModel(model, ctxParams) + + // eval before + batch := llama.NewBatch(512, 0, 1) + var nPast int + + // clip context + var clipCtx *llama.ClipContext + + // multi-modal + if *ppath == "" { + clipCtx = llama.NewClipContext(*ppath) + + // open image file + file, err := os.Open(*image) + if err != nil { + panic(err) + } + defer file.Close() + + data, err := io.ReadAll(file) + if err != nil { + log.Fatal(err) + } + + embedding := llama.NewLlavaImageEmbed(clipCtx, data) + + parts := strings.Split(*prompt, "") + if len(parts) != 2 { + panic("prompt must contain exactly one ") + } + + beforeTokens, err := lc.Model().Tokenize(parts[0], 2048, true, true) + if err != nil { + panic(err) + } + + for _, t := range beforeTokens { + batch.Add(t, nPast, []int{0}, true) + nPast++ + } + + err = lc.Decode(batch) + if err != nil { + panic(err) + } + + llama.LlavaEvalImageEmbed(lc, embedding, 512, &nPast) + + afterTokens, err := lc.Model().Tokenize(parts[1], 2048, true, true) + if err != nil { + panic(err) + } + + for _, t := range afterTokens { + batch.Add(t, nPast, []int{0}, true) + nPast++ + } + } else { + tokens, err := lc.Model().Tokenize(*prompt, 2048, true, true) + if err != nil { + panic(err) + } + + for _, t := range tokens { + batch.Add(t, nPast, []int{0}, true) + nPast++ + } + } + + // main loop + for n := nPast; n < 4096; n++ { + err := lc.Decode(batch) + if err != nil { + panic(err) + } + + // sample a token + logits := lc.GetLogitsIth(batch.NumTokens() - 1) + token := lc.SampleTokenGreedy(logits) + + // if it's an end of sequence token, break + if lc.Model().TokenIsEog(token) { + break + } + + // print the token + str := lc.Model().TokenToPiece(token) + fmt.Print(str) + batch.Clear() + batch.Add(token, n, []int{0}, true) + } +} diff --git a/llama/llama.go b/llama/llama.go index 421cff0c..27cf7516 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -99,26 +99,24 @@ func (c *Context) Model() *Model { return &Model{c: C.llama_get_model(c.c)} } -// TODO: break this up -func (c *Context) SampleTokenGreedy(batch Batch, i int) int { - nv := c.Model().NumVocab() +func (c *Context) GetLogitsIth(i int) []float32 { + return unsafe.Slice((*float32)(unsafe.Pointer(C.llama_get_logits_ith(c.c, C.int(i)))), c.Model().NumVocab()) +} - // TODO(jmorganca): split this up into different functions - candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(nv) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{})))) +func (c *Context) SampleTokenGreedy(logits []float32) int { + candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(len(logits)) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{})))) defer C.free(unsafe.Pointer(candidates)) - // get most recent logits - logits := C.llama_get_logits_ith(c.c, C.int(i)) - for i := 0; i < int(nv); i++ { + for i, logit := range logits { ptr := (*C.struct_llama_token_data)(unsafe.Pointer(uintptr(unsafe.Pointer(candidates)) + uintptr(i)*unsafe.Sizeof(C.struct_llama_token_data{}))) ptr.id = C.int(i) - ptr.logit = unsafe.Slice(logits, nv)[i] + ptr.logit = C.float(logit) ptr.p = 0.0 } return int(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{ data: candidates, - size: C.size_t(nv), + size: C.size_t(len(logits)), sorted: C.bool(false), })) } @@ -155,6 +153,8 @@ func (b *Batch) NumTokens() int { return int(b.c.n_tokens) } +// Add adds a token to the batch with the given position for the given +// sequence ids, and optionally instructs to include logits. func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) { unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token) unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos) @@ -179,12 +179,6 @@ func (b *Batch) Free() { C.llama_batch_free(b.c) } -// LLAMA_API struct llama_batch llama_batch_get_one( -// -// llama_token * tokens, -// int32_t n_tokens, -// llama_pos pos_0, -// llama_seq_id seq_id); func BatchGetOne(tokens []int, pos0 int, seqId int) Batch { return Batch{c: C.llama_batch_get_one((*C.int)(unsafe.Pointer(&tokens[0])), C.int32_t(len(tokens)), C.int(pos0), C.int(seqId))} } diff --git a/llama/llava/README.md b/llama/llava/README.md deleted file mode 100644 index 979f3869..00000000 --- a/llama/llava/README.md +++ /dev/null @@ -1,14 +0,0 @@ -# `llava` - -Demo app for running Llava and other clip-based vision models. - -``` -ollama pull llava -``` - -``` -go run -x . \ - -model ~/.ollama/models/blobs/sha256-170370233dd5c5415250a2ecd5c71586352850729062ccef1496385647293868 \ - -projector ~/.ollama/models/blobs/sha256-72d6f08a42f656d36b356dbe0920675899a99ce21192fd66266fb7d82ed07539 \ - -image ./alonso.jpg -``` diff --git a/llama/llava/main.go b/llama/llava/main.go deleted file mode 100644 index 60331490..00000000 --- a/llama/llava/main.go +++ /dev/null @@ -1,117 +0,0 @@ -package main - -import ( - "flag" - "fmt" - "io" - "log" - "os" - "strings" - - "github.com/ollama/ollama/llama" -) - -func main() { - mp := flag.String("model", "", "Path to model binary file") - pp := flag.String("projector", "", "Path to projector binary file") - image := flag.String("image", "", "Path to image file") - prompt := flag.String("prompt", " [INST] What is in the picture? [/INST]", "Prompt including tag") - flag.Parse() - - // load the model - llama.BackendInit() - params := llama.NewModelParams() - model := llama.LoadModelFromFile(*mp, params) - ctxParams := llama.NewContextParams() - - // language model context - lc := llama.NewContextWithModel(model, ctxParams) - - // clip context - clipCtx := llama.NewClipContext(*pp) - - // open image file - file, err := os.Open(*image) - if err != nil { - panic(err) - } - defer file.Close() - - data, err := io.ReadAll(file) - if err != nil { - log.Fatal(err) - } - - embedding := llama.NewLlavaImageEmbed(clipCtx, data) - - parts := strings.Split(*prompt, "") - if len(parts) != 2 { - panic("prompt must contain exactly one ") - } - - err = eval(lc, parts[0], embedding, parts[1]) - if err != nil { - panic(err) - } -} - -func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, after string) error { - beforeTokens, err := lc.Model().Tokenize(before, 2048, true, true) - if err != nil { - return err - } - - afterTokens, err := lc.Model().Tokenize(after, 2048, true, true) - if err != nil { - return err - } - - // eval before - batch := llama.NewBatch(512, 0, 1) - - var nPast int - - // prompt eval - for _, t := range beforeTokens { - batch.Add(t, nPast, []int{0}, true) - nPast++ - } - - err = lc.Decode(batch) - if err != nil { - return err - } - - // batch.Clear() - - llama.LlavaEvalImageEmbed(lc, embedding, 512, &nPast) - - batch = llama.NewBatch(512, 0, 1) - for _, t := range afterTokens { - batch.Add(t, nPast, []int{0}, true) - } - - // main loop - for n := nPast; n < 4096; n++ { - err = lc.Decode(batch) - if err != nil { - panic("Failed to decode") - } - - // sample a token - token := lc.SampleTokenGreedy(batch) - - // if it's an end of sequence token, break - if lc.Model().TokenIsEog(token) { - break - } - - // print the token - str := lc.Model().TokenToPiece(token) - fmt.Print(str) - batch.Clear() - batch.Add(token, n, []int{0}, true) - } - - return nil -} diff --git a/llama/runner/README.md b/llama/runner/README.md index 7695dc1b..e7cf51c0 100644 --- a/llama/runner/README.md +++ b/llama/runner/README.md @@ -1,5 +1,7 @@ # `runner` +A subprocess runner for loading a model and running inference via a small http web server. + ``` ./runner -model ``` diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 85d03f0f..e75ec671 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -9,8 +9,10 @@ import ( "log/slog" "net" "net/http" + "strconv" "sync" + "github.com/ollama/ollama/api" "github.com/ollama/ollama/llama" ) @@ -131,7 +133,8 @@ func (s *Server) run(ctx context.Context) { // sample a token // TODO: sample based on the sequence fmt.Println("Sampling token", i, ibatch[i]) - token := s.lc.SampleTokenGreedy(batch, ibatch[i]) + logits := s.lc.GetLogitsIth(ibatch[i]) + token := s.lc.SampleTokenGreedy(logits) // if it's an end of sequence token, break // TODO: just end this sequence @@ -155,6 +158,8 @@ func (s *Server) run(ctx context.Context) { type Request struct { Prompt string `json:"prompt"` Images []string `json:"images"` + + api.Options } type Response struct { @@ -208,6 +213,7 @@ func main() { mpath := flag.String("model", "", "Path to model binary file") ppath := flag.String("projector", "", "Path to projector binary file") parallel := flag.Int("parallel", 1, "Number of sequences to handle simultaneously") + port := flag.Int("port", 8080, "Port to expose the server on") flag.Parse() // load the model @@ -241,7 +247,7 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) go server.run(ctx) - addr := "127.0.0.1:8080" + addr := "127.0.0.1:" + strconv.Itoa(*port) listener, err := net.Listen("tcp", addr) if err != nil { fmt.Println("Listen error:", err)