num predict

This commit is contained in:
jmorganca 2024-05-28 23:38:44 -07:00
parent 43efc893d7
commit 7d0a452938
2 changed files with 11 additions and 4 deletions

View File

@ -10,9 +10,9 @@ Supported:
- [x] Windows CUDA
- [x] Windows ROCm
- [x] Linux CUDA
- [ ] Linux ROCm
- [x] Linux ROCm
- [x] Llava
- [ ] Parallel Requests
- [x] Parallel Requests
Extra build steps are required for CUDA and ROCm on Windows since `nvcc` and `hipcc` both require using msvc as the host compiler. For these small dlls are created:

View File

@ -23,6 +23,9 @@ type Sequence struct {
// number of tokens evaluated
nPast int
// number of tokens predicted so far
numPredicted int
// tokens left to evaluate
tokens []int
@ -47,6 +50,7 @@ type Sequence struct {
}
// prompt returns true if the prompt is still being processed
// TODO (jmorganca): clean up this logic
func (s *Sequence) prompt() bool {
return s.nPast < len(s.tokens)-1
}
@ -203,8 +207,8 @@ func (s *Server) run(ctx context.Context) {
continue
}
// we've reached the context limit
if seq.nPast > s.numCtx {
// if past the num predict limit
if seq.numPredicted > seq.numPredict || seq.nPast > s.numCtx {
seq.doneReason = "limit"
close(seq.responses)
s.lc.KvCacheSeqRm(i, 0, -1)
@ -269,6 +273,9 @@ func (s *Server) run(ctx context.Context) {
seq.samplingCtx.Accept(s.lc, token, true)
piece := s.model.TokenToPiece(token)
seq.numPredicted++
slog.Info("sampled", "piece", piece)
// if it's an end of sequence token, break