Compare commits
1 Commits
main
...
brucemacd/
Author | SHA1 | Date | |
---|---|---|---|
|
e94e5b1771 |
129
llm/llama.go
129
llm/llama.go
@ -535,6 +535,8 @@ type prediction struct {
|
||||
}
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
const maxRetries = 3
|
||||
const retryDelay = 1 * time.Second
|
||||
|
||||
type PredictOpts struct {
|
||||
Model string
|
||||
@ -557,6 +559,11 @@ type PredictResult struct {
|
||||
EvalDuration time.Duration
|
||||
}
|
||||
|
||||
// IsRetryable checks if the line matches a condition that can be retried
|
||||
func isRetryable(line []byte) bool {
|
||||
return bytes.Contains(line, []byte("slot unavailable"))
|
||||
}
|
||||
|
||||
func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
|
||||
request := map[string]any{
|
||||
"prompt": predict.Prompt,
|
||||
@ -585,53 +592,69 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
|
||||
request["grammar"] = jsonGrammar
|
||||
}
|
||||
|
||||
// Handling JSON marshaling with special characters unescaped.
|
||||
buffer := &bytes.Buffer{}
|
||||
enc := json.NewEncoder(buffer)
|
||||
enc.SetEscapeHTML(false)
|
||||
|
||||
if err := enc.Encode(request); err != nil {
|
||||
return fmt.Errorf("failed to marshal data: %v", err)
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating POST request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("POST predict: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
for retries := 0; retries < maxRetries; retries++ {
|
||||
// Handling JSON marshaling with special characters unescaped.
|
||||
buffer := &bytes.Buffer{}
|
||||
enc := json.NewEncoder(buffer)
|
||||
enc.SetEscapeHTML(false)
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading llm error response: %w", err)
|
||||
if err := enc.Encode(request); err != nil {
|
||||
return fmt.Errorf("failed to marshal data: %v", err)
|
||||
}
|
||||
log.Printf("llm predict error: %s", bodyBytes)
|
||||
return fmt.Errorf("%s", bodyBytes)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
// increase the buffer size to avoid running out of space
|
||||
buf := make([]byte, 0, maxBufferSize)
|
||||
scanner.Buffer(buf, maxBufferSize)
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// This handles the request cancellation
|
||||
return ctx.Err()
|
||||
default:
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
if retries > 0 {
|
||||
time.Sleep(retryDelay) // wait before retrying
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating POST request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("POST predict: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading llm error response: %w", err)
|
||||
}
|
||||
log.Printf("llm predict error: %s", bodyBytes)
|
||||
return fmt.Errorf("%s", bodyBytes)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
// increase the buffer size to avoid running out of space
|
||||
buf := make([]byte, 0, maxBufferSize)
|
||||
scanner.Buffer(buf, maxBufferSize)
|
||||
|
||||
retryNeeded := false
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// This handles the request cancellation
|
||||
return ctx.Err()
|
||||
default:
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if isRetryable(line) {
|
||||
retryNeeded = true
|
||||
break
|
||||
}
|
||||
|
||||
evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
||||
if !ok {
|
||||
return fmt.Errorf("error parsing llm response stream: %s", line)
|
||||
}
|
||||
|
||||
if evt, ok := bytes.CutPrefix(line, []byte("data: ")); ok {
|
||||
var p prediction
|
||||
if err := json.Unmarshal(evt, &p); err != nil {
|
||||
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
||||
@ -661,21 +684,27 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if strings.Contains(err.Error(), "unexpected EOF") {
|
||||
// this means the llama runner subprocess crashed
|
||||
llm.Close()
|
||||
if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
|
||||
return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg)
|
||||
if err := scanner.Err(); err != nil {
|
||||
if strings.Contains(err.Error(), "unexpected EOF") {
|
||||
// this means the llama runner subprocess crashed
|
||||
llm.Close()
|
||||
if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
|
||||
return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg)
|
||||
}
|
||||
return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model")
|
||||
}
|
||||
return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model")
|
||||
return fmt.Errorf("error reading llm response: %v", err)
|
||||
}
|
||||
|
||||
if !retryNeeded {
|
||||
// success
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("error reading llm response: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
// should never reach here ideally
|
||||
return fmt.Errorf("max retries exceeded")
|
||||
}
|
||||
|
||||
type TokenizeRequest struct {
|
||||
|
Loading…
x
Reference in New Issue
Block a user