diff --git a/llama/llama.go b/llama/llama.go index b169cf51..afd17a24 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -390,6 +390,7 @@ type SamplingParams struct { TfsZ float32 TypicalP float32 Temp float32 + RepeatLastN int PenaltyRepeat float32 PenaltyFreq float32 PenaltyPresent float32 @@ -408,6 +409,7 @@ func NewSamplingContext(params SamplingParams) *SamplingContext { cparams.tfs_z = C.float(params.TfsZ) cparams.typical_p = C.float(params.TypicalP) cparams.temp = C.float(params.Temp) + cparams.penalty_last_n = C.int32_t(params.RepeatLastN) cparams.penalty_repeat = C.float(params.PenaltyRepeat) cparams.penalty_freq = C.float(params.PenaltyFreq) cparams.penalty_present = C.float(params.PenaltyFreq) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 1491de3e..75c79b41 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -402,6 +402,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { samplingParams.TfsZ = req.TFSZ samplingParams.TypicalP = req.TypicalP samplingParams.Temp = req.Temperature + samplingParams.RepeatLastN = req.RepeatLastN samplingParams.PenaltyRepeat = req.RepeatPenalty samplingParams.PenaltyFreq = req.FrequencyPenalty samplingParams.PenaltyPresent = req.PresencePenalty diff --git a/llama/sampling_ext.cpp b/llama/sampling_ext.cpp index 9928a739..4b7fcc97 100644 --- a/llama/sampling_ext.cpp +++ b/llama/sampling_ext.cpp @@ -10,6 +10,7 @@ struct llama_sampling_context *llama_sampling_cinit(struct llama_sampling_cparam sparams.tfs_z = params->tfs_z; sparams.typical_p = params->typical_p; sparams.temp = params->temp; + sparams.penalty_last_n = params->penalty_last_n; sparams.penalty_repeat = params->penalty_repeat; sparams.penalty_freq = params->penalty_freq; sparams.penalty_present = params->penalty_present; diff --git a/llama/sampling_ext.h b/llama/sampling_ext.h index baf1a0a9..16dd2398 100644 --- a/llama/sampling_ext.h +++ b/llama/sampling_ext.h @@ -16,6 +16,7 @@ extern "C" float tfs_z; float typical_p; float temp; + int32_t penalty_last_n; float penalty_repeat; float penalty_freq; float penalty_present;