truncate stop properly

This commit is contained in:
jmorganca 2024-05-27 23:09:56 -07:00
parent a379d68aa9
commit 72f3fe4b94
2 changed files with 102 additions and 25 deletions

View File

@ -94,7 +94,7 @@ func (s *Server) allNil() bool {
return true
}
func contains(sequence string, stops []string) (bool, string) {
func findStop(sequence string, stops []string) (bool, string) {
for _, stop := range stops {
if strings.Contains(sequence, stop) {
return true, stop
@ -104,9 +104,9 @@ func contains(sequence string, stops []string) (bool, string) {
return false, ""
}
func overlap(sequence string, stops []string) bool {
func containsStopSuffix(sequence string, stops []string) bool {
for _, stop := range stops {
for i := 1; i < len(stop); i++ {
for i := 1; i <= len(stop); i++ {
if strings.HasSuffix(sequence, stop[:i]) {
return true
}
@ -116,13 +116,50 @@ func overlap(sequence string, stops []string) bool {
return false
}
// truncateStop removes the provided stop string from pieces,
// returning the partial pieces with stop removed, including truncating
// the last piece if required
func truncateStop(pieces []string, stop string) []string {
joined := strings.Join(pieces, "")
index := strings.Index(joined, stop)
if index == -1 {
return pieces
}
joined = joined[:index]
// Split truncated string back into pieces of original lengths
lengths := make([]int, len(pieces))
for i, piece := range pieces {
lengths[i] = len(piece)
}
var result []string
start := 0
for _, length := range lengths {
if start >= len(joined) {
break
}
end := start + length
if end > len(joined) {
end = len(joined)
}
result = append(result, joined[start:end])
start = end
}
return result
}
func (s *Server) run(ctx context.Context) {
batch := llama.NewBatch(512, 0, s.parallel)
defer batch.Free()
// build up stop sequences as we recognize them
// TODO (jmorganca): simplify this
sofar := make([][]string, s.parallel)
pieces := make([][]string, s.parallel)
for {
select {
@ -214,50 +251,41 @@ func (s *Server) run(ctx context.Context) {
close(seq.responses)
seq.samplingCtx.Free()
sofar[i] = []string{}
pieces[i] = []string{}
s.seqs[i] = nil
continue
}
seq.tokens = []int{token}
// recognize stop sequences
// TODO (jmorganca): add tests around this
// TODO (jmorganca): send back parital piece
sequence := strings.Join(append(sofar[i], piece), "")
if ok, stop := contains(sequence, seq.stop); ok {
pieces[i] = append(pieces[i], piece)
sequence := strings.Join(pieces[i], "")
if ok, stop := findStop(sequence, seq.stop); ok {
slog.Info("hit stop token", "stop", seq.stop)
for _, p := range sofar[i] {
truncated := truncateStop(pieces[i], stop)
for _, p := range truncated {
seq.responses <- p
}
piece, _, _ := strings.Cut(piece, stop)
seq.responses <- piece
s.lc.KvCacheSeqRm(i, 0, -1)
close(seq.responses)
seq.samplingCtx.Free()
sofar[i] = []string{}
pieces[i] = []string{}
s.seqs[i] = nil
continue
}
if overlap(sequence, seq.stop) {
slog.Info("overlap", "sequence", sequence)
// partial stop, don't send
if containsStopSuffix(sequence, seq.stop) {
continue
}
slog.Info("sending", "sofar", sofar[i])
sofar[i] = append(sofar[i], piece)
for _, p := range sofar[i] {
for _, p := range pieces[i] {
seq.responses <- p
}
sofar[i] = []string{}
pieces[i] = []string{}
}
batch.Clear()

View File

@ -0,0 +1,49 @@
package main
import (
"reflect"
"testing"
)
func TestTruncateStop(t *testing.T) {
tests := []struct {
name string
pieces []string
stop string
expected []string
}{
{
name: "Single word",
pieces: []string{"hello", "world"},
stop: "world",
expected: []string{"hello"},
},
{
name: "Partial",
pieces: []string{"hello", "wor"},
stop: "or",
expected: []string{"hello", "w"},
},
{
name: "Suffix",
pieces: []string{"Hello", " there", "!"},
stop: "!",
expected: []string{"Hello", " there"},
},
{
name: "Middle",
pieces: []string{"hello", " wor"},
stop: "llo w",
expected: []string{"he"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := truncateStop(tt.pieces, tt.stop)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("truncateStop(%v, %s): have %v; want %v", tt.pieces, tt.stop, result, tt.expected)
}
})
}
}