diff --git a/llm/llama.go b/llm/llama.go index b2f1571f..89049e48 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -535,6 +535,8 @@ type prediction struct { } const maxBufferSize = 512 * format.KiloByte +const maxRetries = 3 +const retryDelay = 1 * time.Second type PredictOpts struct { Model string @@ -557,6 +559,11 @@ type PredictResult struct { EvalDuration time.Duration } +// IsRetryable checks if the line matches a condition that can be retried +func isRetryable(line []byte) bool { + return bytes.Contains(line, []byte("slot unavailable")) +} + func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error { request := map[string]any{ "prompt": predict.Prompt, @@ -585,53 +592,69 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred request["grammar"] = jsonGrammar } - // Handling JSON marshaling with special characters unescaped. - buffer := &bytes.Buffer{} - enc := json.NewEncoder(buffer) - enc.SetEscapeHTML(false) - - if err := enc.Encode(request); err != nil { - return fmt.Errorf("failed to marshal data: %v", err) - } - endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer) - if err != nil { - return fmt.Errorf("error creating POST request: %v", err) - } - req.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) - if err != nil { - return fmt.Errorf("POST predict: %v", err) - } - defer resp.Body.Close() + for retries := 0; retries < maxRetries; retries++ { + // Handling JSON marshaling with special characters unescaped. + buffer := &bytes.Buffer{} + enc := json.NewEncoder(buffer) + enc.SetEscapeHTML(false) - if resp.StatusCode >= 400 { - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed reading llm error response: %w", err) + if err := enc.Encode(request); err != nil { + return fmt.Errorf("failed to marshal data: %v", err) } - log.Printf("llm predict error: %s", bodyBytes) - return fmt.Errorf("%s", bodyBytes) - } - scanner := bufio.NewScanner(resp.Body) - // increase the buffer size to avoid running out of space - buf := make([]byte, 0, maxBufferSize) - scanner.Buffer(buf, maxBufferSize) - for scanner.Scan() { - select { - case <-ctx.Done(): - // This handles the request cancellation - return ctx.Err() - default: - line := scanner.Bytes() - if len(line) == 0 { - continue + if retries > 0 { + time.Sleep(retryDelay) // wait before retrying + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer) + if err != nil { + return fmt.Errorf("error creating POST request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("POST predict: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed reading llm error response: %w", err) } + log.Printf("llm predict error: %s", bodyBytes) + return fmt.Errorf("%s", bodyBytes) + } + + scanner := bufio.NewScanner(resp.Body) + // increase the buffer size to avoid running out of space + buf := make([]byte, 0, maxBufferSize) + scanner.Buffer(buf, maxBufferSize) + + retryNeeded := false + for scanner.Scan() { + select { + case <-ctx.Done(): + // This handles the request cancellation + return ctx.Err() + default: + line := scanner.Bytes() + if len(line) == 0 { + continue + } + + if isRetryable(line) { + retryNeeded = true + break + } + + evt, ok := bytes.CutPrefix(line, []byte("data: ")) + if !ok { + return fmt.Errorf("error parsing llm response stream: %s", line) + } - if evt, ok := bytes.CutPrefix(line, []byte("data: ")); ok { var p prediction if err := json.Unmarshal(evt, &p); err != nil { return fmt.Errorf("error unmarshaling llm prediction response: %v", err) @@ -661,21 +684,27 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred } } } - } - if err := scanner.Err(); err != nil { - if strings.Contains(err.Error(), "unexpected EOF") { - // this means the llama runner subprocess crashed - llm.Close() - if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" { - return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg) + if err := scanner.Err(); err != nil { + if strings.Contains(err.Error(), "unexpected EOF") { + // this means the llama runner subprocess crashed + llm.Close() + if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" { + return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg) + } + return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model") } - return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model") + return fmt.Errorf("error reading llm response: %v", err) + } + + if !retryNeeded { + // success + return nil } - return fmt.Errorf("error reading llm response: %v", err) } - return nil + // should never reach here ideally + return fmt.Errorf("max retries exceeded") } type TokenizeRequest struct {