This commit is contained in:
jmorganca 2024-05-26 23:01:05 -07:00
parent d12db0568e
commit 72be8e27c4
4 changed files with 200 additions and 11 deletions

View File

@ -28,6 +28,7 @@ package llama
// #include "llama.h"
// #include "clip.h"
// #include "llava.h"
// #include "sampling_ext.h"
import "C"
import (
"fmt"
@ -244,6 +245,7 @@ func Quantize(infile, outfile string, ftype llm.FileType) error {
return nil
}
// llava
type ClipContext struct {
c *C.struct_clip_ctx
}
@ -270,3 +272,65 @@ func NewLlavaImageEmbed(clipContext *ClipContext, data []byte) *LlavaImageEmbed
func LlavaEvalImageEmbed(llamaContext *Context, embed *LlavaImageEmbed, nBatch int, nPast *int) {
C.llava_eval_image_embed(llamaContext.c, embed.c, C.int(nBatch), (*C.int)(unsafe.Pointer(nPast)))
}
// sampling
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
type SamplingContext struct {
c *C.struct_llama_sampling_context
}
type SamplingParams struct {
TopK int
TopP float32
TfsZ float32
TypicalP float32
Temp float32
PenaltyRepeat float32
PenaltyFreq float32
PenaltyPresent float32
Mirostat int
MirostatTau float32
MirostatEta float32
PenalizeNl bool
Seed uint32
}
func NewSamplingContext(params SamplingParams) *SamplingContext {
var cparams C.struct_llama_sampling_cparams
cparams.top_k = C.int32_t(params.TopK)
cparams.top_p = C.float(params.TopP)
cparams.tfs_z = C.float(params.TfsZ)
cparams.typical_p = C.float(params.TypicalP)
cparams.temp = C.float(params.Temp)
cparams.penalty_repeat = C.float(params.PenaltyRepeat)
cparams.penalty_freq = C.float(params.PenaltyFreq)
cparams.penalty_present = C.float(params.PenaltyFreq)
cparams.mirostat = C.int32_t(params.Mirostat)
cparams.mirostat_tau = C.float(params.MirostatTau)
cparams.mirostat_eta = C.float(params.MirostatEta)
cparams.penalize_nl = C.bool(params.PenalizeNl)
cparams.seed = C.uint32_t(params.Seed)
return &SamplingContext{c: C.llama_sampling_cinit(&cparams)}
}
func (s *SamplingContext) Free() {
C.llama_sampling_cfree(s.c)
}
func (s *SamplingContext) Reset() {
C.llama_sampling_creset(s.c)
}
func (s *SamplingContext) Sample(ctxMain *Context, ctxConfig *Context, idx int) int {
// TODO (jmorganca): handle nil for all args
if ctxConfig == nil {
return int(C.llama_sampling_csample(s.c, ctxMain.c, nil, C.int(idx)))
}
return int(C.llama_sampling_csample(s.c, ctxMain.c, ctxConfig.c, C.int(idx)))
}
func (s *SamplingContext) Accept(ctxMain *Context, id int, applyGrammar bool) {
C.llama_sampling_caccept(s.c, ctxMain.c, C.llama_token(id), C.bool(applyGrammar))
}

View File

@ -24,6 +24,8 @@ type Sequence struct {
tokens []int
responses chan string
samplingCtx *llama.SamplingContext
}
// prompt returns true if the prompt is still being processed
@ -31,15 +33,41 @@ func (s *Sequence) prompt() bool {
return s.nPast < len(s.tokens)-1
}
func (s *Server) NewSequence(text string, w http.ResponseWriter) *Sequence {
tokens, err := s.lc.Model().Tokenize(text, 2048, true, true)
func DefaultParams() llama.SamplingParams {
return llama.SamplingParams{}
}
func (s *Server) NewSequence(r Request, w http.ResponseWriter) *Sequence {
var samplingParams llama.SamplingParams
samplingParams.TopK = r.TopK
samplingParams.TopP = r.TopP
samplingParams.TfsZ = r.TFSZ
samplingParams.TypicalP = r.TypicalP
samplingParams.Temp = r.Temperature
samplingParams.PenaltyRepeat = r.RepeatPenalty
samplingParams.PenaltyFreq = r.FrequencyPenalty
samplingParams.PenaltyPresent = r.PresencePenalty
samplingParams.Mirostat = r.Mirostat
samplingParams.MirostatTau = r.MirostatTau
samplingParams.MirostatEta = r.MirostatEta
samplingParams.PenalizeNl = r.PenalizeNewline
samplingParams.Seed = uint32(r.Seed)
tokens, err := s.lc.Model().Tokenize(r.Prompt, 2048, false, true)
if err != nil {
panic(err)
}
sc := llama.NewSamplingContext(samplingParams)
for _, t := range tokens {
sc.Accept(s.lc, t, false)
}
return &Sequence{
tokens: tokens,
responses: make(chan string, 1),
tokens: tokens,
responses: make(chan string, 1),
samplingCtx: sc,
}
}
@ -80,7 +108,6 @@ func (s *Server) run(ctx context.Context) {
slog.Info("Processing batch", "seqs", len(s.seqs))
s.mu.Lock()
for s.allNil() {
fmt.Println("wait")
s.cond.Wait() // Wait until an item is added
}
s.mu.Unlock()
@ -133,8 +160,16 @@ func (s *Server) run(ctx context.Context) {
// sample a token
// TODO: sample based on the sequence
fmt.Println("Sampling token", i, ibatch[i])
logits := s.lc.GetLogitsIth(ibatch[i])
token := s.lc.SampleTokenGreedy(logits)
fmt.Println("calling sample", s.lc, nil, ibatch[i])
token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
seq.samplingCtx.Accept(s.lc, token, true)
// logits := s.lc.GetLogitsIth(ibatch[i])
// token := s.lc.SampleTokenGreedy(logits)
fmt.Println("sampled", token, s.model.TokenToPiece(token))
seq.responses <- s.model.TokenToPiece(token)
seq.tokens = []int{token}
// if it's an end of sequence token, break
// TODO: just end this sequence
@ -145,9 +180,6 @@ func (s *Server) run(ctx context.Context) {
s.seqs[i] = nil
continue
}
seq.responses <- s.model.TokenToPiece(token)
seq.tokens = []int{token}
}
batch.Clear()
@ -168,6 +200,7 @@ type Response struct {
func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
var request Request
request.Options = api.DefaultOptions()
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
@ -178,7 +211,7 @@ func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Transfer-Encoding", "chunked")
w.WriteHeader(http.StatusOK)
seq := s.NewSequence(request.Prompt, w)
seq := s.NewSequence(request, w)
s.mu.Lock()
for i, sq := range s.seqs {

45
llama/sampling_ext.cpp vendored Normal file
View File

@ -0,0 +1,45 @@
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
#include "sampling.h"
#include "sampling_ext.h"
struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparams *params) {
llama_sampling_params sparams;
sparams.top_k = params->top_k;
sparams.top_p = params->top_p;
sparams.tfs_z = params->tfs_z;
sparams.typical_p = params->typical_p;
sparams.temp = params->temp;
sparams.penalty_repeat = params->penalty_repeat;
sparams.penalty_freq = params->penalty_freq;
sparams.penalty_present = params->penalty_present;
sparams.mirostat = params->mirostat;
sparams.mirostat_tau = params->mirostat_tau;
sparams.mirostat_eta = params->mirostat_eta;
sparams.penalize_nl = params->penalize_nl;
sparams.seed = params->seed;
return llama_sampling_init(sparams);
}
void llama_sampling_cfree(struct llama_sampling_context * ctx){
llama_sampling_free(ctx);
}
void llama_sampling_creset(struct llama_sampling_context * ctx){
llama_sampling_reset(ctx);
}
llama_token llama_sampling_csample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx) {
return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx);
}
void llama_sampling_caccept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
llama_token id,
bool apply_grammar) {
llama_sampling_accept(ctx_sampling, ctx_main, id, apply_grammar);
}

47
llama/sampling_ext.h vendored Normal file
View File

@ -0,0 +1,47 @@
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
#ifndef LLAMA_SAMPLING_EXT_H
#define LLAMA_SAMPLING_EXT_H
#include "llama.h"
#ifdef __cplusplus
extern "C" {
#endif
struct llama_sampling_cparams {
int32_t top_k;
float top_p;
float tfs_z;
float typical_p;
float temp;
float penalty_repeat;
float penalty_freq;
float penalty_present;
int32_t mirostat;
float mirostat_tau;
float mirostat_eta;
bool penalize_nl;
uint32_t seed;
};
struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparams *params);
void llama_sampling_cfree(struct llama_sampling_context * ctx);
void llama_sampling_creset(struct llama_sampling_context * ctx);
llama_token llama_sampling_csample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx);
void llama_sampling_caccept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
llama_token id,
bool apply_grammar);
#ifdef __cplusplus
}
#endif
#endif // LLAMA_SAMPLING_EXT_H