From 7d0a45293804ceb19882c7be06517788032989be Mon Sep 17 00:00:00 2001 From: jmorganca Date: Tue, 28 May 2024 23:38:44 -0700 Subject: [PATCH] num predict --- llama/README.md | 4 ++-- llama/runner/runner.go | 11 +++++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/llama/README.md b/llama/README.md index c3228b31..46179d48 100644 --- a/llama/README.md +++ b/llama/README.md @@ -10,9 +10,9 @@ Supported: - [x] Windows CUDA - [x] Windows ROCm - [x] Linux CUDA -- [ ] Linux ROCm +- [x] Linux ROCm - [x] Llava -- [ ] Parallel Requests +- [x] Parallel Requests Extra build steps are required for CUDA and ROCm on Windows since `nvcc` and `hipcc` both require using msvc as the host compiler. For these small dlls are created: diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 7692d1c4..c2d81e23 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -23,6 +23,9 @@ type Sequence struct { // number of tokens evaluated nPast int + // number of tokens predicted so far + numPredicted int + // tokens left to evaluate tokens []int @@ -47,6 +50,7 @@ type Sequence struct { } // prompt returns true if the prompt is still being processed +// TODO (jmorganca): clean up this logic func (s *Sequence) prompt() bool { return s.nPast < len(s.tokens)-1 } @@ -203,8 +207,8 @@ func (s *Server) run(ctx context.Context) { continue } - // we've reached the context limit - if seq.nPast > s.numCtx { + // if past the num predict limit + if seq.numPredicted > seq.numPredict || seq.nPast > s.numCtx { seq.doneReason = "limit" close(seq.responses) s.lc.KvCacheSeqRm(i, 0, -1) @@ -269,6 +273,9 @@ func (s *Server) run(ctx context.Context) { seq.samplingCtx.Accept(s.lc, token, true) piece := s.model.TokenToPiece(token) + + seq.numPredicted++ + slog.Info("sampled", "piece", piece) // if it's an end of sequence token, break