From a379d68aa9a3b2c3722f70ed6678944006e6a992 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Mon, 27 May 2024 14:38:44 -0700 Subject: [PATCH] wip stop tokens --- llama/runner/runner.go | 91 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 85 insertions(+), 6 deletions(-) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 0401bb88..0d34febf 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "strconv" + "strings" "sync" "github.com/ollama/ollama/api" @@ -31,6 +32,9 @@ type Sequence struct { // channel to send back the embedding if embedding only embedding chan []float32 + // stop sequences + stop []string + // true if an embedding are to be returned instead of text generation embeddingOnly bool } @@ -40,7 +44,7 @@ func (s *Sequence) prompt() bool { return s.nPast < len(s.tokens)-1 } -func (s *Server) NewSequence(prompt string, params *llama.SamplingParams, embedding bool) *Sequence { +func (s *Server) NewSequence(prompt string, stop []string, params *llama.SamplingParams, embedding bool) *Sequence { tokens, err := s.lc.Model().Tokenize(prompt, 2048, false, true) if err != nil { panic(err) @@ -60,6 +64,7 @@ func (s *Server) NewSequence(prompt string, params *llama.SamplingParams, embedd embedding: make(chan []float32, 1), samplingCtx: sc, embeddingOnly: embedding, + stop: stop, } } @@ -72,6 +77,7 @@ type Server struct { parallel int // seqs is the list of parallel sequences being evaluated + // TODO (jmorganca): this can probably be moved into run() seqs []*Sequence mu sync.Mutex @@ -88,10 +94,36 @@ func (s *Server) allNil() bool { return true } +func contains(sequence string, stops []string) (bool, string) { + for _, stop := range stops { + if strings.Contains(sequence, stop) { + return true, stop + } + } + + return false, "" +} + +func overlap(sequence string, stops []string) bool { + for _, stop := range stops { + for i := 1; i < len(stop); i++ { + if strings.HasSuffix(sequence, stop[:i]) { + return true + } + } + } + + return false +} + func (s *Server) run(ctx context.Context) { batch := llama.NewBatch(512, 0, s.parallel) defer batch.Free() + // build up stop sequences as we recognize them + // TODO (jmorganca): simplify this + sofar := make([][]string, s.parallel) + for { select { case <-ctx.Done(): @@ -165,21 +197,67 @@ func (s *Server) run(ctx context.Context) { // logits := s.lc.GetLogitsIth(ibatch[i]) // token := s.lc.SampleTokenGreedy(logits) token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i]) - seq.samplingCtx.Accept(s.lc, token, true) - seq.responses <- s.model.TokenToPiece(token) - seq.tokens = []int{token} + seq.samplingCtx.Accept(s.lc, token, true) + piece := s.model.TokenToPiece(token) + slog.Info("sampled", "piece", piece) // if it's an end of sequence token, break // TODO: just end this sequence if s.model.TokenIsEog(token) { // TODO: end the sequence instead of quitting the pool s.lc.KvCacheSeqRm(i, 0, -1) + + // TODO (jmorganca): we should send this back + // as it's important for the /api/generate context + // seq.responses <- piece + close(seq.responses) seq.samplingCtx.Free() + sofar[i] = []string{} s.seqs[i] = nil continue } + + seq.tokens = []int{token} + + // recognize stop sequences + // TODO (jmorganca): add tests around this + // TODO (jmorganca): send back parital piece + + sequence := strings.Join(append(sofar[i], piece), "") + if ok, stop := contains(sequence, seq.stop); ok { + slog.Info("hit stop token", "stop", seq.stop) + for _, p := range sofar[i] { + seq.responses <- p + } + + piece, _, _ := strings.Cut(piece, stop) + seq.responses <- piece + + s.lc.KvCacheSeqRm(i, 0, -1) + close(seq.responses) + seq.samplingCtx.Free() + sofar[i] = []string{} + s.seqs[i] = nil + continue + } + + if overlap(sequence, seq.stop) { + slog.Info("overlap", "sequence", sequence) + // partial stop, don't send + continue + } + + slog.Info("sending", "sofar", sofar[i]) + + sofar[i] = append(sofar[i], piece) + + for _, p := range sofar[i] { + seq.responses <- p + } + + sofar[i] = []string{} } batch.Clear() @@ -191,6 +269,7 @@ type CompletionRequest struct { Prompt string `json:"prompt"` Images []string `json:"images"` Grammar string `json:"grammar"` + Stop []string `json:"stop"` api.Options } @@ -228,7 +307,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { samplingParams.Seed = uint32(req.Seed) samplingParams.Grammar = req.Grammar - seq := s.NewSequence(req.Prompt, &samplingParams, false) + seq := s.NewSequence(req.Prompt, req.Stop, &samplingParams, false) // TODO (jmorganca): add to sequence queue instead of // failing if a slot isn't available @@ -279,7 +358,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - seq := s.NewSequence(req.Prompt, nil, true) + seq := s.NewSequence(req.Prompt, nil, nil, true) s.mu.Lock() for i, sq := range s.seqs {