diff --git a/llm/server.go b/llm/server.go index 50b2ab60..ed7e5012 100644 --- a/llm/server.go +++ b/llm/server.go @@ -344,9 +344,12 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr } } - m, err := loadModel(model, true) - if err != nil { - return nil, fmt.Errorf("unable to load model for tokenization %w", err) + var m *loadedModel + if envconfig.NewRunners() { + m, err = loadModel(model, true) + if err != nil { + return nil, fmt.Errorf("unable to load model for tokenization %w", err) + } } s := &llmServer{ port: port, @@ -960,7 +963,51 @@ type TokenizeResponse struct { } func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) { - return tokenize(s.model, content) + if envconfig.NewRunners() { + return tokenize(s.model, content) + } + + // Make sure the server is ready + status, err := s.getServerStatus(ctx) + if err != nil { + return nil, err + } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable { + return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) + } + + data, err := json.Marshal(TokenizeRequest{Content: content}) + if err != nil { + return nil, fmt.Errorf("marshaling encode data: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port), bytes.NewBuffer(data)) + if err != nil { + return nil, fmt.Errorf("encode request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("do encode request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read encode request: %w", err) + } + + if resp.StatusCode >= 400 { + log.Printf("llm encode error: %s", body) + return nil, fmt.Errorf("%s", body) + } + + var encoded TokenizeResponse + if err := json.Unmarshal(body, &encoded); err != nil { + return nil, fmt.Errorf("unmarshal encode response: %w", err) + } + + return encoded.Tokens, nil } type DetokenizeRequest struct { @@ -972,7 +1019,50 @@ type DetokenizeResponse struct { } func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) { - return detokenize(s.model, tokens), nil + if envconfig.NewRunners() { + return detokenize(s.model, tokens), nil + } + // Make sure the server is ready + status, err := s.getServerStatus(ctx) + if err != nil { + return "", err + } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable { + return "", fmt.Errorf("unexpected server status: %s", status.ToString()) + } + + data, err := json.Marshal(DetokenizeRequest{Tokens: tokens}) + if err != nil { + return "", fmt.Errorf("marshaling decode data: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/detokenize", s.port), bytes.NewBuffer(data)) + if err != nil { + return "", fmt.Errorf("decode request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("do decode request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read decode request: %w", err) + } + + if resp.StatusCode >= 400 { + log.Printf("llm decode error: %s", body) + return "", fmt.Errorf("%s", body) + } + + var decoded DetokenizeResponse + if err := json.Unmarshal(body, &decoded); err != nil { + return "", fmt.Errorf("unmarshal encode response: %w", err) + } + + return decoded.Content, nil } func (s *llmServer) Close() error {