From 0077e22d524edbad949002216f2ba6206aacb1b5 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Wed, 9 Oct 2024 16:12:23 -0700 Subject: [PATCH] runner.go: Handle truncation of tokens for stop sequences When a single token contains both text to be return and a stop sequence, this causes an out of bounds error when we update the cache to match our text. This is because we currently assume that the removing the stop sequence will consume at least one token. This also inverts the logic to deal with positive numbers, rather than a value to be subtracted, which is easier to reason about. Fixes #7153 --- llama/runner/runner.go | 25 +++++++++++++---- llama/runner/stop.go | 10 ++++--- llama/runner/stop_test.go | 58 +++++++++++++++++++++++---------------- 3 files changed, 60 insertions(+), 33 deletions(-) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index ffbea9e9..bf799d37 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -451,14 +451,27 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) sequence := strings.Join(seq.pendingResponses, "") if ok, stop := findStop(sequence, seq.stop); ok { - slog.Debug("hit stop token", "stop", seq.stop) + slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) - trimCacheLen := len(seq.pendingResponses) - 1 - seq.pendingResponses = truncateStop(seq.pendingResponses, stop) - trimCacheLen -= len(seq.pendingResponses) + var tokenTruncated bool + origLen := len(seq.pendingResponses) + seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop) + newLen := len(seq.pendingResponses) + + // Update the cache based on the tokens that will be returned: + // - We have 1 token more than is currently in the cache because + // the last one generated wasn't submitted to Decode + // - Remove any stop sequences that we stripped out + // - If truncateStop removed a portion of a token, drop that + // - As defense-in-depth, if truncatedToken didn't find a stop token + // remove the extra one that we added to the cache len + tokenLen := len(seq.cache.Inputs) + 1 + tokenLen -= origLen - newLen + if tokenTruncated || origLen == newLen { + tokenLen-- + } + seq.cache.Inputs = seq.cache.Inputs[:tokenLen] - // remove any tokens from the cache that we don't actually return - seq.cache.Inputs = seq.cache.Inputs[:len(seq.cache.Inputs)-trimCacheLen] s.removeSequence(i, "stop") continue } diff --git a/llama/runner/stop.go b/llama/runner/stop.go index ece06c21..c05f5e3d 100644 --- a/llama/runner/stop.go +++ b/llama/runner/stop.go @@ -28,13 +28,13 @@ func containsStopSuffix(sequence string, stops []string) bool { // truncateStop removes the provided stop string from pieces, // returning the partial pieces with stop removed, including truncating -// the last piece if required -func truncateStop(pieces []string, stop string) []string { +// the last piece if required (and signalling if this was the case) +func truncateStop(pieces []string, stop string) ([]string, bool) { joined := strings.Join(pieces, "") index := strings.Index(joined, stop) if index == -1 { - return pieces + return pieces, false } joined = joined[:index] @@ -46,6 +46,7 @@ func truncateStop(pieces []string, stop string) []string { } var result []string + tokenTruncated := false start := 0 for _, length := range lengths { if start >= len(joined) { @@ -55,12 +56,13 @@ func truncateStop(pieces []string, stop string) []string { end := start + length if end > len(joined) { end = len(joined) + tokenTruncated = true } result = append(result, joined[start:end]) start = end } - return result + return result, tokenTruncated } func incompleteUnicode(token string) bool { diff --git a/llama/runner/stop_test.go b/llama/runner/stop_test.go index 14553987..51b35fde 100644 --- a/llama/runner/stop_test.go +++ b/llama/runner/stop_test.go @@ -7,42 +7,54 @@ import ( func TestTruncateStop(t *testing.T) { tests := []struct { - name string - pieces []string - stop string - expected []string + name string + pieces []string + stop string + expected []string + expectedTrunc bool }{ { - name: "Single word", - pieces: []string{"hello", "world"}, - stop: "world", - expected: []string{"hello"}, + name: "Single word", + pieces: []string{"hello", "world"}, + stop: "world", + expected: []string{"hello"}, + expectedTrunc: false, }, { - name: "Partial", - pieces: []string{"hello", "wor"}, - stop: "or", - expected: []string{"hello", "w"}, + name: "Partial", + pieces: []string{"hello", "wor"}, + stop: "or", + expected: []string{"hello", "w"}, + expectedTrunc: true, }, { - name: "Suffix", - pieces: []string{"Hello", " there", "!"}, - stop: "!", - expected: []string{"Hello", " there"}, + name: "Suffix", + pieces: []string{"Hello", " there", "!"}, + stop: "!", + expected: []string{"Hello", " there"}, + expectedTrunc: false, }, { - name: "Middle", - pieces: []string{"hello", " wor"}, - stop: "llo w", - expected: []string{"he"}, + name: "Suffix partial", + pieces: []string{"Hello", " the", "re!"}, + stop: "there!", + expected: []string{"Hello", " "}, + expectedTrunc: true, + }, + { + name: "Middle", + pieces: []string{"hello", " wor"}, + stop: "llo w", + expected: []string{"he"}, + expectedTrunc: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := truncateStop(tt.pieces, tt.stop) - if !reflect.DeepEqual(result, tt.expected) { - t.Errorf("truncateStop(%v, %s): have %v; want %v", tt.pieces, tt.stop, result, tt.expected) + result, resultTrunc := truncateStop(tt.pieces, tt.stop) + if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc { + t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc) } }) }