forked from third-party-mirrors/ollama
num predict
This commit is contained in:
parent
43efc893d7
commit
7d0a452938
@ -10,9 +10,9 @@ Supported:
|
||||
- [x] Windows CUDA
|
||||
- [x] Windows ROCm
|
||||
- [x] Linux CUDA
|
||||
- [ ] Linux ROCm
|
||||
- [x] Linux ROCm
|
||||
- [x] Llava
|
||||
- [ ] Parallel Requests
|
||||
- [x] Parallel Requests
|
||||
|
||||
Extra build steps are required for CUDA and ROCm on Windows since `nvcc` and `hipcc` both require using msvc as the host compiler. For these small dlls are created:
|
||||
|
||||
|
@ -23,6 +23,9 @@ type Sequence struct {
|
||||
// number of tokens evaluated
|
||||
nPast int
|
||||
|
||||
// number of tokens predicted so far
|
||||
numPredicted int
|
||||
|
||||
// tokens left to evaluate
|
||||
tokens []int
|
||||
|
||||
@ -47,6 +50,7 @@ type Sequence struct {
|
||||
}
|
||||
|
||||
// prompt returns true if the prompt is still being processed
|
||||
// TODO (jmorganca): clean up this logic
|
||||
func (s *Sequence) prompt() bool {
|
||||
return s.nPast < len(s.tokens)-1
|
||||
}
|
||||
@ -203,8 +207,8 @@ func (s *Server) run(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
// we've reached the context limit
|
||||
if seq.nPast > s.numCtx {
|
||||
// if past the num predict limit
|
||||
if seq.numPredicted > seq.numPredict || seq.nPast > s.numCtx {
|
||||
seq.doneReason = "limit"
|
||||
close(seq.responses)
|
||||
s.lc.KvCacheSeqRm(i, 0, -1)
|
||||
@ -269,6 +273,9 @@ func (s *Server) run(ctx context.Context) {
|
||||
|
||||
seq.samplingCtx.Accept(s.lc, token, true)
|
||||
piece := s.model.TokenToPiece(token)
|
||||
|
||||
seq.numPredicted++
|
||||
|
||||
slog.Info("sampled", "piece", piece)
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
|
Loading…
x
Reference in New Issue
Block a user