This commit is contained in:
jmorganca 2024-05-26 23:14:44 -07:00
parent 72be8e27c4
commit c0b94376b2
4 changed files with 15 additions and 15 deletions

View File

@ -293,6 +293,7 @@ type SamplingParams struct {
MirostatEta float32
PenalizeNl bool
Seed uint32
Grammar string
}
func NewSamplingContext(params SamplingParams) *SamplingContext {
@ -310,6 +311,11 @@ func NewSamplingContext(params SamplingParams) *SamplingContext {
cparams.mirostat_eta = C.float(params.MirostatEta)
cparams.penalize_nl = C.bool(params.PenalizeNl)
cparams.seed = C.uint32_t(params.Seed)
grammar := C.CString(params.Grammar)
defer C.free(unsafe.Pointer(grammar))
cparams.grammar = grammar
return &SamplingContext{c: C.llama_sampling_cinit(&cparams)}
}

View File

@ -33,10 +33,6 @@ func (s *Sequence) prompt() bool {
return s.nPast < len(s.tokens)-1
}
func DefaultParams() llama.SamplingParams {
return llama.SamplingParams{}
}
func (s *Server) NewSequence(r Request, w http.ResponseWriter) *Sequence {
var samplingParams llama.SamplingParams
samplingParams.TopK = r.TopK
@ -52,6 +48,7 @@ func (s *Server) NewSequence(r Request, w http.ResponseWriter) *Sequence {
samplingParams.MirostatEta = r.MirostatEta
samplingParams.PenalizeNl = r.PenalizeNewline
samplingParams.Seed = uint32(r.Seed)
samplingParams.Grammar = r.Grammar
tokens, err := s.lc.Model().Tokenize(r.Prompt, 2048, false, true)
if err != nil {
@ -112,8 +109,6 @@ func (s *Server) run(ctx context.Context) {
}
s.mu.Unlock()
fmt.Println("seqs", s.seqs, len(s.seqs))
// prepare the batch
ibatch := make([]int, s.parallel)
for i, seq := range s.seqs {
@ -158,15 +153,10 @@ func (s *Server) run(ctx context.Context) {
}
// sample a token
// TODO: sample based on the sequence
fmt.Println("Sampling token", i, ibatch[i])
fmt.Println("calling sample", s.lc, nil, ibatch[i])
token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
seq.samplingCtx.Accept(s.lc, token, true)
// logits := s.lc.GetLogitsIth(ibatch[i])
// token := s.lc.SampleTokenGreedy(logits)
fmt.Println("sampled", token, s.model.TokenToPiece(token))
token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
seq.samplingCtx.Accept(s.lc, token, true)
seq.responses <- s.model.TokenToPiece(token)
seq.tokens = []int{token}
@ -177,6 +167,7 @@ func (s *Server) run(ctx context.Context) {
// TODO: end the sequence instead of quitting the pool
s.lc.KvCacheSeqRm(i, 0, -1)
close(seq.responses)
seq.samplingCtx.Free()
s.seqs[i] = nil
continue
}
@ -188,8 +179,9 @@ func (s *Server) run(ctx context.Context) {
}
type Request struct {
Prompt string `json:"prompt"`
Images []string `json:"images"`
Prompt string `json:"prompt"`
Images []string `json:"images"`
Grammar string `json:"grammar"`
api.Options
}

View File

@ -17,6 +17,7 @@ struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparam
sparams.mirostat_eta = params->mirostat_eta;
sparams.penalize_nl = params->penalize_nl;
sparams.seed = params->seed;
sparams.grammar = std::string(params->grammar);
return llama_sampling_init(sparams);
}

View File

@ -22,6 +22,7 @@ struct llama_sampling_cparams {
float mirostat_eta;
bool penalize_nl;
uint32_t seed;
char* grammar;
};
struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparams *params);