diff --git a/llama/runner/runner.go b/llama/runner/runner.go index c2d81e23..97129fef 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -121,65 +121,6 @@ func (s *Server) allNil() bool { return true } -func findStop(sequence string, stops []string) (bool, string) { - for _, stop := range stops { - if strings.Contains(sequence, stop) { - return true, stop - } - } - - return false, "" -} - -func containsStopSuffix(sequence string, stops []string) bool { - for _, stop := range stops { - for i := 1; i <= len(stop); i++ { - if strings.HasSuffix(sequence, stop[:i]) { - return true - } - } - } - - return false -} - -// truncateStop removes the provided stop string from pieces, -// returning the partial pieces with stop removed, including truncating -// the last piece if required -func truncateStop(pieces []string, stop string) []string { - joined := strings.Join(pieces, "") - - index := strings.Index(joined, stop) - if index == -1 { - return pieces - } - - joined = joined[:index] - - // Split truncated string back into pieces of original lengths - lengths := make([]int, len(pieces)) - for i, piece := range pieces { - lengths[i] = len(piece) - } - - var result []string - start := 0 - for _, length := range lengths { - if start >= len(joined) { - break - } - - end := start + length - if end > len(joined) { - end = len(joined) - } - result = append(result, joined[start:end]) - start = end - } - - return result -} - func (s *Server) run(ctx context.Context) { batch := llama.NewBatch(s.batchSize, 0, s.parallel) defer batch.Free() diff --git a/llama/runner/stop.go b/llama/runner/stop.go new file mode 100644 index 00000000..b593a904 --- /dev/null +++ b/llama/runner/stop.go @@ -0,0 +1,64 @@ +package main + +import ( + "strings" +) + +func findStop(sequence string, stops []string) (bool, string) { + for _, stop := range stops { + if strings.Contains(sequence, stop) { + return true, stop + } + } + + return false, "" +} + +func containsStopSuffix(sequence string, stops []string) bool { + for _, stop := range stops { + for i := 1; i <= len(stop); i++ { + if strings.HasSuffix(sequence, stop[:i]) { + return true + } + } + } + + return false +} + +// truncateStop removes the provided stop string from pieces, +// returning the partial pieces with stop removed, including truncating +// the last piece if required +func truncateStop(pieces []string, stop string) []string { + joined := strings.Join(pieces, "") + + index := strings.Index(joined, stop) + if index == -1 { + return pieces + } + + joined = joined[:index] + + // Split truncated string back into pieces of original lengths + lengths := make([]int, len(pieces)) + for i, piece := range pieces { + lengths[i] = len(piece) + } + + var result []string + start := 0 + for _, length := range lengths { + if start >= len(joined) { + break + } + + end := start + length + if end > len(joined) { + end = len(joined) + } + result = append(result, joined[start:end]) + start = end + } + + return result +} diff --git a/llama/runner/runner_test.go b/llama/runner/stop_test.go similarity index 100% rename from llama/runner/runner_test.go rename to llama/runner/stop_test.go