diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 0d34febf..54210a49 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -94,7 +94,7 @@ func (s *Server) allNil() bool { return true } -func contains(sequence string, stops []string) (bool, string) { +func findStop(sequence string, stops []string) (bool, string) { for _, stop := range stops { if strings.Contains(sequence, stop) { return true, stop @@ -104,9 +104,9 @@ func contains(sequence string, stops []string) (bool, string) { return false, "" } -func overlap(sequence string, stops []string) bool { +func containsStopSuffix(sequence string, stops []string) bool { for _, stop := range stops { - for i := 1; i < len(stop); i++ { + for i := 1; i <= len(stop); i++ { if strings.HasSuffix(sequence, stop[:i]) { return true } @@ -116,13 +116,50 @@ func overlap(sequence string, stops []string) bool { return false } +// truncateStop removes the provided stop string from pieces, +// returning the partial pieces with stop removed, including truncating +// the last piece if required +func truncateStop(pieces []string, stop string) []string { + joined := strings.Join(pieces, "") + + index := strings.Index(joined, stop) + if index == -1 { + return pieces + } + + joined = joined[:index] + + // Split truncated string back into pieces of original lengths + lengths := make([]int, len(pieces)) + for i, piece := range pieces { + lengths[i] = len(piece) + } + + var result []string + start := 0 + for _, length := range lengths { + if start >= len(joined) { + break + } + + end := start + length + if end > len(joined) { + end = len(joined) + } + result = append(result, joined[start:end]) + start = end + } + + return result +} + func (s *Server) run(ctx context.Context) { batch := llama.NewBatch(512, 0, s.parallel) defer batch.Free() // build up stop sequences as we recognize them // TODO (jmorganca): simplify this - sofar := make([][]string, s.parallel) + pieces := make([][]string, s.parallel) for { select { @@ -214,50 +251,41 @@ func (s *Server) run(ctx context.Context) { close(seq.responses) seq.samplingCtx.Free() - sofar[i] = []string{} + pieces[i] = []string{} s.seqs[i] = nil continue } seq.tokens = []int{token} - // recognize stop sequences - // TODO (jmorganca): add tests around this - // TODO (jmorganca): send back parital piece - - sequence := strings.Join(append(sofar[i], piece), "") - if ok, stop := contains(sequence, seq.stop); ok { + pieces[i] = append(pieces[i], piece) + sequence := strings.Join(pieces[i], "") + if ok, stop := findStop(sequence, seq.stop); ok { slog.Info("hit stop token", "stop", seq.stop) - for _, p := range sofar[i] { + + truncated := truncateStop(pieces[i], stop) + + for _, p := range truncated { seq.responses <- p } - piece, _, _ := strings.Cut(piece, stop) - seq.responses <- piece - s.lc.KvCacheSeqRm(i, 0, -1) close(seq.responses) seq.samplingCtx.Free() - sofar[i] = []string{} + pieces[i] = []string{} s.seqs[i] = nil continue } - if overlap(sequence, seq.stop) { - slog.Info("overlap", "sequence", sequence) - // partial stop, don't send + if containsStopSuffix(sequence, seq.stop) { continue } - slog.Info("sending", "sofar", sofar[i]) - - sofar[i] = append(sofar[i], piece) - - for _, p := range sofar[i] { + for _, p := range pieces[i] { seq.responses <- p } - sofar[i] = []string{} + pieces[i] = []string{} } batch.Clear() diff --git a/llama/runner/runner_test.go b/llama/runner/runner_test.go new file mode 100644 index 00000000..d7fd48ab --- /dev/null +++ b/llama/runner/runner_test.go @@ -0,0 +1,49 @@ +package main + +import ( + "reflect" + "testing" +) + +func TestTruncateStop(t *testing.T) { + tests := []struct { + name string + pieces []string + stop string + expected []string + }{ + { + name: "Single word", + pieces: []string{"hello", "world"}, + stop: "world", + expected: []string{"hello"}, + }, + { + name: "Partial", + pieces: []string{"hello", "wor"}, + stop: "or", + expected: []string{"hello", "w"}, + }, + { + name: "Suffix", + pieces: []string{"Hello", " there", "!"}, + stop: "!", + expected: []string{"Hello", " there"}, + }, + { + name: "Middle", + pieces: []string{"hello", " wor"}, + stop: "llo w", + expected: []string{"he"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := truncateStop(tt.pieces, tt.stop) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("truncateStop(%v, %s): have %v; want %v", tt.pieces, tt.stop, result, tt.expected) + } + }) + } +}