forked from third-party-mirrors/ollama
sampling
This commit is contained in:
parent
d12db0568e
commit
72be8e27c4
@ -28,6 +28,7 @@ package llama
|
||||
// #include "llama.h"
|
||||
// #include "clip.h"
|
||||
// #include "llava.h"
|
||||
// #include "sampling_ext.h"
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
@ -244,6 +245,7 @@ func Quantize(infile, outfile string, ftype llm.FileType) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// llava
|
||||
type ClipContext struct {
|
||||
c *C.struct_clip_ctx
|
||||
}
|
||||
@ -270,3 +272,65 @@ func NewLlavaImageEmbed(clipContext *ClipContext, data []byte) *LlavaImageEmbed
|
||||
func LlavaEvalImageEmbed(llamaContext *Context, embed *LlavaImageEmbed, nBatch int, nPast *int) {
|
||||
C.llava_eval_image_embed(llamaContext.c, embed.c, C.int(nBatch), (*C.int)(unsafe.Pointer(nPast)))
|
||||
}
|
||||
|
||||
// sampling
|
||||
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
|
||||
type SamplingContext struct {
|
||||
c *C.struct_llama_sampling_context
|
||||
}
|
||||
|
||||
type SamplingParams struct {
|
||||
TopK int
|
||||
TopP float32
|
||||
TfsZ float32
|
||||
TypicalP float32
|
||||
Temp float32
|
||||
PenaltyRepeat float32
|
||||
PenaltyFreq float32
|
||||
PenaltyPresent float32
|
||||
Mirostat int
|
||||
MirostatTau float32
|
||||
MirostatEta float32
|
||||
PenalizeNl bool
|
||||
Seed uint32
|
||||
}
|
||||
|
||||
func NewSamplingContext(params SamplingParams) *SamplingContext {
|
||||
var cparams C.struct_llama_sampling_cparams
|
||||
cparams.top_k = C.int32_t(params.TopK)
|
||||
cparams.top_p = C.float(params.TopP)
|
||||
cparams.tfs_z = C.float(params.TfsZ)
|
||||
cparams.typical_p = C.float(params.TypicalP)
|
||||
cparams.temp = C.float(params.Temp)
|
||||
cparams.penalty_repeat = C.float(params.PenaltyRepeat)
|
||||
cparams.penalty_freq = C.float(params.PenaltyFreq)
|
||||
cparams.penalty_present = C.float(params.PenaltyFreq)
|
||||
cparams.mirostat = C.int32_t(params.Mirostat)
|
||||
cparams.mirostat_tau = C.float(params.MirostatTau)
|
||||
cparams.mirostat_eta = C.float(params.MirostatEta)
|
||||
cparams.penalize_nl = C.bool(params.PenalizeNl)
|
||||
cparams.seed = C.uint32_t(params.Seed)
|
||||
return &SamplingContext{c: C.llama_sampling_cinit(&cparams)}
|
||||
}
|
||||
|
||||
func (s *SamplingContext) Free() {
|
||||
C.llama_sampling_cfree(s.c)
|
||||
}
|
||||
|
||||
func (s *SamplingContext) Reset() {
|
||||
C.llama_sampling_creset(s.c)
|
||||
}
|
||||
|
||||
func (s *SamplingContext) Sample(ctxMain *Context, ctxConfig *Context, idx int) int {
|
||||
// TODO (jmorganca): handle nil for all args
|
||||
if ctxConfig == nil {
|
||||
return int(C.llama_sampling_csample(s.c, ctxMain.c, nil, C.int(idx)))
|
||||
}
|
||||
|
||||
return int(C.llama_sampling_csample(s.c, ctxMain.c, ctxConfig.c, C.int(idx)))
|
||||
|
||||
}
|
||||
|
||||
func (s *SamplingContext) Accept(ctxMain *Context, id int, applyGrammar bool) {
|
||||
C.llama_sampling_caccept(s.c, ctxMain.c, C.llama_token(id), C.bool(applyGrammar))
|
||||
}
|
||||
|
@ -24,6 +24,8 @@ type Sequence struct {
|
||||
tokens []int
|
||||
|
||||
responses chan string
|
||||
|
||||
samplingCtx *llama.SamplingContext
|
||||
}
|
||||
|
||||
// prompt returns true if the prompt is still being processed
|
||||
@ -31,15 +33,41 @@ func (s *Sequence) prompt() bool {
|
||||
return s.nPast < len(s.tokens)-1
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(text string, w http.ResponseWriter) *Sequence {
|
||||
tokens, err := s.lc.Model().Tokenize(text, 2048, true, true)
|
||||
func DefaultParams() llama.SamplingParams {
|
||||
return llama.SamplingParams{}
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(r Request, w http.ResponseWriter) *Sequence {
|
||||
var samplingParams llama.SamplingParams
|
||||
samplingParams.TopK = r.TopK
|
||||
samplingParams.TopP = r.TopP
|
||||
samplingParams.TfsZ = r.TFSZ
|
||||
samplingParams.TypicalP = r.TypicalP
|
||||
samplingParams.Temp = r.Temperature
|
||||
samplingParams.PenaltyRepeat = r.RepeatPenalty
|
||||
samplingParams.PenaltyFreq = r.FrequencyPenalty
|
||||
samplingParams.PenaltyPresent = r.PresencePenalty
|
||||
samplingParams.Mirostat = r.Mirostat
|
||||
samplingParams.MirostatTau = r.MirostatTau
|
||||
samplingParams.MirostatEta = r.MirostatEta
|
||||
samplingParams.PenalizeNl = r.PenalizeNewline
|
||||
samplingParams.Seed = uint32(r.Seed)
|
||||
|
||||
tokens, err := s.lc.Model().Tokenize(r.Prompt, 2048, false, true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
sc := llama.NewSamplingContext(samplingParams)
|
||||
|
||||
for _, t := range tokens {
|
||||
sc.Accept(s.lc, t, false)
|
||||
}
|
||||
|
||||
return &Sequence{
|
||||
tokens: tokens,
|
||||
responses: make(chan string, 1),
|
||||
tokens: tokens,
|
||||
responses: make(chan string, 1),
|
||||
samplingCtx: sc,
|
||||
}
|
||||
}
|
||||
|
||||
@ -80,7 +108,6 @@ func (s *Server) run(ctx context.Context) {
|
||||
slog.Info("Processing batch", "seqs", len(s.seqs))
|
||||
s.mu.Lock()
|
||||
for s.allNil() {
|
||||
fmt.Println("wait")
|
||||
s.cond.Wait() // Wait until an item is added
|
||||
}
|
||||
s.mu.Unlock()
|
||||
@ -133,8 +160,16 @@ func (s *Server) run(ctx context.Context) {
|
||||
// sample a token
|
||||
// TODO: sample based on the sequence
|
||||
fmt.Println("Sampling token", i, ibatch[i])
|
||||
logits := s.lc.GetLogitsIth(ibatch[i])
|
||||
token := s.lc.SampleTokenGreedy(logits)
|
||||
fmt.Println("calling sample", s.lc, nil, ibatch[i])
|
||||
token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
|
||||
seq.samplingCtx.Accept(s.lc, token, true)
|
||||
|
||||
// logits := s.lc.GetLogitsIth(ibatch[i])
|
||||
// token := s.lc.SampleTokenGreedy(logits)
|
||||
fmt.Println("sampled", token, s.model.TokenToPiece(token))
|
||||
|
||||
seq.responses <- s.model.TokenToPiece(token)
|
||||
seq.tokens = []int{token}
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
// TODO: just end this sequence
|
||||
@ -145,9 +180,6 @@ func (s *Server) run(ctx context.Context) {
|
||||
s.seqs[i] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
seq.responses <- s.model.TokenToPiece(token)
|
||||
seq.tokens = []int{token}
|
||||
}
|
||||
|
||||
batch.Clear()
|
||||
@ -168,6 +200,7 @@ type Response struct {
|
||||
|
||||
func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
|
||||
var request Request
|
||||
request.Options = api.DefaultOptions()
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
@ -178,7 +211,7 @@ func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
seq := s.NewSequence(request.Prompt, w)
|
||||
seq := s.NewSequence(request, w)
|
||||
|
||||
s.mu.Lock()
|
||||
for i, sq := range s.seqs {
|
||||
|
45
llama/sampling_ext.cpp
vendored
Normal file
45
llama/sampling_ext.cpp
vendored
Normal file
@ -0,0 +1,45 @@
|
||||
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
|
||||
#include "sampling.h"
|
||||
#include "sampling_ext.h"
|
||||
|
||||
struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparams *params) {
|
||||
llama_sampling_params sparams;
|
||||
sparams.top_k = params->top_k;
|
||||
sparams.top_p = params->top_p;
|
||||
sparams.tfs_z = params->tfs_z;
|
||||
sparams.typical_p = params->typical_p;
|
||||
sparams.temp = params->temp;
|
||||
sparams.penalty_repeat = params->penalty_repeat;
|
||||
sparams.penalty_freq = params->penalty_freq;
|
||||
sparams.penalty_present = params->penalty_present;
|
||||
sparams.mirostat = params->mirostat;
|
||||
sparams.mirostat_tau = params->mirostat_tau;
|
||||
sparams.mirostat_eta = params->mirostat_eta;
|
||||
sparams.penalize_nl = params->penalize_nl;
|
||||
sparams.seed = params->seed;
|
||||
return llama_sampling_init(sparams);
|
||||
}
|
||||
|
||||
void llama_sampling_cfree(struct llama_sampling_context * ctx){
|
||||
llama_sampling_free(ctx);
|
||||
}
|
||||
|
||||
void llama_sampling_creset(struct llama_sampling_context * ctx){
|
||||
llama_sampling_reset(ctx);
|
||||
}
|
||||
|
||||
llama_token llama_sampling_csample(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_cfg,
|
||||
int idx) {
|
||||
return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx);
|
||||
}
|
||||
|
||||
void llama_sampling_caccept(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
llama_token id,
|
||||
bool apply_grammar) {
|
||||
llama_sampling_accept(ctx_sampling, ctx_main, id, apply_grammar);
|
||||
}
|
47
llama/sampling_ext.h
vendored
Normal file
47
llama/sampling_ext.h
vendored
Normal file
@ -0,0 +1,47 @@
|
||||
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
|
||||
#ifndef LLAMA_SAMPLING_EXT_H
|
||||
#define LLAMA_SAMPLING_EXT_H
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct llama_sampling_cparams {
|
||||
int32_t top_k;
|
||||
float top_p;
|
||||
float tfs_z;
|
||||
float typical_p;
|
||||
float temp;
|
||||
float penalty_repeat;
|
||||
float penalty_freq;
|
||||
float penalty_present;
|
||||
int32_t mirostat;
|
||||
float mirostat_tau;
|
||||
float mirostat_eta;
|
||||
bool penalize_nl;
|
||||
uint32_t seed;
|
||||
};
|
||||
|
||||
struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparams *params);
|
||||
void llama_sampling_cfree(struct llama_sampling_context * ctx);
|
||||
void llama_sampling_creset(struct llama_sampling_context * ctx);
|
||||
|
||||
llama_token llama_sampling_csample(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_cfg,
|
||||
int idx);
|
||||
|
||||
void llama_sampling_caccept(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
llama_token id,
|
||||
bool apply_grammar);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // LLAMA_SAMPLING_EXT_H
|
Loading…
x
Reference in New Issue
Block a user