From 80d065658df96c9e4b03ffdcdb8eb0890355ca5d Mon Sep 17 00:00:00 2001 From: Yap Sok Ann Date: Thu, 18 Jul 2024 08:57:33 +0700 Subject: [PATCH] Make llama.cpp's cache_prompt parameter configurable This allows the output to be deterministic when setting the same seed and temperature. Fixes #5321 --- api/types.go | 2 ++ cmd/interactive.go | 1 + llm/server.go | 2 +- server/routes_test.go | 3 ++- 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/api/types.go b/api/types.go index c7e9dce3..0c31387c 100644 --- a/api/types.go +++ b/api/types.go @@ -221,6 +221,7 @@ type Options struct { MirostatEta float32 `json:"mirostat_eta,omitempty"` PenalizeNewline bool `json:"penalize_newline,omitempty"` Stop []string `json:"stop,omitempty"` + CachePrompt bool `json:"cache_prompt,omitempty"` } // Runner options which must be set when the model is loaded into memory @@ -594,6 +595,7 @@ func DefaultOptions() Options { MirostatEta: 0.1, PenalizeNewline: true, Seed: -1, + CachePrompt: true, Runner: Runner{ // options set when the model is loaded diff --git a/cmd/interactive.go b/cmd/interactive.go index adbc3e9f..2f0feb49 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -144,6 +144,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n Set how far back to look for repetitions") fmt.Fprintln(os.Stderr, " /set parameter num_gpu The number of layers to send to the GPU") fmt.Fprintln(os.Stderr, " /set parameter stop ... Set the stop parameters") + fmt.Fprintln(os.Stderr, " /set parameter cache_prompt Set the cache_prompt parameter of llama.cpp") fmt.Fprintln(os.Stderr, "") } diff --git a/llm/server.go b/llm/server.go index 36c0e0b5..7d0b10cb 100644 --- a/llm/server.go +++ b/llm/server.go @@ -734,7 +734,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu "seed": req.Options.Seed, "stop": req.Options.Stop, "image_data": req.Images, - "cache_prompt": true, + "cache_prompt": req.Options.CachePrompt, } // Make sure the server is ready diff --git a/server/routes_test.go b/server/routes_test.go index 97786ba2..9237ed60 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -63,7 +63,7 @@ func Test_Routes(t *testing.T) { fname := createTestFile(t, "ollama-model") - r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname)) + r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar\nPARAMETER cache_prompt false", fname)) modelfile, err := parser.ParseFile(r) require.NoError(t, err) fn := func(resp api.ProgressResponse) { @@ -246,6 +246,7 @@ func Test_Routes(t *testing.T) { } sort.Strings(params) expectedParams := []string{ + "cache_prompt false", "seed 42", "stop \"bar\"", "stop \"foo\"",