runner.go: Shift context window when KV cache space is exceeded

Currently, once the KV cache is full, text generation stops. Instead,
we should shift out the oldest context so that new generation can
continue based on more recent context.

This uses the algorithm from llama.cpp that is currently used by Ollama
with the server.cpp code. There are others but they are never turned
on through Ollama, so this restores parity.

The algorithm is:
 - Retain a configurable number of tokens at the beginning (for things
like beginning of sequence tokens
 - Drop the oldest half of the remaining tokens
 - Shift the remaining new tokens to the back of the cache
This commit is contained in:
Jesse Gross 2024-08-14 10:35:49 -07:00 committed by jmorganca
parent 5a441d227a
commit 69cc5795a7
2 changed files with 79 additions and 15 deletions

View File

@ -157,6 +157,10 @@ func (c *Context) SampleTokenGreedy(logits []float32) int {
}))
}
func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
}
func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
}
@ -191,6 +195,16 @@ func (m *Model) TokenIsEog(token int) bool {
return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
}
func (m *Model) ShouldAddBOSToken() bool {
addBos := int(C.llama_add_bos_token(m.c))
if addBos != -1 {
return addBos != 0
} else {
return C.llama_vocab_type(m.c) == C.LLAMA_VOCAB_TYPE_SPM
}
}
func (m *Model) ApplyLoraFromFile(loraPath string, scale float32, baseModelPath string, threads int) error {
cLoraPath := C.CString(loraPath)
defer C.free(unsafe.Pointer(cLoraPath))

View File

@ -49,6 +49,9 @@ type Sequence struct {
// stop sequences
stop []string
// number of tokens to keep at the beginning when shifting context window
numKeep int
// true if an embedding are to be returned instead of text generation
embeddingOnly bool
@ -61,22 +64,38 @@ type Sequence struct {
n_prompt_tokens int
}
func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
type NewSequenceParams struct {
numPredict int
stop []string
numKeep int
samplingParams *llama.SamplingParams
embedding bool
}
func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence {
tokens, err := s.lc.Model().Tokenize(prompt, true, true)
if err != nil {
panic(err)
}
// truncate to last n tokens
// TODO: this shouldn't happen and will severely impact generation
// quality. instead we should ensure to cut prompt in the API.
if params.numKeep < 0 {
params.numKeep = len(tokens)
}
// Subtracting 4 ensures that at least 1 token can be discarded during shift
params.numKeep = min(params.numKeep, s.numCtx-4)
params.numKeep += s.bosToken
// truncate to fit in context window
if len(tokens) > s.numCtx {
tokens = tokens[:s.numCtx]
slog.Warn("truncating input prompt", "limit", s.numCtx, "prompt", len(tokens), "numKeep", params.numKeep)
newTokens := tokens[:params.numKeep]
newTokens = append(newTokens, tokens[len(tokens)-s.numCtx+params.numKeep:]...)
tokens = newTokens
}
var sc *llama.SamplingContext
if params != nil {
sc = llama.NewSamplingContext(*params)
if params.samplingParams != nil {
sc = llama.NewSamplingContext(*params.samplingParams)
for _, t := range tokens {
sc.Accept(s.lc, t, false)
}
@ -85,12 +104,13 @@ func (s *Server) NewSequence(prompt string, numPredict int, stop []string, param
return &Sequence{
tokens: tokens,
n_prompt_tokens: len(tokens),
numPredict: numPredict,
numPredict: params.numPredict,
responses: make(chan string, 1),
embedding: make(chan []float32, 1),
samplingCtx: sc,
embeddingOnly: embedding,
stop: stop,
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
}
}
@ -111,6 +131,9 @@ type Server struct {
// context window size
numCtx int
// does this model require a beginning of sequence token?
bosToken int
mu sync.Mutex
cond *sync.Cond
@ -129,6 +152,21 @@ func (s *Server) allNil() bool {
return true
}
func (s *Server) shiftContext(seqIndex int) {
seq := s.seqs[seqIndex]
numLeft := seq.nPast - seq.numKeep
numDiscard := numLeft / 2
slog.Debug("context limit hit - shifting", "limit", s.numCtx, "nPast", seq.nPast,
"numKeep", seq.numKeep, "numLeft", numLeft, "numDiscard", numDiscard)
s.lc.KvCacheSeqRm(seqIndex, seq.numKeep, seq.numKeep+numDiscard)
s.lc.KvCacheSeqAdd(seqIndex, seq.numKeep+numDiscard, seq.nPast, -numDiscard)
seq.nPast -= numDiscard
}
func (s *Server) run(ctx context.Context) {
// TODO - should this be n_ctx / parallel like the old server.cpp setup?
batch := llama.NewBatch(s.batchSize, 0, s.parallel)
@ -155,10 +193,8 @@ func (s *Server) run(ctx context.Context) {
continue
}
hitLimit := seq.numPredict > 0 && seq.numPredicted > seq.numPredict
// if past the num predict limit
if hitLimit || seq.nPast > s.numCtx {
if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
seq.doneReason = "limit"
close(seq.responses)
s.lc.KvCacheSeqRm(i, 0, -1)
@ -166,6 +202,10 @@ func (s *Server) run(ctx context.Context) {
continue
}
if seq.nPast+len(seq.tokens) > s.numCtx {
s.shiftContext(i)
}
if seq.t_start_process_prompt.IsZero() {
seq.t_start_process_prompt = time.Now()
}
@ -350,7 +390,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
samplingParams.Seed = uint32(req.Seed)
samplingParams.Grammar = req.Grammar
seq := s.NewSequence(req.Prompt, req.NumPredict, req.Stop, &samplingParams, false)
seq := s.NewSequence(req.Prompt, NewSequenceParams{
numPredict: req.NumPredict,
stop: req.Stop,
numKeep: req.NumKeep,
samplingParams: &samplingParams,
embedding: false,
})
// TODO (jmorganca): add to sequence queue instead of
// failing if a slot isn't available
@ -428,7 +474,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
embeddings := make([][]float32, len(req.Content))
var processed int
for i, content := range req.Content {
seqs[i] = s.NewSequence(content, 0, nil, nil, true)
seqs[i] = s.NewSequence(content, NewSequenceParams{embedding: true})
}
// TODO - refactor to go routines to add seq's and drain the responses
@ -563,6 +609,10 @@ func main() {
ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
server.lc = llama.NewContextWithModel(server.model, ctxParams)
if server.model.ShouldAddBOSToken() {
server.bosToken = 1
}
if *ppath != "" {
server.cc = llama.NewClipContext(*ppath)
}