diff --git a/api/types.go b/api/types.go index e5291a02..ca64bf25 100644 --- a/api/types.go +++ b/api/types.go @@ -242,6 +242,7 @@ type Runner struct { UseMMap *bool `json:"use_mmap,omitempty"` UseMLock bool `json:"use_mlock,omitempty"` NumThread int `json:"num_thread,omitempty"` + Reranking bool `json:"reranking,omitempty"` } // EmbedRequest is the request passed to [Client.Embed]. @@ -293,6 +294,22 @@ type EmbeddingResponse struct { Embedding []float64 `json:"embedding"` } +type RerankRequest struct { + Model string `json:"model"` + Query string `json:"query"` + TopN int `json:"top_n"` // return top N documents + Documents []string `json:"documents"` // list of documents to rerank + KeepAlive *Duration `json:"keep_alive,omitempty"` + Options map[string]interface{} `json:"options,omitempty"` +} + +type RerankResponse struct { + Results []struct { + Document string `json:"document"` + RelevanceScore float32 `json:"relevance_score"` + } `json:"results"` +} + // CreateRequest is the request passed to [Client.Create]. type CreateRequest struct { Model string `json:"model"` diff --git a/llama/llama.go b/llama/llama.go index dbb02768..5d95a8f7 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -140,7 +140,7 @@ type ContextParams struct { c C.struct_llama_context_params } -func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool) ContextParams { +func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, reranking bool) ContextParams { params := C.llama_context_default_params() params.n_ctx = C.uint(numCtx) params.n_batch = C.uint(batchSize) @@ -149,6 +149,9 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla params.n_threads_batch = params.n_threads params.embeddings = C.bool(true) params.flash_attn = C.bool(flashAttention) + if reranking { + params.pooling_type = C.LLAMA_POOLING_TYPE_RANK + } return ContextParams{c: params} } @@ -214,6 +217,18 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 { return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd()) } +func (c *Context) GetTokenBOS() C.llama_token { + return C.llama_token_bos(c.Model().c) +} + +func (c *Context) GetTokenEOS() C.llama_token { + return C.llama_token_eos(c.Model().c) +} + +func (c *Context) GetTokenSEP() C.llama_token { + return C.llama_token_sep(c.Model().c) +} + type ModelParams struct { NumGpuLayers int MainGpu int @@ -298,6 +313,10 @@ func (m *Model) AddBOSToken() bool { return bool(C.llama_add_bos_token(m.c)) } +func (m *Model) AddEOSToken() bool { + return bool(C.llama_add_eos_token(m.c)) +} + func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error { cLoraPath := C.CString(loraPath) defer C.free(unsafe.Pointer(cLoraPath)) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index cff7d148..a9acdcba 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -753,6 +753,77 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) { } } +type RerankRequest struct { + Model string `json:"model"` + Query string `json:"query"` + TopN int `json:"top_n"` // return top N documents + Documents []string `json:"documents"` // list of documents to rerank + CachePrompt bool `json:"cache_prompt"` +} + +type RerankResponse struct { + Results []struct { + Index int `json:"index"` + RelevanceScore float32 `json:"relevance_score"` + } `json:"results"` +} + +func (s *Server) rerank(w http.ResponseWriter, r *http.Request) { + var req RerankRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("bad rereank request: %s", err), http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + + var rsp RerankResponse + rsp.Results = make([]struct { + Index int `json:"index"` + RelevanceScore float32 `json:"relevance_score"` + }, len(req.Documents)) + + for i, doc := range req.Documents { + // reranking prompt format: [BOS]query[EOS][SEP]doc[EOS] + p := "" + if !s.model.AddBOSToken() { + p += s.model.TokenToPiece(int(s.lc.GetTokenBOS())) + } + p += req.Query + s.model.TokenToPiece(int(s.lc.GetTokenEOS())) + s.model.TokenToPiece(int(s.lc.GetTokenSEP())) + doc + if !s.model.AddEOSToken() { + p += s.model.TokenToPiece(int(s.lc.GetTokenEOS())) + } + seq, err := s.NewSequence(p, nil, NewSequenceParams{embedding: true}) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) + return + } + + s.mu.Lock() + for i, sq := range s.seqs { + if sq == nil { + seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + if err != nil { + s.mu.Unlock() + http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) + return + } + s.seqs[i] = seq + s.cond.Signal() + break + } + } + s.mu.Unlock() + + score := <-seq.embedding + rsp.Results[i].Index = i + rsp.Results[i].RelevanceScore = score[0] + } + + if err := json.NewEncoder(w).Encode(&rsp); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + } +} + func (s *Server) loadModel( params llama.ModelParams, mpath string, @@ -762,6 +833,7 @@ func (s *Server) loadModel( flashAttention bool, threads int, multiUserCache bool, + reranking bool, ) { llama.BackendInit() @@ -771,7 +843,7 @@ func (s *Server) loadModel( panic(err) } - ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention) + ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, reranking) s.lc, err = llama.NewContextWithModel(s.model, ctxParams) if err != nil { panic(err) @@ -819,6 +891,7 @@ func main() { tensorSplit := flag.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions") multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users") requirements := flag.Bool("requirements", false, "print json requirement information") + reranking := flag.Bool("reranking", false, "enable reranking (default: disabled)") flag.Parse() if *requirements { @@ -874,7 +947,7 @@ func main() { } server.ready.Add(1) - go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache) + go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache, *reranking) server.cond = sync.NewCond(&server.mu) @@ -893,6 +966,7 @@ func main() { mux.HandleFunc("/embedding", server.embeddings) mux.HandleFunc("/completion", server.completion) mux.HandleFunc("/health", server.health) + mux.HandleFunc("/rerank", server.rerank) httpServer := http.Server{ Handler: mux, diff --git a/llm/server.go b/llm/server.go index 96815826..fed20270 100644 --- a/llm/server.go +++ b/llm/server.go @@ -38,6 +38,7 @@ type LlamaServer interface { WaitUntilRunning(ctx context.Context) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error Embedding(ctx context.Context, input string) ([]float32, error) + Rerank(ctx context.Context, req RerankRequest, fn func(RerankResponse)) error Tokenize(ctx context.Context, content string) ([]int, error) Detokenize(ctx context.Context, tokens []int) (string, error) Close() error @@ -188,6 +189,10 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter "--batch-size", strconv.Itoa(opts.NumBatch), } + if opts.Reranking { + params = append(params, "--reranking") + } + if opts.NumGPU >= 0 { params = append(params, "--n-gpu-layers", strconv.Itoa(opts.NumGPU)) } @@ -911,6 +916,69 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err return e.Embedding, nil } +type RerankRequest struct { + Model string `json:"model"` + Query string `json:"query"` + TopN int `json:"top_n"` // return top N documents + Documents []string `json:"documents"` // list of documents to rerank +} +type RerankResponse struct { + Results []struct { + Index int `json:"index"` + RelevanceScore float32 `json:"relevance_score"` + } `json:"results"` +} + +func (s *llmServer) Rerank(ctx context.Context, req RerankRequest, fn func(RerankResponse)) error { + if err := s.sem.Acquire(ctx, 1); err != nil { + slog.Error("Failed to acquire semaphore", "error", err) + return err + } + defer s.sem.Release(1) + + status, err := s.getServerStatusRetry(ctx) + if err != nil { + return err + } else if status != ServerStatusReady { + return fmt.Errorf("unexpected server status: %s", status.ToString()) + } + + data, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("error marshaling rerank data: %w", err) + } + + r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/rerank", s.port), bytes.NewBuffer(data)) + if err != nil { + return fmt.Errorf("error creating rerank request: %w", err) + } + r.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(r) + if err != nil { + return fmt.Errorf("do rerank request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error reading rerank response: %w", err) + } + + if resp.StatusCode >= 400 { + log.Printf("llm rerank error: %s", body) + return fmt.Errorf("%s", body) + } + + var rr RerankResponse + if err := json.Unmarshal(body, &rr); err != nil { + return fmt.Errorf("unmarshal tokenize response: %w", err) + } + fn(rr) + + return nil +} + type TokenizeRequest struct { Content string `json:"content"` } diff --git a/server/routes.go b/server/routes.go index c5fd3293..4d19a827 100644 --- a/server/routes.go +++ b/server/routes.go @@ -18,6 +18,7 @@ import ( "os/signal" "path/filepath" "slices" + "sort" "strings" "syscall" "time" @@ -348,6 +349,79 @@ func (s *Server) GenerateHandler(c *gin.Context) { streamResponse(c, ch) } +func (s *Server) RerankHandler(c *gin.Context) { + var req api.RerankRequest + if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + } else if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + switch { + case len(req.Documents) == 0: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Documents cannot be empty"}) + return + case req.Query == "": + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Query cannot be empty"}) + return + case req.TopN < 0: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "TopN cannot be negative"}) + return + } + + if req.Options == nil { + req.Options = make(map[string]any) + } + req.Options["reranking"] = true + r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) + if err != nil { + handleScheduleError(c, req.Model, err) + return + } + + llmreq := llm.RerankRequest{ + Model: req.Model, + Query: req.Query, + TopN: req.TopN, + Documents: req.Documents, + } + err = r.Rerank(c.Request.Context(), llmreq, func(rr llm.RerankResponse) { + sort.SliceStable(rr.Results, func(i, j int) bool { + return rr.Results[i].RelevanceScore > rr.Results[j].RelevanceScore + }) + + var topn int + if req.TopN == 0 { + topn = len(rr.Results) // if TopN is unset, return all results + } else { + topn = min(len(rr.Results), req.TopN) + } + topResults := rr.Results[:topn] + + rsp := api.RerankResponse{ + Results: make([]struct { + Document string `json:"document"` + RelevanceScore float32 `json:"relevance_score"` + }, topn), + } + + for i, result := range topResults { + rsp.Results[i].Document = req.Documents[result.Index] + rsp.Results[i].RelevanceScore = result.RelevanceScore + } + + c.JSON(http.StatusOK, rsp) + }) + + if err != nil { + slog.Info(fmt.Sprintf("rerank failed: %v", err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to rerank"}) + return + } +} + func (s *Server) EmbedHandler(c *gin.Context) { checkpointStart := time.Now() var req api.EmbedRequest @@ -1158,6 +1232,7 @@ func (s *Server) GenerateRoutes() http.Handler { r.POST("/api/blobs/:digest", s.CreateBlobHandler) r.HEAD("/api/blobs/:digest", s.HeadBlobHandler) r.GET("/api/ps", s.PsHandler) + r.POST("/api/rerank", s.RerankHandler) // Compatibility endpoints r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)