diff --git a/llm/server.go b/llm/server.go index 44bada08..db1b0e23 100644 --- a/llm/server.go +++ b/llm/server.go @@ -338,7 +338,7 @@ type ServerStatus int const ( // iota is reset to 0 ServerStatusReady ServerStatus = iota - ServerStatusNoSlotsAvaialble + ServerStatusNoSlotsAvailable ServerStatusLoadingModel ServerStatusNotResponding ServerStatusError @@ -348,7 +348,7 @@ func (s ServerStatus) ToString() string { switch s { case ServerStatusReady: return "llm server ready" - case ServerStatusNoSlotsAvaialble: + case ServerStatusNoSlotsAvailable: return "llm busy - no slots available" case ServerStatusLoadingModel: return "llm server loading model" @@ -405,7 +405,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { case "ok": return ServerStatusReady, nil case "no slot available": - return ServerStatusNoSlotsAvaialble, nil + return ServerStatusNoSlotsAvailable, nil case "loading model": return ServerStatusLoadingModel, nil default: @@ -413,6 +413,29 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { } } +// getServerStatusRetry will retry if ServerStatusNoSlotsAvailable is received +func (s *llmServer) getServerStatusRetry(ctx context.Context) (ServerStatus, error) { + var retries int + for { + status, err := s.getServerStatus(ctx) + if err != nil { + return status, err + } + + if status == ServerStatusNoSlotsAvailable { + if retries >= 10 { + return status, fmt.Errorf("no slots available after %d retries", retries) + } + + time.Sleep(5 * time.Millisecond) + retries++ + continue + } + + return status, nil + } +} + func (s *llmServer) Ping(ctx context.Context) error { _, err := s.getServerStatus(ctx) if err != nil { @@ -510,7 +533,6 @@ ws ::= ([ \t\n] ws)? ` const maxBufferSize = 512 * format.KiloByte -const maxRetries = 3 type ImageData struct { Data []byte `json:"data"` @@ -586,7 +608,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu } // Make sure the server is ready - status, err := s.getServerStatus(ctx) + status, err := s.getServerStatusRetry(ctx) if err != nil { return err } else if status != ServerStatusReady { @@ -600,133 +622,113 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu } } - retryDelay := 100 * time.Microsecond - for retries := 0; retries < maxRetries; retries++ { - if retries > 0 { - time.Sleep(retryDelay) // wait before retrying - retryDelay *= 2 // exponential backoff - } + // Handling JSON marshaling with special characters unescaped. + buffer := &bytes.Buffer{} + enc := json.NewEncoder(buffer) + enc.SetEscapeHTML(false) - // Handling JSON marshaling with special characters unescaped. - buffer := &bytes.Buffer{} - enc := json.NewEncoder(buffer) - enc.SetEscapeHTML(false) + if err := enc.Encode(request); err != nil { + return fmt.Errorf("failed to marshal data: %v", err) + } - if err := enc.Encode(request); err != nil { - return fmt.Errorf("failed to marshal data: %v", err) - } + endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port) + serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer) + if err != nil { + return fmt.Errorf("error creating POST request: %v", err) + } + serverReq.Header.Set("Content-Type", "application/json") - endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer) + res, err := http.DefaultClient.Do(serverReq) + if err != nil { + return fmt.Errorf("POST predict: %v", err) + } + defer res.Body.Close() + + if res.StatusCode >= 400 { + bodyBytes, err := io.ReadAll(res.Body) if err != nil { - return fmt.Errorf("error creating POST request: %v", err) + return fmt.Errorf("failed reading llm error response: %w", err) } - req.Header.Set("Content-Type", "application/json") + log.Printf("llm predict error: %s", bodyBytes) + return fmt.Errorf("%s", bodyBytes) + } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return fmt.Errorf("POST predict: %v", err) - } - defer resp.Body.Close() + scanner := bufio.NewScanner(res.Body) + buf := make([]byte, 0, maxBufferSize) + scanner.Buffer(buf, maxBufferSize) - if resp.StatusCode >= 400 { - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed reading llm error response: %w", err) + // keep track of the last token generated, this is used to abort if the model starts looping + var lastToken string + var tokenRepeat int + + for scanner.Scan() { + select { + case <-ctx.Done(): + // This handles the request cancellation + return ctx.Err() + default: + line := scanner.Bytes() + if len(line) == 0 { + continue } - log.Printf("llm predict error: %s", bodyBytes) - return fmt.Errorf("%s", bodyBytes) - } - scanner := bufio.NewScanner(resp.Body) - buf := make([]byte, 0, maxBufferSize) - scanner.Buffer(buf, maxBufferSize) + evt, ok := bytes.CutPrefix(line, []byte("data: ")) + if !ok { + return fmt.Errorf("error parsing llm response stream: %s", line) + } - retryNeeded := false - // keep track of the last token generated, this is used to abort if the model starts looping - var lastToken string - var tokenRepeat int + var c completion + if err := json.Unmarshal(evt, &c); err != nil { + return fmt.Errorf("error unmarshaling llm prediction response: %v", err) + } - for scanner.Scan() { - select { - case <-ctx.Done(): - // This handles the request cancellation - return ctx.Err() + switch { + case strings.TrimSpace(c.Content) == lastToken: + tokenRepeat++ default: - line := scanner.Bytes() - if len(line) == 0 { - continue - } - - // try again on slot unavailable - if bytes.Contains(line, []byte("slot unavailable")) { - retryNeeded = true - break - } - - evt, ok := bytes.CutPrefix(line, []byte("data: ")) - if !ok { - return fmt.Errorf("error parsing llm response stream: %s", line) - } - - var c completion - if err := json.Unmarshal(evt, &c); err != nil { - return fmt.Errorf("error unmarshaling llm prediction response: %v", err) - } - - switch { - case strings.TrimSpace(c.Content) == lastToken: - tokenRepeat++ - default: - lastToken = strings.TrimSpace(c.Content) - tokenRepeat = 0 - } - - // 30 picked as an arbitrary max token repeat limit, modify as needed - if tokenRepeat > 30 { - slog.Debug("prediction aborted, token repeat limit reached") - return ctx.Err() - } - - if c.Content != "" { - fn(CompletionResponse{ - Content: c.Content, - }) - } - - if c.Stop { - fn(CompletionResponse{ - Done: true, - PromptEvalCount: c.Timings.PromptN, - PromptEvalDuration: parseDurationMs(c.Timings.PromptMS), - EvalCount: c.Timings.PredictedN, - EvalDuration: parseDurationMs(c.Timings.PredictedMS), - }) - return nil - } + lastToken = strings.TrimSpace(c.Content) + tokenRepeat = 0 } - } - if err := scanner.Err(); err != nil { - if strings.Contains(err.Error(), "unexpected EOF") { - s.Close() - msg := "" - if s.status != nil && s.status.LastErrMsg != "" { - msg = s.status.LastErrMsg - } - - return fmt.Errorf("an unknown error was encountered while running the model %s", msg) + // 30 picked as an arbitrary max token repeat limit, modify as needed + if tokenRepeat > 30 { + slog.Debug("prediction aborted, token repeat limit reached") + return ctx.Err() } - return fmt.Errorf("error reading llm response: %v", err) - } - if !retryNeeded { - return nil // success + if c.Content != "" { + fn(CompletionResponse{ + Content: c.Content, + }) + } + + if c.Stop { + fn(CompletionResponse{ + Done: true, + PromptEvalCount: c.Timings.PromptN, + PromptEvalDuration: parseDurationMs(c.Timings.PromptMS), + EvalCount: c.Timings.PredictedN, + EvalDuration: parseDurationMs(c.Timings.PredictedMS), + }) + return nil + } } } - // should never reach here ideally - return fmt.Errorf("max retries exceeded") + if err := scanner.Err(); err != nil { + if strings.Contains(err.Error(), "unexpected EOF") { + s.Close() + msg := "" + if s.status != nil && s.status.LastErrMsg != "" { + msg = s.status.LastErrMsg + } + return fmt.Errorf("an unknown error was encountered while running the model %s", msg) + } + + return fmt.Errorf("error reading llm response: %v", err) + } + + return nil } type EmbeddingRequest struct { @@ -743,8 +745,9 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er return nil, err } defer s.sem.Release(1) + // Make sure the server is ready - status, err := s.getServerStatus(ctx) + status, err := s.getServerStatusRetry(ctx) if err != nil { return nil, err } else if status != ServerStatusReady { @@ -799,7 +802,7 @@ func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) status, err := s.getServerStatus(ctx) if err != nil { return nil, err - } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble { + } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable { return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) } @@ -851,7 +854,7 @@ func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error status, err := s.getServerStatus(ctx) if err != nil { return "", err - } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble { + } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable { return "", fmt.Errorf("unexpected server status: %s", status.ToString()) }