diff --git a/docs/modelfile.md b/docs/modelfile.md index 0ee434eb..0f59905b 100644 --- a/docs/modelfile.md +++ b/docs/modelfile.md @@ -123,7 +123,7 @@ PARAMETER | repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 | | repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 | | temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 | -| stop | Sets the stop tokens to use. | string | stop "AI assistant:" | +| stop | Sets the stop sequences to use. | string | stop "AI assistant:" | | tfs_z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | tfs_z 1 | | top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 | | top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 | diff --git a/llm/llama.go b/llm/llama.go index ce697b33..353ec47f 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -334,20 +334,18 @@ func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse b.WriteString(llm.Decode(int(token))) - if err := llm.checkStopConditions(b); err != nil { - if errors.Is(err, io.EOF) { - break - } else if errors.Is(err, errNeedMoreData) { - continue - } - - return err + stop, endsWithStopPrefix := handleStopSequences(&b, llm.Stop) + if endsWithStopPrefix { + continue } if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax { fn(api.GenerateResponse{Response: b.String()}) b.Reset() } + if stop { + break + } } embd := make([]int, len(llm.embd)) @@ -370,16 +368,31 @@ func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse return nil } -func (llm *llama) checkStopConditions(b bytes.Buffer) error { - for _, stopCondition := range llm.Stop { - if stopCondition == strings.TrimSpace(b.String()) { - return io.EOF - } else if strings.HasPrefix(stopCondition, strings.TrimSpace(b.String())) { - return errNeedMoreData +// handleStopSequences checks whether b contains any of the stop sequences, or ends with a prefix of +// any stop sequence (and therefore might contain data that should not ultimately be returned to the +// client). +// +// If b contains a stop sequence, it modifies b to remove the stop sequence and all subsequent data. +func handleStopSequences(b *bytes.Buffer, stopSequences []string) (stop bool, endsWithStopPrefix bool) { + s := b.String() + for _, seq := range stopSequences { + // Check for an exact or substring match. + if i := strings.Index(s, seq); i != -1 { + b.Truncate(i) + return true, false + } + + // Check if b ends with a prefix of the stop sequence. + if len(seq) > 1 { + for i := 1; i < len(seq); i++ { + if strings.HasSuffix(s, seq[:i]) { + return false, true + } + } } } - return nil + return false, false } func (llm *llama) marshalPrompt(ctx []int, prompt string) []C.llama_token { diff --git a/llm/llama_test.go b/llm/llama_test.go new file mode 100644 index 00000000..536edb92 --- /dev/null +++ b/llm/llama_test.go @@ -0,0 +1,79 @@ +package llm + +import ( + "bytes" + "testing" +) + +func TestCheckStopConditions(t *testing.T) { + tests := map[string]struct { + b string + stop []string + wantB string + wantStop bool + wantEndsWithStopPrefix bool + }{ + "not present": { + b: "abc", + stop: []string{"x"}, + wantStop: false, + wantEndsWithStopPrefix: false, + }, + "exact": { + b: "abc", + stop: []string{"abc"}, + wantStop: true, + wantEndsWithStopPrefix: false, + }, + "substring": { + b: "abc", + stop: []string{"b"}, + wantB: "a", + wantStop: true, + wantEndsWithStopPrefix: false, + }, + "prefix 1": { + b: "abc", + stop: []string{"abcd"}, + wantStop: false, + wantEndsWithStopPrefix: true, + }, + "prefix 2": { + b: "abc", + stop: []string{"bcd"}, + wantStop: false, + wantEndsWithStopPrefix: true, + }, + "prefix 3": { + b: "abc", + stop: []string{"cd"}, + wantStop: false, + wantEndsWithStopPrefix: true, + }, + "no prefix": { + b: "abc", + stop: []string{"bx"}, + wantStop: false, + wantEndsWithStopPrefix: false, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + var b bytes.Buffer + b.WriteString(test.b) + stop, endsWithStopPrefix := handleStopSequences(&b, test.stop) + if test.wantB != "" { + gotB := b.String() + if gotB != test.wantB { + t.Errorf("got b %q, want %q", gotB, test.wantB) + } + } + if stop != test.wantStop { + t.Errorf("got stop %v, want %v", stop, test.wantStop) + } + if endsWithStopPrefix != test.wantEndsWithStopPrefix { + t.Errorf("got endsWithStopPrefix %v, want %v", endsWithStopPrefix, test.wantEndsWithStopPrefix) + } + }) + } +} diff --git a/server/images.go b/server/images.go index 9f457ecb..ff06e04b 100644 --- a/server/images.go +++ b/server/images.go @@ -430,7 +430,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api layer.MediaType = mediaType layers = append(layers, layer) default: - // runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop tokens) + // runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop sequences) params[c.Name] = append(params[c.Name], c.Args) } }