From 8f9e6ec7d1739d58153abcc64ceb303e21df236a Mon Sep 17 00:00:00 2001
From: Liu Yuan <namei.unix@gmail.com>
Date: Tue, 29 Oct 2024 22:17:44 +0800
Subject: [PATCH 1/3] go runner: add rerank support

Co-authored-by Craig Hughes
Signed-off-by: Liu Yuan <namei.unix@gmail.com>
---
 llama/llama.go         | 21 +++++++++++-
 llama/runner/runner.go | 78 ++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 96 insertions(+), 3 deletions(-)

diff --git a/llama/llama.go b/llama/llama.go
index a092ea12..572da1e1 100644
--- a/llama/llama.go
+++ b/llama/llama.go
@@ -138,7 +138,7 @@ type ContextParams struct {
 	c C.struct_llama_context_params
 }
 
-func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool) ContextParams {
+func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, reranking bool) ContextParams {
 	params := C.llama_context_default_params()
 	params.n_ctx = C.uint(numCtx)
 	params.n_batch = C.uint(batchSize)
@@ -147,6 +147,9 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla
 	params.n_threads_batch = params.n_threads
 	params.embeddings = C.bool(true)
 	params.flash_attn = C.bool(flashAttention)
+	if reranking {
+		params.pooling_type = C.LLAMA_POOLING_TYPE_RANK
+	}
 	return ContextParams{c: params}
 }
 
@@ -212,6 +215,18 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
 	return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
 }
 
+func (c *Context) GetTokenBOS() C.llama_token {
+	return C.llama_token_bos(c.Model().c)
+}
+
+func (c *Context) GetTokenEOS() C.llama_token {
+	return C.llama_token_eos(c.Model().c)
+}
+
+func (c *Context) GetTokenSEP() C.llama_token {
+	return C.llama_token_sep(c.Model().c)
+}
+
 type ModelParams struct {
 	NumGpuLayers int
 	MainGpu      int
@@ -296,6 +311,10 @@ func (m *Model) AddBOSToken() bool {
 	return bool(C.llama_add_bos_token(m.c))
 }
 
+func (m *Model) AddEOSToken() bool {
+	return bool(C.llama_add_eos_token(m.c))
+}
+
 func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
 	cLoraPath := C.CString(loraPath)
 	defer C.free(unsafe.Pointer(cLoraPath))
diff --git a/llama/runner/runner.go b/llama/runner/runner.go
index 0a37dee0..b478759a 100644
--- a/llama/runner/runner.go
+++ b/llama/runner/runner.go
@@ -771,6 +771,77 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
 	}
 }
 
