grammar
This commit is contained in:
parent
72be8e27c4
commit
c0b94376b2
@ -293,6 +293,7 @@ type SamplingParams struct {
|
||||
MirostatEta float32
|
||||
PenalizeNl bool
|
||||
Seed uint32
|
||||
Grammar string
|
||||
}
|
||||
|
||||
func NewSamplingContext(params SamplingParams) *SamplingContext {
|
||||
@ -310,6 +311,11 @@ func NewSamplingContext(params SamplingParams) *SamplingContext {
|
||||
cparams.mirostat_eta = C.float(params.MirostatEta)
|
||||
cparams.penalize_nl = C.bool(params.PenalizeNl)
|
||||
cparams.seed = C.uint32_t(params.Seed)
|
||||
|
||||
grammar := C.CString(params.Grammar)
|
||||
defer C.free(unsafe.Pointer(grammar))
|
||||
|
||||
cparams.grammar = grammar
|
||||
return &SamplingContext{c: C.llama_sampling_cinit(&cparams)}
|
||||
}
|
||||
|
||||
|
@ -33,10 +33,6 @@ func (s *Sequence) prompt() bool {
|
||||
return s.nPast < len(s.tokens)-1
|
||||
}
|
||||
|
||||
func DefaultParams() llama.SamplingParams {
|
||||
return llama.SamplingParams{}
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(r Request, w http.ResponseWriter) *Sequence {
|
||||
var samplingParams llama.SamplingParams
|
||||
samplingParams.TopK = r.TopK
|
||||
@ -52,6 +48,7 @@ func (s *Server) NewSequence(r Request, w http.ResponseWriter) *Sequence {
|
||||
samplingParams.MirostatEta = r.MirostatEta
|
||||
samplingParams.PenalizeNl = r.PenalizeNewline
|
||||
samplingParams.Seed = uint32(r.Seed)
|
||||
samplingParams.Grammar = r.Grammar
|
||||
|
||||
tokens, err := s.lc.Model().Tokenize(r.Prompt, 2048, false, true)
|
||||
if err != nil {
|
||||
@ -112,8 +109,6 @@ func (s *Server) run(ctx context.Context) {
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
fmt.Println("seqs", s.seqs, len(s.seqs))
|
||||
|
||||
// prepare the batch
|
||||
ibatch := make([]int, s.parallel)
|
||||
for i, seq := range s.seqs {
|
||||
@ -158,15 +153,10 @@ func (s *Server) run(ctx context.Context) {
|
||||
}
|
||||
|
||||
// sample a token
|
||||
// TODO: sample based on the sequence
|
||||
fmt.Println("Sampling token", i, ibatch[i])
|
||||
fmt.Println("calling sample", s.lc, nil, ibatch[i])
|
||||
token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
|
||||
seq.samplingCtx.Accept(s.lc, token, true)
|
||||
|
||||
// logits := s.lc.GetLogitsIth(ibatch[i])
|
||||
// token := s.lc.SampleTokenGreedy(logits)
|
||||
fmt.Println("sampled", token, s.model.TokenToPiece(token))
|
||||
token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
|
||||
seq.samplingCtx.Accept(s.lc, token, true)
|
||||
|
||||
seq.responses <- s.model.TokenToPiece(token)
|
||||
seq.tokens = []int{token}
|
||||
@ -177,6 +167,7 @@ func (s *Server) run(ctx context.Context) {
|
||||
// TODO: end the sequence instead of quitting the pool
|
||||
s.lc.KvCacheSeqRm(i, 0, -1)
|
||||
close(seq.responses)
|
||||
seq.samplingCtx.Free()
|
||||
s.seqs[i] = nil
|
||||
continue
|
||||
}
|
||||
@ -188,8 +179,9 @@ func (s *Server) run(ctx context.Context) {
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Images []string `json:"images"`
|
||||
Prompt string `json:"prompt"`
|
||||
Images []string `json:"images"`
|
||||
Grammar string `json:"grammar"`
|
||||
|
||||
api.Options
|
||||
}
|
||||
|
1
llama/sampling_ext.cpp
vendored
1
llama/sampling_ext.cpp
vendored
@ -17,6 +17,7 @@ struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparam
|
||||
sparams.mirostat_eta = params->mirostat_eta;
|
||||
sparams.penalize_nl = params->penalize_nl;
|
||||
sparams.seed = params->seed;
|
||||
sparams.grammar = std::string(params->grammar);
|
||||
return llama_sampling_init(sparams);
|
||||
}
|
||||
|
||||
|
1
llama/sampling_ext.h
vendored
1
llama/sampling_ext.h
vendored
@ -22,6 +22,7 @@ struct llama_sampling_cparams {
|
||||
float mirostat_eta;
|
||||
bool penalize_nl;
|
||||
uint32_t seed;
|
||||
char* grammar;
|
||||
};
|
||||
|
||||
struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparams *params);
|
||||
|
Loading…
x
Reference in New Issue
Block a user