go runner: add rerank support

Co-authored-by Craig Hughes
Signed-off-by: Liu Yuan <namei.unix@gmail.com>
This commit is contained in:
Liu Yuan 2024-10-29 22:17:44 +08:00
parent 3d25e7bf8c
commit 8f9e6ec7d1
2 changed files with 96 additions and 3 deletions

View File

@ -138,7 +138,7 @@ type ContextParams struct {
c C.struct_llama_context_params
}
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool) ContextParams {
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, reranking bool) ContextParams {
params := C.llama_context_default_params()
params.n_ctx = C.uint(numCtx)
params.n_batch = C.uint(batchSize)
@ -147,6 +147,9 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla
params.n_threads_batch = params.n_threads
params.embeddings = C.bool(true)
params.flash_attn = C.bool(flashAttention)
if reranking {
params.pooling_type = C.LLAMA_POOLING_TYPE_RANK
}
return ContextParams{c: params}
}
@ -212,6 +215,18 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
}
func (c *Context) GetTokenBOS() C.llama_token {
return C.llama_token_bos(c.Model().c)
}
func (c *Context) GetTokenEOS() C.llama_token {
return C.llama_token_eos(c.Model().c)
}
func (c *Context) GetTokenSEP() C.llama_token {
return C.llama_token_sep(c.Model().c)
}
type ModelParams struct {
NumGpuLayers int
MainGpu int
@ -296,6 +311,10 @@ func (m *Model) AddBOSToken() bool {
return bool(C.llama_add_bos_token(m.c))
}
func (m *Model) AddEOSToken() bool {
return bool(C.llama_add_eos_token(m.c))
}
func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
cLoraPath := C.CString(loraPath)
defer C.free(unsafe.Pointer(cLoraPath))

View File

@ -771,6 +771,77 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
}
}
type RerankRequest struct {
Model string `json:"model"`
Query string `json:"query"`
TopN int `json:"top_n"` // return top N documents
Documents []string `json:"documents"` // list of documents to rerank
CachePrompt bool `json:"cache_prompt"`
}
type RerankResponse struct {
Results []struct {
Index int `json:"index"`
RelevanceScore float32 `json:"relevance_score"`
} `json:"results"`
}
func (s *Server) rerank(w http.ResponseWriter, r *http.Request) {
var req RerankRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("bad rereank request: %s", err), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
var rsp RerankResponse
rsp.Results = make([]struct {
Index int `json:"index"`
RelevanceScore float32 `json:"relevance_score"`
}, len(req.Documents))
for i, doc := range req.Documents {
// reranking prompt format: [BOS]query[EOS][SEP]doc[EOS]
p := ""
if !s.model.AddBOSToken() {
p += s.model.TokenToPiece(int(s.lc.GetTokenBOS()))
}
p += req.Query + s.model.TokenToPiece(int(s.lc.GetTokenEOS())) + s.model.TokenToPiece(int(s.lc.GetTokenSEP())) + doc
if !s.model.AddEOSToken() {
p += s.model.TokenToPiece(int(s.lc.GetTokenEOS()))
}
seq, err := s.NewSequence(p, nil, NewSequenceParams{embedding: true})
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
s.mu.Lock()
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return
}
s.seqs[i] = seq
s.cond.Signal()
break
}
}
s.mu.Unlock()
score := <-seq.embedding
rsp.Results[i].Index = i
rsp.Results[i].RelevanceScore = score[0]
}
if err := json.NewEncoder(w).Encode(&rsp); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
func (s *Server) loadModel(
params llama.ModelParams,
mpath string,
@ -780,6 +851,7 @@ func (s *Server) loadModel(
flashAttention bool,
threads int,
multiUserCache bool,
reranking bool,
) {
llama.BackendInit()
@ -789,7 +861,7 @@ func (s *Server) loadModel(
panic(err)
}
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention)
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, reranking)
s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
if err != nil {
panic(err)
@ -838,6 +910,7 @@ func main() {
tensorSplit := flag.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
requirements := flag.Bool("requirements", false, "print json requirement information")
reranking := flag.Bool("reranking", false, "enable reranking (default: disabled)")
flag.Parse()
if *requirements {
@ -893,7 +966,7 @@ func main() {
}
server.ready.Add(1)
go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache)
go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache, *reranking)
server.cond = sync.NewCond(&server.mu)
@ -912,6 +985,7 @@ func main() {
mux.HandleFunc("/embedding", server.embeddings)
mux.HandleFunc("/completion", server.completion)
mux.HandleFunc("/health", server.health)
mux.HandleFunc("/rerank", server.rerank)
httpServer := http.Server{
Handler: mux,