+type RerankRequest struct {
+	Model       string   `json:"model"`
+	Query       string   `json:"query"`
+	TopN        int      `json:"top_n"`     // return top N documents
+	Documents   []string `json:"documents"` // list of documents to rerank
+	CachePrompt bool     `json:"cache_prompt"`
+}
+
+type RerankResponse struct {
+	Results []struct {
+		Index          int     `json:"index"`
+		RelevanceScore float32 `json:"relevance_score"`
+	} `json:"results"`
+}
+
+func (s *Server) rerank(w http.ResponseWriter, r *http.Request) {
+	var req RerankRequest
+	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+		http.Error(w, fmt.Sprintf("bad rereank request: %s", err), http.StatusBadRequest)
+		return
+	}
+	w.Header().Set("Content-Type", "application/json")
+
+	var rsp RerankResponse
+	rsp.Results = make([]struct {
+		Index          int     `json:"index"`
+		RelevanceScore float32 `json:"relevance_score"`
+	}, len(req.Documents))
+
+	for i, doc := range req.Documents {
+		// reranking prompt format: [BOS]query[EOS][SEP]doc[EOS]
+		p := ""
+		if !s.model.AddBOSToken() {
+			p += s.model.TokenToPiece(int(s.lc.GetTokenBOS()))
+		}
+		p += req.Query + s.model.TokenToPiece(int(s.lc.GetTokenEOS())) + s.model.TokenToPiece(int(s.lc.GetTokenSEP())) + doc
+		if !s.model.AddEOSToken() {
+			p += s.model.TokenToPiece(int(s.lc.GetTokenEOS()))
+		}
+		seq, err := s.NewSequence(p, nil, NewSequenceParams{embedding: true})
+		if err != nil {
+			http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
+			return
+		}
+
+		s.mu.Lock()
+		for i, sq := range s.seqs {
+			if sq == nil {
+				seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
+				if err != nil {
+					s.mu.Unlock()
+					http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
+					return
+				}
+				s.seqs[i] = seq
+				s.cond.Signal()
+				break
+			}
+		}
+		s.mu.Unlock()
+
+		score := <-seq.embedding
+		rsp.Results[i].Index = i
+		rsp.Results[i].RelevanceScore = score[0]
+	}
+
+	if err := json.NewEncoder(w).Encode(&rsp); err != nil {
+		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
+	}
+}
+
 func (s *Server) loadModel(
 	params llama.ModelParams,
 	mpath string,
@@ -780,6 +851,7 @@ func (s *Server) loadModel(
 	flashAttention bool,
 	threads int,
 	multiUserCache bool,
+	reranking bool,
 ) {
 	llama.BackendInit()
 
@@ -789,7 +861,7 @@ func (s *Server) loadModel(
 		panic(err)
 	}
 
-	ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention)
+	ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, reranking)
 	s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
 	if err != nil {
 		panic(err)
@@ -838,6 +910,7 @@ func main() {
 	tensorSplit := flag.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
 	multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
 	requirements := flag.Bool("requirements", false, "print json requirement information")
+	reranking := flag.Bool("reranking", false, "enable reranking (default: disabled)")
 
 	flag.Parse()
 	if *requirements {
@@ -893,7 +966,7 @@ func main() {
 	}
 
 	server.ready.Add(1)
-	go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache)
+	go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache, *reranking)
 
 	server.cond = sync.NewCond(&server.mu)
 
@@ -912,6 +985,7 @@ func main() {
 	mux.HandleFunc("/embedding", server.embeddings)
 	mux.HandleFunc("/completion", server.completion)
 	mux.HandleFunc("/health", server.health)
+	mux.HandleFunc("/rerank", server.rerank)
 
 	httpServer := http.Server{
 		Handler: mux,

From 67818b5093646720cac5ae748eb00692ecad4c0b Mon Sep 17 00:00:00 2001
From: Liu Yuan <namei.unix@gmail.com>
Date: Tue, 15 Oct 2024 18:46:35 +0800
Subject: [PATCH 2/3] route: add rerank support

Signed-off-by: Liu Yuan <namei.unix@gmail.com>
---
 api/types.go     | 16 +++++++++++
 llm/server.go    | 64 +++++++++++++++++++++++++++++++++++++++++++
 server/routes.go | 71 ++++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 151 insertions(+)

diff --git a/api/types.go b/api/types.go
index d09ad06c..afdeb10f 100644
--- a/api/types.go
+++ b/api/types.go
@@ -293,6 +293,22 @@ type EmbeddingResponse struct {
 	Embedding []float64 `json:"embedding"`
 }
 
+type RerankRequest struct {
+	Model     string                 `json:"model"`
+	Query     string                 `json:"query"`
+	TopN      int                    `json:"top_n"`     // return top N documents
+	Documents []string               `json:"documents"` // list of documents to rerank
+	KeepAlive *Duration              `json:"keep_alive,omitempty"`
+	Options   map[string]interface{} `json:"options,omitempty"`
+}
+
+type RerankResponse struct {
+	Results []struct {
+		Document       string  `json:"document"`
+		RelevanceScore float32 `json:"relevance_score"`
+	} `json:"results"`
+}
+
 // CreateRequest is the request passed to [Client.Create].
 type CreateRequest struct {
 	Model     string `json:"model"`
diff --git a/llm/server.go b/llm/server.go
index 5ca6aa32..0c732ea9 100644
--- a/llm/server.go
+++ b/llm/server.go
@@ -38,6 +38,7 @@ type LlamaServer interface {
 	WaitUntilRunning(ctx context.Context) error
 	Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
 	Embedding(ctx context.Context, input string) ([]float32, error)
+	Rerank(ctx context.Context, req RerankRequest, fn func(RerankResponse)) error
 	Tokenize(ctx context.Context, content string) ([]int, error)
 	Detokenize(ctx context.Context, tokens []int) (string, error)
 	Close() error
@@ -911,6 +912,69 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
 	return e.Embedding, nil
 }
 
+type RerankRequest struct {
+	Model     string   `json:"model"`
+	Query     string   `json:"query"`
+	TopN      int      `json:"top_n"`     // return top N documents
+	Documents []string `json:"documents"` // list of documents to rerank
+}
+type RerankResponse struct {
+	Results []struct {
+		Index          int     `json:"index"`
+		RelevanceScore float32 `json:"relevance_score"`
+	} `json:"results"`
+}
+
+func (s *llmServer) Rerank(ctx context.Context, req RerankRequest, fn func(RerankResponse)) error {
+	if err := s.sem.Acquire(ctx, 1); err != nil {
+		slog.Error("Failed to acquire semaphore", "error", err)
+		return err
+	}
+	defer s.sem.Release(1)
+
+	status, err := s.getServerStatusRetry(ctx)
+	if err != nil {
+		return err
+	} else if status != ServerStatusReady {
+		return fmt.Errorf("unexpected server status: %s", status.ToString())
+	}
+
+	data, err := json.Marshal(req)
+	if err != nil {
+		return fmt.Errorf("error marshaling rerank data: %w", err)
+	}
+
+	r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/rerank", s.port), bytes.NewBuffer(data))
+	if err != nil {
+		return fmt.Errorf("error creating rerank request: %w", err)
+	}
+	r.Header.Set("Content-Type", "application/json")
+
+	resp, err := http.DefaultClient.Do(r)
+	if err != nil {
+		return fmt.Errorf("do rerank request: %w", err)
+	}
+	defer resp.Body.Close()
+
+	body, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return fmt.Errorf("error reading rerank response: %w", err)
+	}
+
+	if resp.StatusCode >= 400 {
+		log.Printf("llm rerank error: %s", body)
+		return fmt.Errorf("%s", body)
+	}
+
+	var rr RerankResponse
+	if err := json.Unmarshal(body, &rr); err != nil {
+		return fmt.Errorf("unmarshal tokenize response: %w", err)
+	}
+	fn(rr)
+
+	return nil
+}
+
 type TokenizeRequest struct {
 	Content string `json:"content"`
 }
diff --git a/server/routes.go b/server/routes.go
index c5fd3293..33cdbbda 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -18,6 +18,7 @@ import (
 	"os/signal"
 	"path/filepath"
 	"slices"
+	"sort"
 	"strings"
 	"syscall"
 	"time"
@@ -348,6 +349,75 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 	streamResponse(c, ch)
 }
 
+func (s *Server) RerankHandler(c *gin.Context) {
+	var req api.RerankRequest
+	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+		return
+	} else if err != nil {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	switch {
+	case len(req.Documents) == 0:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Documents cannot be empty"})
+		return
+	case req.Query == "":
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Query cannot be empty"})
+		return
+	case req.TopN < 0:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "TopN cannot be negative"})
+		return
+	}
+
+	r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
+	if err != nil {
+		handleScheduleError(c, req.Model, err)
+		return
+	}
+
+	llmreq := llm.RerankRequest{
+		Model:     req.Model,
+		Query:     req.Query,
+		TopN:      req.TopN,
+		Documents: req.Documents,
+	}
+	err = r.Rerank(c.Request.Context(), llmreq, func(rr llm.RerankResponse) {
+		sort.SliceStable(rr.Results, func(i, j int) bool {
+			return rr.Results[i].RelevanceScore > rr.Results[j].RelevanceScore
+		})
+
+		var topn int
+		if req.TopN == 0 {
+			topn = len(rr.Results) // if TopN is unset, return all results
+		} else {
+			topn = min(len(rr.Results), req.TopN)
+		}
+		topResults := rr.Results[:topn]
+
+		rsp := api.RerankResponse{
+			Results: make([]struct {
+				Document       string  `json:"document"`
+				RelevanceScore float32 `json:"relevance_score"`
+			}, topn),
+		}
+
+		for i, result := range topResults {
+			rsp.Results[i].Document = req.Documents[result.Index]
+			rsp.Results[i].RelevanceScore = result.RelevanceScore
+		}
+
+		c.JSON(http.StatusOK, rsp)
+	})
+
+	if err != nil {
+		slog.Info(fmt.Sprintf("rerank failed: %v", err))
+		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to rerank"})
+		return
+	}
+}
+
 func (s *Server) EmbedHandler(c *gin.Context) {
 	checkpointStart := time.Now()
 	var req api.EmbedRequest
@@ -1158,6 +1228,7 @@ func (s *Server) GenerateRoutes() http.Handler {
 	r.POST("/api/blobs/:digest", s.CreateBlobHandler)
 	r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
 	r.GET("/api/ps", s.PsHandler)
+	r.POST("/api/rerank", s.RerankHandler)
 
 	// Compatibility endpoints
 	r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)

From bfae776f34250386c0e05ae4e18126d42cfb940e Mon Sep 17 00:00:00 2001
From: Liu Yuan <namei.unix@gmail.com>
Date: Thu, 31 Oct 2024 22:48:27 +0800
Subject: [PATCH 3/3] enable --reranking flag for rerank handler while starting
 server

Signed-off-by: Liu Yuan <namei.unix@gmail.com>
---
 api/types.go     | 1 +
 llm/server.go    | 4 ++++
 server/routes.go | 4 ++++
 3 files changed, 9 insertions(+)

diff --git a/api/types.go b/api/types.go
index afdeb10f..f40819a3 100644
--- a/api/types.go
+++ b/api/types.go
@@ -242,6 +242,7 @@ type Runner struct {
 	UseMMap   *bool `json:"use_mmap,omitempty"`
 	UseMLock  bool  `json:"use_mlock,omitempty"`
 	NumThread int   `json:"num_thread,omitempty"`
+	Reranking bool  `json:"reranking,omitempty"`
 }
 
 // EmbedRequest is the request passed to [Client.Embed].
diff --git a/llm/server.go b/llm/server.go
index 0c732ea9..c5a1da64 100644
--- a/llm/server.go
+++ b/llm/server.go
@@ -189,6 +189,10 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
 		"--batch-size", strconv.Itoa(opts.NumBatch),
 	}
 
+	if opts.Reranking {
+		params = append(params, "--reranking")
+	}
+
 	if opts.NumGPU >= 0 {
 		params = append(params, "--n-gpu-layers", strconv.Itoa(opts.NumGPU))
 	}
diff --git a/server/routes.go b/server/routes.go
index 33cdbbda..4d19a827 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -371,6 +371,10 @@ func (s *Server) RerankHandler(c *gin.Context) {
 		return
 	}
 
+	if req.Options == nil {
+		req.Options = make(map[string]any)
+	}
+	req.Options["reranking"] = true
 	r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
 	if err != nil {
 		handleScheduleError(c, req.Model, err)