From 76718ead40505ba7d583e0b598f7626a51ff8daa Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Wed, 21 Aug 2024 16:13:54 -0700 Subject: [PATCH] runner.go: Support MinP parameter MinP is a user-facing parameter that is exposed that is exposed through the APIs but is not currently plumbed through. --- llama/llama.go | 2 ++ llama/runner/runner.go | 1 + llama/sampling_ext.cpp | 1 + llama/sampling_ext.h | 1 + 4 files changed, 5 insertions(+) diff --git a/llama/llama.go b/llama/llama.go index afd17a24..1315fbe2 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -387,6 +387,7 @@ type SamplingContext struct { type SamplingParams struct { TopK int TopP float32 + MinP float32 TfsZ float32 TypicalP float32 Temp float32 @@ -406,6 +407,7 @@ func NewSamplingContext(params SamplingParams) *SamplingContext { var cparams C.struct_llama_sampling_cparams cparams.top_k = C.int32_t(params.TopK) cparams.top_p = C.float(params.TopP) + cparams.min_p = C.float(params.MinP) cparams.tfs_z = C.float(params.TfsZ) cparams.typical_p = C.float(params.TypicalP) cparams.temp = C.float(params.Temp) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 3e54229f..07fa5f06 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -434,6 +434,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { var samplingParams llama.SamplingParams samplingParams.TopK = req.TopK samplingParams.TopP = req.TopP + samplingParams.MinP = req.MinP samplingParams.TfsZ = req.TFSZ samplingParams.TypicalP = req.TypicalP samplingParams.Temp = req.Temperature diff --git a/llama/sampling_ext.cpp b/llama/sampling_ext.cpp index 4b7fcc97..da92cedf 100644 --- a/llama/sampling_ext.cpp +++ b/llama/sampling_ext.cpp @@ -7,6 +7,7 @@ struct llama_sampling_context *llama_sampling_cinit(struct llama_sampling_cparam llama_sampling_params sparams; sparams.top_k = params->top_k; sparams.top_p = params->top_p; + sparams.min_p = params->min_p; sparams.tfs_z = params->tfs_z; sparams.typical_p = params->typical_p; sparams.temp = params->temp; diff --git a/llama/sampling_ext.h b/llama/sampling_ext.h index 16dd2398..588ed5c1 100644 --- a/llama/sampling_ext.h +++ b/llama/sampling_ext.h @@ -13,6 +13,7 @@ extern "C" { int32_t top_k; float top_p; + float min_p; float tfs_z; float typical_p; float temp;