route: add rerank support

Signed-off-by: Liu Yuan <namei.unix@gmail.com>
This commit is contained in:
Liu Yuan 2024-10-15 18:46:35 +08:00
parent 8f9e6ec7d1
commit 67818b5093
3 changed files with 151 additions and 0 deletions

View File

@ -293,6 +293,22 @@ type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
}
type RerankRequest struct {
Model string `json:"model"`
Query string `json:"query"`
TopN int `json:"top_n"` // return top N documents
Documents []string `json:"documents"` // list of documents to rerank
KeepAlive *Duration `json:"keep_alive,omitempty"`
Options map[string]interface{} `json:"options,omitempty"`
}
type RerankResponse struct {
Results []struct {
Document string `json:"document"`
RelevanceScore float32 `json:"relevance_score"`
} `json:"results"`
}
// CreateRequest is the request passed to [Client.Create].
type CreateRequest struct {
Model string `json:"model"`

View File

@ -38,6 +38,7 @@ type LlamaServer interface {
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, input string) ([]float32, error)
Rerank(ctx context.Context, req RerankRequest, fn func(RerankResponse)) error
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
@ -911,6 +912,69 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
return e.Embedding, nil
}
type RerankRequest struct {
Model string `json:"model"`
Query string `json:"query"`
TopN int `json:"top_n"` // return top N documents
Documents []string `json:"documents"` // list of documents to rerank
}
type RerankResponse struct {
Results []struct {
Index int `json:"index"`
RelevanceScore float32 `json:"relevance_score"`
} `json:"results"`
}
func (s *llmServer) Rerank(ctx context.Context, req RerankRequest, fn func(RerankResponse)) error {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return err
}
defer s.sem.Release(1)
status, err := s.getServerStatusRetry(ctx)
if err != nil {
return err
} else if status != ServerStatusReady {
return fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("error marshaling rerank data: %w", err)
}
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/rerank", s.port), bytes.NewBuffer(data))
if err != nil {
return fmt.Errorf("error creating rerank request: %w", err)
}
r.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(r)
if err != nil {
return fmt.Errorf("do rerank request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error reading rerank response: %w", err)
}
if resp.StatusCode >= 400 {
log.Printf("llm rerank error: %s", body)
return fmt.Errorf("%s", body)
}
var rr RerankResponse
if err := json.Unmarshal(body, &rr); err != nil {
return fmt.Errorf("unmarshal tokenize response: %w", err)
}
fn(rr)
return nil
}
type TokenizeRequest struct {
Content string `json:"content"`
}

View File

@ -18,6 +18,7 @@ import (
"os/signal"
"path/filepath"
"slices"
"sort"
"strings"
"syscall"
"time"
@ -348,6 +349,75 @@ func (s *Server) GenerateHandler(c *gin.Context) {
streamResponse(c, ch)
}
func (s *Server) RerankHandler(c *gin.Context) {
var req api.RerankRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
switch {
case len(req.Documents) == 0:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Documents cannot be empty"})
return
case req.Query == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Query cannot be empty"})
return
case req.TopN < 0:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "TopN cannot be negative"})
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
}
llmreq := llm.RerankRequest{
Model: req.Model,
Query: req.Query,
TopN: req.TopN,
Documents: req.Documents,
}
err = r.Rerank(c.Request.Context(), llmreq, func(rr llm.RerankResponse) {
sort.SliceStable(rr.Results, func(i, j int) bool {
return rr.Results[i].RelevanceScore > rr.Results[j].RelevanceScore
})
var topn int
if req.TopN == 0 {
topn = len(rr.Results) // if TopN is unset, return all results
} else {
topn = min(len(rr.Results), req.TopN)
}
topResults := rr.Results[:topn]
rsp := api.RerankResponse{
Results: make([]struct {
Document string `json:"document"`
RelevanceScore float32 `json:"relevance_score"`
}, topn),
}
for i, result := range topResults {
rsp.Results[i].Document = req.Documents[result.Index]
rsp.Results[i].RelevanceScore = result.RelevanceScore
}
c.JSON(http.StatusOK, rsp)
})
if err != nil {
slog.Info(fmt.Sprintf("rerank failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to rerank"})
return
}
}
func (s *Server) EmbedHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.EmbedRequest
@ -1158,6 +1228,7 @@ func (s *Server) GenerateRoutes() http.Handler {
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.GET("/api/ps", s.PsHandler)
r.POST("/api/rerank", s.RerankHandler)
// Compatibility endpoints
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)