runner.go: Shift context window when KV cache space is exceeded
Currently, once the KV cache is full, text generation stops. Instead, we should shift out the oldest context so that new generation can continue based on more recent context. This uses the algorithm from llama.cpp that is currently used by Ollama with the server.cpp code. There are others but they are never turned on through Ollama, so this restores parity. The algorithm is: - Retain a configurable number of tokens at the beginning (for things like beginning of sequence tokens - Drop the oldest half of the remaining tokens - Shift the remaining new tokens to the back of the cache
This commit is contained in:
parent
5a441d227a
commit
69cc5795a7
@ -157,6 +157,10 @@ func (c *Context) SampleTokenGreedy(logits []float32) int {
|
||||
}))
|
||||
}
|
||||
|
||||
func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
|
||||
C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
|
||||
}
|
||||
|
||||
func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
|
||||
return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
|
||||
}
|
||||
@ -191,6 +195,16 @@ func (m *Model) TokenIsEog(token int) bool {
|
||||
return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
|
||||
}
|
||||
|
||||
func (m *Model) ShouldAddBOSToken() bool {
|
||||
addBos := int(C.llama_add_bos_token(m.c))
|
||||
|
||||
if addBos != -1 {
|
||||
return addBos != 0
|
||||
} else {
|
||||
return C.llama_vocab_type(m.c) == C.LLAMA_VOCAB_TYPE_SPM
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) ApplyLoraFromFile(loraPath string, scale float32, baseModelPath string, threads int) error {
|
||||
cLoraPath := C.CString(loraPath)
|
||||
defer C.free(unsafe.Pointer(cLoraPath))
|
||||
|
@ -49,6 +49,9 @@ type Sequence struct {
|
||||
// stop sequences
|
||||
stop []string
|
||||
|
||||
// number of tokens to keep at the beginning when shifting context window
|
||||
numKeep int
|
||||
|
||||
// true if an embedding are to be returned instead of text generation
|
||||
embeddingOnly bool
|
||||
|
||||
@ -61,22 +64,38 @@ type Sequence struct {
|
||||
n_prompt_tokens int
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
|
||||
type NewSequenceParams struct {
|
||||
numPredict int
|
||||
stop []string
|
||||
numKeep int
|
||||
samplingParams *llama.SamplingParams
|
||||
embedding bool
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence {
|
||||
tokens, err := s.lc.Model().Tokenize(prompt, true, true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// truncate to last n tokens
|
||||
// TODO: this shouldn't happen and will severely impact generation
|
||||
// quality. instead we should ensure to cut prompt in the API.
|
||||
if params.numKeep < 0 {
|
||||
params.numKeep = len(tokens)
|
||||
}
|
||||
// Subtracting 4 ensures that at least 1 token can be discarded during shift
|
||||
params.numKeep = min(params.numKeep, s.numCtx-4)
|
||||
params.numKeep += s.bosToken
|
||||
|
||||
// truncate to fit in context window
|
||||
if len(tokens) > s.numCtx {
|
||||
tokens = tokens[:s.numCtx]
|
||||
slog.Warn("truncating input prompt", "limit", s.numCtx, "prompt", len(tokens), "numKeep", params.numKeep)
|
||||
newTokens := tokens[:params.numKeep]
|
||||
newTokens = append(newTokens, tokens[len(tokens)-s.numCtx+params.numKeep:]...)
|
||||
tokens = newTokens
|
||||
}
|
||||
|
||||
var sc *llama.SamplingContext
|
||||
if params != nil {
|
||||
sc = llama.NewSamplingContext(*params)
|
||||
if params.samplingParams != nil {
|
||||
sc = llama.NewSamplingContext(*params.samplingParams)
|
||||
for _, t := range tokens {
|
||||
sc.Accept(s.lc, t, false)
|
||||
}
|
||||
@ -85,12 +104,13 @@ func (s *Server) NewSequence(prompt string, numPredict int, stop []string, param
|
||||
return &Sequence{
|
||||
tokens: tokens,
|
||||
n_prompt_tokens: len(tokens),
|
||||
numPredict: numPredict,
|
||||
numPredict: params.numPredict,
|
||||
responses: make(chan string, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
samplingCtx: sc,
|
||||
embeddingOnly: embedding,
|
||||
stop: stop,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
}
|
||||
}
|
||||
|
||||
@ -111,6 +131,9 @@ type Server struct {
|
||||
// context window size
|
||||
numCtx int
|
||||
|
||||
// does this model require a beginning of sequence token?
|
||||
bosToken int
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
cond *sync.Cond
|
||||
@ -129,6 +152,21 @@ func (s *Server) allNil() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) shiftContext(seqIndex int) {
|
||||
seq := s.seqs[seqIndex]
|
||||
|
||||
numLeft := seq.nPast - seq.numKeep
|
||||
numDiscard := numLeft / 2
|
||||
|
||||
slog.Debug("context limit hit - shifting", "limit", s.numCtx, "nPast", seq.nPast,
|
||||
"numKeep", seq.numKeep, "numLeft", numLeft, "numDiscard", numDiscard)
|
||||
|
||||
s.lc.KvCacheSeqRm(seqIndex, seq.numKeep, seq.numKeep+numDiscard)
|
||||
s.lc.KvCacheSeqAdd(seqIndex, seq.numKeep+numDiscard, seq.nPast, -numDiscard)
|
||||
|
||||
seq.nPast -= numDiscard
|
||||
}
|
||||
|
||||
func (s *Server) run(ctx context.Context) {
|
||||
// TODO - should this be n_ctx / parallel like the old server.cpp setup?
|
||||
batch := llama.NewBatch(s.batchSize, 0, s.parallel)
|
||||
@ -155,10 +193,8 @@ func (s *Server) run(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
hitLimit := seq.numPredict > 0 && seq.numPredicted > seq.numPredict
|
||||
|
||||
// if past the num predict limit
|
||||
if hitLimit || seq.nPast > s.numCtx {
|
||||
if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
|
||||
seq.doneReason = "limit"
|
||||
close(seq.responses)
|
||||
s.lc.KvCacheSeqRm(i, 0, -1)
|
||||
@ -166,6 +202,10 @@ func (s *Server) run(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
if seq.nPast+len(seq.tokens) > s.numCtx {
|
||||
s.shiftContext(i)
|
||||
}
|
||||
|
||||
if seq.t_start_process_prompt.IsZero() {
|
||||
seq.t_start_process_prompt = time.Now()
|
||||
}
|
||||
@ -350,7 +390,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
samplingParams.Seed = uint32(req.Seed)
|
||||
samplingParams.Grammar = req.Grammar
|
||||
|
||||
seq := s.NewSequence(req.Prompt, req.NumPredict, req.Stop, &samplingParams, false)
|
||||
seq := s.NewSequence(req.Prompt, NewSequenceParams{
|
||||
numPredict: req.NumPredict,
|
||||
stop: req.Stop,
|
||||
numKeep: req.NumKeep,
|
||||
samplingParams: &samplingParams,
|
||||
embedding: false,
|
||||
})
|
||||
|
||||
// TODO (jmorganca): add to sequence queue instead of
|
||||
// failing if a slot isn't available
|
||||
@ -428,7 +474,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
embeddings := make([][]float32, len(req.Content))
|
||||
var processed int
|
||||
for i, content := range req.Content {
|
||||
seqs[i] = s.NewSequence(content, 0, nil, nil, true)
|
||||
seqs[i] = s.NewSequence(content, NewSequenceParams{embedding: true})
|
||||
}
|
||||
|
||||
// TODO - refactor to go routines to add seq's and drain the responses
|
||||
@ -563,6 +609,10 @@ func main() {
|
||||
ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
|
||||
server.lc = llama.NewContextWithModel(server.model, ctxParams)
|
||||
|
||||
if server.model.ShouldAddBOSToken() {
|
||||
server.bosToken = 1
|
||||
}
|
||||
|
||||
if *ppath != "" {
|
||||
server.cc = llama.NewClipContext(*ppath)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user