Compare commits

...

1 Commits

Author SHA1 Message Date
Bruce MacDonald
e94e5b1771 fix: retry on concurrent request failure 2023-12-08 17:33:41 -08:00

View File

@ -535,6 +535,8 @@ type prediction struct {
} }
const maxBufferSize = 512 * format.KiloByte const maxBufferSize = 512 * format.KiloByte
const maxRetries = 3
const retryDelay = 1 * time.Second
type PredictOpts struct { type PredictOpts struct {
Model string Model string
@ -557,6 +559,11 @@ type PredictResult struct {
EvalDuration time.Duration EvalDuration time.Duration
} }
// IsRetryable checks if the line matches a condition that can be retried
func isRetryable(line []byte) bool {
return bytes.Contains(line, []byte("slot unavailable"))
}
func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error { func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
request := map[string]any{ request := map[string]any{
"prompt": predict.Prompt, "prompt": predict.Prompt,
@ -585,6 +592,9 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
request["grammar"] = jsonGrammar request["grammar"] = jsonGrammar
} }
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
for retries := 0; retries < maxRetries; retries++ {
// Handling JSON marshaling with special characters unescaped. // Handling JSON marshaling with special characters unescaped.
buffer := &bytes.Buffer{} buffer := &bytes.Buffer{}
enc := json.NewEncoder(buffer) enc := json.NewEncoder(buffer)
@ -594,7 +604,9 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
return fmt.Errorf("failed to marshal data: %v", err) return fmt.Errorf("failed to marshal data: %v", err)
} }
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port) if retries > 0 {
time.Sleep(retryDelay) // wait before retrying
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer) req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
if err != nil { if err != nil {
return fmt.Errorf("error creating POST request: %v", err) return fmt.Errorf("error creating POST request: %v", err)
@ -620,6 +632,8 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
// increase the buffer size to avoid running out of space // increase the buffer size to avoid running out of space
buf := make([]byte, 0, maxBufferSize) buf := make([]byte, 0, maxBufferSize)
scanner.Buffer(buf, maxBufferSize) scanner.Buffer(buf, maxBufferSize)
retryNeeded := false
for scanner.Scan() { for scanner.Scan() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -631,7 +645,16 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
continue continue
} }
if evt, ok := bytes.CutPrefix(line, []byte("data: ")); ok { if isRetryable(line) {
retryNeeded = true
break
}
evt, ok := bytes.CutPrefix(line, []byte("data: "))
if !ok {
return fmt.Errorf("error parsing llm response stream: %s", line)
}
var p prediction var p prediction
if err := json.Unmarshal(evt, &p); err != nil { if err := json.Unmarshal(evt, &p); err != nil {
return fmt.Errorf("error unmarshaling llm prediction response: %v", err) return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
@ -661,7 +684,6 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
} }
} }
} }
}
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "unexpected EOF") { if strings.Contains(err.Error(), "unexpected EOF") {
@ -675,8 +697,15 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
return fmt.Errorf("error reading llm response: %v", err) return fmt.Errorf("error reading llm response: %v", err)
} }
if !retryNeeded {
// success
return nil return nil
} }
}
// should never reach here ideally
return fmt.Errorf("max retries exceeded")
}
type TokenizeRequest struct { type TokenizeRequest struct {
Content string `json:"content"` Content string `json:"content"`