From 72be8e27c41dd7ff2fb1e9f9794b136255f212d2 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sun, 26 May 2024 23:01:05 -0700 Subject: [PATCH] sampling --- llama/llama.go | 64 ++++++++++++++++++++++++++++++++++++++++++ llama/runner/runner.go | 55 ++++++++++++++++++++++++++++-------- llama/sampling_ext.cpp | 45 +++++++++++++++++++++++++++++ llama/sampling_ext.h | 47 +++++++++++++++++++++++++++++++ 4 files changed, 200 insertions(+), 11 deletions(-) create mode 100644 llama/sampling_ext.cpp create mode 100644 llama/sampling_ext.h diff --git a/llama/llama.go b/llama/llama.go index 27cf7516..b6962f9a 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -28,6 +28,7 @@ package llama // #include "llama.h" // #include "clip.h" // #include "llava.h" +// #include "sampling_ext.h" import "C" import ( "fmt" @@ -244,6 +245,7 @@ func Quantize(infile, outfile string, ftype llm.FileType) error { return nil } +// llava type ClipContext struct { c *C.struct_clip_ctx } @@ -270,3 +272,65 @@ func NewLlavaImageEmbed(clipContext *ClipContext, data []byte) *LlavaImageEmbed func LlavaEvalImageEmbed(llamaContext *Context, embed *LlavaImageEmbed, nBatch int, nPast *int) { C.llava_eval_image_embed(llamaContext.c, embed.c, C.int(nBatch), (*C.int)(unsafe.Pointer(nPast))) } + +// sampling +// TODO: this is a temporary wrapper to allow calling C++ code from CGo +type SamplingContext struct { + c *C.struct_llama_sampling_context +} + +type SamplingParams struct { + TopK int + TopP float32 + TfsZ float32 + TypicalP float32 + Temp float32 + PenaltyRepeat float32 + PenaltyFreq float32 + PenaltyPresent float32 + Mirostat int + MirostatTau float32 + MirostatEta float32 + PenalizeNl bool + Seed uint32 +} + +func NewSamplingContext(params SamplingParams) *SamplingContext { + var cparams C.struct_llama_sampling_cparams + cparams.top_k = C.int32_t(params.TopK) + cparams.top_p = C.float(params.TopP) + cparams.tfs_z = C.float(params.TfsZ) + cparams.typical_p = C.float(params.TypicalP) + cparams.temp = C.float(params.Temp) + cparams.penalty_repeat = C.float(params.PenaltyRepeat) + cparams.penalty_freq = C.float(params.PenaltyFreq) + cparams.penalty_present = C.float(params.PenaltyFreq) + cparams.mirostat = C.int32_t(params.Mirostat) + cparams.mirostat_tau = C.float(params.MirostatTau) + cparams.mirostat_eta = C.float(params.MirostatEta) + cparams.penalize_nl = C.bool(params.PenalizeNl) + cparams.seed = C.uint32_t(params.Seed) + return &SamplingContext{c: C.llama_sampling_cinit(&cparams)} +} + +func (s *SamplingContext) Free() { + C.llama_sampling_cfree(s.c) +} + +func (s *SamplingContext) Reset() { + C.llama_sampling_creset(s.c) +} + +func (s *SamplingContext) Sample(ctxMain *Context, ctxConfig *Context, idx int) int { + // TODO (jmorganca): handle nil for all args + if ctxConfig == nil { + return int(C.llama_sampling_csample(s.c, ctxMain.c, nil, C.int(idx))) + } + + return int(C.llama_sampling_csample(s.c, ctxMain.c, ctxConfig.c, C.int(idx))) + +} + +func (s *SamplingContext) Accept(ctxMain *Context, id int, applyGrammar bool) { + C.llama_sampling_caccept(s.c, ctxMain.c, C.llama_token(id), C.bool(applyGrammar)) +} diff --git a/llama/runner/runner.go b/llama/runner/runner.go index e75ec671..0e2d251d 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -24,6 +24,8 @@ type Sequence struct { tokens []int responses chan string + + samplingCtx *llama.SamplingContext } // prompt returns true if the prompt is still being processed @@ -31,15 +33,41 @@ func (s *Sequence) prompt() bool { return s.nPast < len(s.tokens)-1 } -func (s *Server) NewSequence(text string, w http.ResponseWriter) *Sequence { - tokens, err := s.lc.Model().Tokenize(text, 2048, true, true) +func DefaultParams() llama.SamplingParams { + return llama.SamplingParams{} +} + +func (s *Server) NewSequence(r Request, w http.ResponseWriter) *Sequence { + var samplingParams llama.SamplingParams + samplingParams.TopK = r.TopK + samplingParams.TopP = r.TopP + samplingParams.TfsZ = r.TFSZ + samplingParams.TypicalP = r.TypicalP + samplingParams.Temp = r.Temperature + samplingParams.PenaltyRepeat = r.RepeatPenalty + samplingParams.PenaltyFreq = r.FrequencyPenalty + samplingParams.PenaltyPresent = r.PresencePenalty + samplingParams.Mirostat = r.Mirostat + samplingParams.MirostatTau = r.MirostatTau + samplingParams.MirostatEta = r.MirostatEta + samplingParams.PenalizeNl = r.PenalizeNewline + samplingParams.Seed = uint32(r.Seed) + + tokens, err := s.lc.Model().Tokenize(r.Prompt, 2048, false, true) if err != nil { panic(err) } + sc := llama.NewSamplingContext(samplingParams) + + for _, t := range tokens { + sc.Accept(s.lc, t, false) + } + return &Sequence{ - tokens: tokens, - responses: make(chan string, 1), + tokens: tokens, + responses: make(chan string, 1), + samplingCtx: sc, } } @@ -80,7 +108,6 @@ func (s *Server) run(ctx context.Context) { slog.Info("Processing batch", "seqs", len(s.seqs)) s.mu.Lock() for s.allNil() { - fmt.Println("wait") s.cond.Wait() // Wait until an item is added } s.mu.Unlock() @@ -133,8 +160,16 @@ func (s *Server) run(ctx context.Context) { // sample a token // TODO: sample based on the sequence fmt.Println("Sampling token", i, ibatch[i]) - logits := s.lc.GetLogitsIth(ibatch[i]) - token := s.lc.SampleTokenGreedy(logits) + fmt.Println("calling sample", s.lc, nil, ibatch[i]) + token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i]) + seq.samplingCtx.Accept(s.lc, token, true) + + // logits := s.lc.GetLogitsIth(ibatch[i]) + // token := s.lc.SampleTokenGreedy(logits) + fmt.Println("sampled", token, s.model.TokenToPiece(token)) + + seq.responses <- s.model.TokenToPiece(token) + seq.tokens = []int{token} // if it's an end of sequence token, break // TODO: just end this sequence @@ -145,9 +180,6 @@ func (s *Server) run(ctx context.Context) { s.seqs[i] = nil continue } - - seq.responses <- s.model.TokenToPiece(token) - seq.tokens = []int{token} } batch.Clear() @@ -168,6 +200,7 @@ type Response struct { func (s *Server) handler(w http.ResponseWriter, r *http.Request) { var request Request + request.Options = api.DefaultOptions() if err := json.NewDecoder(r.Body).Decode(&request); err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return @@ -178,7 +211,7 @@ func (s *Server) handler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Transfer-Encoding", "chunked") w.WriteHeader(http.StatusOK) - seq := s.NewSequence(request.Prompt, w) + seq := s.NewSequence(request, w) s.mu.Lock() for i, sq := range s.seqs { diff --git a/llama/sampling_ext.cpp b/llama/sampling_ext.cpp new file mode 100644 index 00000000..93ef0a93 --- /dev/null +++ b/llama/sampling_ext.cpp @@ -0,0 +1,45 @@ +// TODO: this is a temporary wrapper to allow calling C++ code from CGo +#include "sampling.h" +#include "sampling_ext.h" + +struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparams *params) { + llama_sampling_params sparams; + sparams.top_k = params->top_k; + sparams.top_p = params->top_p; + sparams.tfs_z = params->tfs_z; + sparams.typical_p = params->typical_p; + sparams.temp = params->temp; + sparams.penalty_repeat = params->penalty_repeat; + sparams.penalty_freq = params->penalty_freq; + sparams.penalty_present = params->penalty_present; + sparams.mirostat = params->mirostat; + sparams.mirostat_tau = params->mirostat_tau; + sparams.mirostat_eta = params->mirostat_eta; + sparams.penalize_nl = params->penalize_nl; + sparams.seed = params->seed; + return llama_sampling_init(sparams); +} + +void llama_sampling_cfree(struct llama_sampling_context * ctx){ + llama_sampling_free(ctx); +} + +void llama_sampling_creset(struct llama_sampling_context * ctx){ + llama_sampling_reset(ctx); +} + +llama_token llama_sampling_csample( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + int idx) { + return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx); +} + +void llama_sampling_caccept( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + llama_token id, + bool apply_grammar) { + llama_sampling_accept(ctx_sampling, ctx_main, id, apply_grammar); +} diff --git a/llama/sampling_ext.h b/llama/sampling_ext.h new file mode 100644 index 00000000..9e110e35 --- /dev/null +++ b/llama/sampling_ext.h @@ -0,0 +1,47 @@ +// TODO: this is a temporary wrapper to allow calling C++ code from CGo +#ifndef LLAMA_SAMPLING_EXT_H +#define LLAMA_SAMPLING_EXT_H + +#include "llama.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct llama_sampling_cparams { + int32_t top_k; + float top_p; + float tfs_z; + float typical_p; + float temp; + float penalty_repeat; + float penalty_freq; + float penalty_present; + int32_t mirostat; + float mirostat_tau; + float mirostat_eta; + bool penalize_nl; + uint32_t seed; +}; + +struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparams *params); +void llama_sampling_cfree(struct llama_sampling_context * ctx); +void llama_sampling_creset(struct llama_sampling_context * ctx); + +llama_token llama_sampling_csample( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + int idx); + +void llama_sampling_caccept( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + llama_token id, + bool apply_grammar); + +#ifdef __cplusplus +} +#endif + +#endif // LLAMA_SAMPLING_EXT_H