diff --git a/llama/llama.go b/llama/llama.go index b6962f9a..ae051b5b 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -293,6 +293,7 @@ type SamplingParams struct { MirostatEta float32 PenalizeNl bool Seed uint32 + Grammar string } func NewSamplingContext(params SamplingParams) *SamplingContext { @@ -310,6 +311,11 @@ func NewSamplingContext(params SamplingParams) *SamplingContext { cparams.mirostat_eta = C.float(params.MirostatEta) cparams.penalize_nl = C.bool(params.PenalizeNl) cparams.seed = C.uint32_t(params.Seed) + + grammar := C.CString(params.Grammar) + defer C.free(unsafe.Pointer(grammar)) + + cparams.grammar = grammar return &SamplingContext{c: C.llama_sampling_cinit(&cparams)} } diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 0e2d251d..a732dae7 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -33,10 +33,6 @@ func (s *Sequence) prompt() bool { return s.nPast < len(s.tokens)-1 } -func DefaultParams() llama.SamplingParams { - return llama.SamplingParams{} -} - func (s *Server) NewSequence(r Request, w http.ResponseWriter) *Sequence { var samplingParams llama.SamplingParams samplingParams.TopK = r.TopK @@ -52,6 +48,7 @@ func (s *Server) NewSequence(r Request, w http.ResponseWriter) *Sequence { samplingParams.MirostatEta = r.MirostatEta samplingParams.PenalizeNl = r.PenalizeNewline samplingParams.Seed = uint32(r.Seed) + samplingParams.Grammar = r.Grammar tokens, err := s.lc.Model().Tokenize(r.Prompt, 2048, false, true) if err != nil { @@ -112,8 +109,6 @@ func (s *Server) run(ctx context.Context) { } s.mu.Unlock() - fmt.Println("seqs", s.seqs, len(s.seqs)) - // prepare the batch ibatch := make([]int, s.parallel) for i, seq := range s.seqs { @@ -158,15 +153,10 @@ func (s *Server) run(ctx context.Context) { } // sample a token - // TODO: sample based on the sequence - fmt.Println("Sampling token", i, ibatch[i]) - fmt.Println("calling sample", s.lc, nil, ibatch[i]) - token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i]) - seq.samplingCtx.Accept(s.lc, token, true) - // logits := s.lc.GetLogitsIth(ibatch[i]) // token := s.lc.SampleTokenGreedy(logits) - fmt.Println("sampled", token, s.model.TokenToPiece(token)) + token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i]) + seq.samplingCtx.Accept(s.lc, token, true) seq.responses <- s.model.TokenToPiece(token) seq.tokens = []int{token} @@ -177,6 +167,7 @@ func (s *Server) run(ctx context.Context) { // TODO: end the sequence instead of quitting the pool s.lc.KvCacheSeqRm(i, 0, -1) close(seq.responses) + seq.samplingCtx.Free() s.seqs[i] = nil continue } @@ -188,8 +179,9 @@ func (s *Server) run(ctx context.Context) { } type Request struct { - Prompt string `json:"prompt"` - Images []string `json:"images"` + Prompt string `json:"prompt"` + Images []string `json:"images"` + Grammar string `json:"grammar"` api.Options } diff --git a/llama/sampling_ext.cpp b/llama/sampling_ext.cpp index 93ef0a93..db6d9efc 100644 --- a/llama/sampling_ext.cpp +++ b/llama/sampling_ext.cpp @@ -17,6 +17,7 @@ struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparam sparams.mirostat_eta = params->mirostat_eta; sparams.penalize_nl = params->penalize_nl; sparams.seed = params->seed; + sparams.grammar = std::string(params->grammar); return llama_sampling_init(sparams); } diff --git a/llama/sampling_ext.h b/llama/sampling_ext.h index 9e110e35..ac791ddf 100644 --- a/llama/sampling_ext.h +++ b/llama/sampling_ext.h @@ -22,6 +22,7 @@ struct llama_sampling_cparams { float mirostat_eta; bool penalize_nl; uint32_t seed; + char* grammar; }; struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparams *params);