Compare commits
1 Commits
main
...
brucemacd/
Author | SHA1 | Date | |
---|---|---|---|
|
e94e5b1771 |
35
llm/llama.go
35
llm/llama.go
@ -535,6 +535,8 @@ type prediction struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const maxBufferSize = 512 * format.KiloByte
|
const maxBufferSize = 512 * format.KiloByte
|
||||||
|
const maxRetries = 3
|
||||||
|
const retryDelay = 1 * time.Second
|
||||||
|
|
||||||
type PredictOpts struct {
|
type PredictOpts struct {
|
||||||
Model string
|
Model string
|
||||||
@ -557,6 +559,11 @@ type PredictResult struct {
|
|||||||
EvalDuration time.Duration
|
EvalDuration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsRetryable checks if the line matches a condition that can be retried
|
||||||
|
func isRetryable(line []byte) bool {
|
||||||
|
return bytes.Contains(line, []byte("slot unavailable"))
|
||||||
|
}
|
||||||
|
|
||||||
func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
|
func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
|
||||||
request := map[string]any{
|
request := map[string]any{
|
||||||
"prompt": predict.Prompt,
|
"prompt": predict.Prompt,
|
||||||
@ -585,6 +592,9 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
|
|||||||
request["grammar"] = jsonGrammar
|
request["grammar"] = jsonGrammar
|
||||||
}
|
}
|
||||||
|
|
||||||
|
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
|
||||||
|
|
||||||
|
for retries := 0; retries < maxRetries; retries++ {
|
||||||
// Handling JSON marshaling with special characters unescaped.
|
// Handling JSON marshaling with special characters unescaped.
|
||||||
buffer := &bytes.Buffer{}
|
buffer := &bytes.Buffer{}
|
||||||
enc := json.NewEncoder(buffer)
|
enc := json.NewEncoder(buffer)
|
||||||
@ -594,7 +604,9 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
|
|||||||
return fmt.Errorf("failed to marshal data: %v", err)
|
return fmt.Errorf("failed to marshal data: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
|
if retries > 0 {
|
||||||
|
time.Sleep(retryDelay) // wait before retrying
|
||||||
|
}
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating POST request: %v", err)
|
return fmt.Errorf("error creating POST request: %v", err)
|
||||||
@ -620,6 +632,8 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
|
|||||||
// increase the buffer size to avoid running out of space
|
// increase the buffer size to avoid running out of space
|
||||||
buf := make([]byte, 0, maxBufferSize)
|
buf := make([]byte, 0, maxBufferSize)
|
||||||
scanner.Buffer(buf, maxBufferSize)
|
scanner.Buffer(buf, maxBufferSize)
|
||||||
|
|
||||||
|
retryNeeded := false
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@ -631,7 +645,16 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if evt, ok := bytes.CutPrefix(line, []byte("data: ")); ok {
|
if isRetryable(line) {
|
||||||
|
retryNeeded = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("error parsing llm response stream: %s", line)
|
||||||
|
}
|
||||||
|
|
||||||
var p prediction
|
var p prediction
|
||||||
if err := json.Unmarshal(evt, &p); err != nil {
|
if err := json.Unmarshal(evt, &p); err != nil {
|
||||||
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
||||||
@ -661,7 +684,6 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
if strings.Contains(err.Error(), "unexpected EOF") {
|
if strings.Contains(err.Error(), "unexpected EOF") {
|
||||||
@ -675,8 +697,15 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
|
|||||||
return fmt.Errorf("error reading llm response: %v", err)
|
return fmt.Errorf("error reading llm response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !retryNeeded {
|
||||||
|
// success
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// should never reach here ideally
|
||||||
|
return fmt.Errorf("max retries exceeded")
|
||||||
|
}
|
||||||
|
|
||||||
type TokenizeRequest struct {
|
type TokenizeRequest struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
|
Loading…
x
Reference in New Issue
Block a user