From 8f9e6ec7d1739d58153abcc64ceb303e21df236a Mon Sep 17 00:00:00 2001 From: Liu Yuan Date: Tue, 29 Oct 2024 22:17:44 +0800 Subject: [PATCH] go runner: add rerank support Co-authored-by Craig Hughes Signed-off-by: Liu Yuan --- llama/llama.go | 21 +++++++++++- llama/runner/runner.go | 78 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 96 insertions(+), 3 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index a092ea12..572da1e1 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -138,7 +138,7 @@ type ContextParams struct { c C.struct_llama_context_params } -func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool) ContextParams { +func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, reranking bool) ContextParams { params := C.llama_context_default_params() params.n_ctx = C.uint(numCtx) params.n_batch = C.uint(batchSize) @@ -147,6 +147,9 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla params.n_threads_batch = params.n_threads params.embeddings = C.bool(true) params.flash_attn = C.bool(flashAttention) + if reranking { + params.pooling_type = C.LLAMA_POOLING_TYPE_RANK + } return ContextParams{c: params} } @@ -212,6 +215,18 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 { return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd()) } +func (c *Context) GetTokenBOS() C.llama_token { + return C.llama_token_bos(c.Model().c) +} + +func (c *Context) GetTokenEOS() C.llama_token { + return C.llama_token_eos(c.Model().c) +} + +func (c *Context) GetTokenSEP() C.llama_token { + return C.llama_token_sep(c.Model().c) +} + type ModelParams struct { NumGpuLayers int MainGpu int @@ -296,6 +311,10 @@ func (m *Model) AddBOSToken() bool { return bool(C.llama_add_bos_token(m.c)) } +func (m *Model) AddEOSToken() bool { + return bool(C.llama_add_eos_token(m.c)) +} + func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error { cLoraPath := C.CString(loraPath) defer C.free(unsafe.Pointer(cLoraPath)) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 0a37dee0..b478759a 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -771,6 +771,77 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) { } } +type RerankRequest struct { + Model string `json:"model"` + Query string `json:"query"` + TopN int `json:"top_n"` // return top N documents + Documents []string `json:"documents"` // list of documents to rerank + CachePrompt bool `json:"cache_prompt"` +} + +type RerankResponse struct { + Results []struct { + Index int `json:"index"` + RelevanceScore float32 `json:"relevance_score"` + } `json:"results"` +} + +func (s *Server) rerank(w http.ResponseWriter, r *http.Request) { + var req RerankRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("bad rereank request: %s", err), http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + + var rsp RerankResponse + rsp.Results = make([]struct { + Index int `json:"index"` + RelevanceScore float32 `json:"relevance_score"` + }, len(req.Documents)) + + for i, doc := range req.Documents { + // reranking prompt format: [BOS]query[EOS][SEP]doc[EOS] + p := "" + if !s.model.AddBOSToken() { + p += s.model.TokenToPiece(int(s.lc.GetTokenBOS())) + } + p += req.Query + s.model.TokenToPiece(int(s.lc.GetTokenEOS())) + s.model.TokenToPiece(int(s.lc.GetTokenSEP())) + doc + if !s.model.AddEOSToken() { + p += s.model.TokenToPiece(int(s.lc.GetTokenEOS())) + } + seq, err := s.NewSequence(p, nil, NewSequenceParams{embedding: true}) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) + return + } + + s.mu.Lock() + for i, sq := range s.seqs { + if sq == nil { + seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + if err != nil { + s.mu.Unlock() + http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) + return + } + s.seqs[i] = seq + s.cond.Signal() + break + } + } + s.mu.Unlock() + + score := <-seq.embedding + rsp.Results[i].Index = i + rsp.Results[i].RelevanceScore = score[0] + } + + if err := json.NewEncoder(w).Encode(&rsp); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + } +} + func (s *Server) loadModel( params llama.ModelParams, mpath string, @@ -780,6 +851,7 @@ func (s *Server) loadModel( flashAttention bool, threads int, multiUserCache bool, + reranking bool, ) { llama.BackendInit() @@ -789,7 +861,7 @@ func (s *Server) loadModel( panic(err) } - ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention) + ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, reranking) s.lc, err = llama.NewContextWithModel(s.model, ctxParams) if err != nil { panic(err) @@ -838,6 +910,7 @@ func main() { tensorSplit := flag.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions") multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users") requirements := flag.Bool("requirements", false, "print json requirement information") + reranking := flag.Bool("reranking", false, "enable reranking (default: disabled)") flag.Parse() if *requirements { @@ -893,7 +966,7 @@ func main() { } server.ready.Add(1) - go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache) + go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache, *reranking) server.cond = sync.NewCond(&server.mu) @@ -912,6 +985,7 @@ func main() { mux.HandleFunc("/embedding", server.embeddings) mux.HandleFunc("/completion", server.completion) mux.HandleFunc("/health", server.health) + mux.HandleFunc("/rerank", server.rerank) httpServer := http.Server{ Handler: mux,