go runner: add rerank support
Co-authored-by Craig Hughes Signed-off-by: Liu Yuan <namei.unix@gmail.com>
This commit is contained in:
parent
3d25e7bf8c
commit
8f9e6ec7d1
@ -138,7 +138,7 @@ type ContextParams struct {
|
||||
c C.struct_llama_context_params
|
||||
}
|
||||
|
||||
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool) ContextParams {
|
||||
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, reranking bool) ContextParams {
|
||||
params := C.llama_context_default_params()
|
||||
params.n_ctx = C.uint(numCtx)
|
||||
params.n_batch = C.uint(batchSize)
|
||||
@ -147,6 +147,9 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla
|
||||
params.n_threads_batch = params.n_threads
|
||||
params.embeddings = C.bool(true)
|
||||
params.flash_attn = C.bool(flashAttention)
|
||||
if reranking {
|
||||
params.pooling_type = C.LLAMA_POOLING_TYPE_RANK
|
||||
}
|
||||
return ContextParams{c: params}
|
||||
}
|
||||
|
||||
@ -212,6 +215,18 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
|
||||
return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
|
||||
}
|
||||
|
||||
func (c *Context) GetTokenBOS() C.llama_token {
|
||||
return C.llama_token_bos(c.Model().c)
|
||||
}
|
||||
|
||||
func (c *Context) GetTokenEOS() C.llama_token {
|
||||
return C.llama_token_eos(c.Model().c)
|
||||
}
|
||||
|
||||
func (c *Context) GetTokenSEP() C.llama_token {
|
||||
return C.llama_token_sep(c.Model().c)
|
||||
}
|
||||
|
||||
type ModelParams struct {
|
||||
NumGpuLayers int
|
||||
MainGpu int
|
||||
@ -296,6 +311,10 @@ func (m *Model) AddBOSToken() bool {
|
||||
return bool(C.llama_add_bos_token(m.c))
|
||||
}
|
||||
|
||||
func (m *Model) AddEOSToken() bool {
|
||||
return bool(C.llama_add_eos_token(m.c))
|
||||
}
|
||||
|
||||
func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
|
||||
cLoraPath := C.CString(loraPath)
|
||||
defer C.free(unsafe.Pointer(cLoraPath))
|
||||
|
@ -771,6 +771,77 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
type RerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
TopN int `json:"top_n"` // return top N documents
|
||||
Documents []string `json:"documents"` // list of documents to rerank
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
}
|
||||
|
||||
type RerankResponse struct {
|
||||
Results []struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float32 `json:"relevance_score"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
func (s *Server) rerank(w http.ResponseWriter, r *http.Request) {
|
||||
var req RerankRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("bad rereank request: %s", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var rsp RerankResponse
|
||||
rsp.Results = make([]struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float32 `json:"relevance_score"`
|
||||
}, len(req.Documents))
|
||||
|
||||
for i, doc := range req.Documents {
|
||||
// reranking prompt format: [BOS]query[EOS][SEP]doc[EOS]
|
||||
p := ""
|
||||
if !s.model.AddBOSToken() {
|
||||
p += s.model.TokenToPiece(int(s.lc.GetTokenBOS()))
|
||||
}
|
||||
p += req.Query + s.model.TokenToPiece(int(s.lc.GetTokenEOS())) + s.model.TokenToPiece(int(s.lc.GetTokenSEP())) + doc
|
||||
if !s.model.AddEOSToken() {
|
||||
p += s.model.TokenToPiece(int(s.lc.GetTokenEOS()))
|
||||
}
|
||||
seq, err := s.NewSequence(p, nil, NewSequenceParams{embedding: true})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
s.seqs[i] = seq
|
||||
s.cond.Signal()
|
||||
break
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
score := <-seq.embedding
|
||||
rsp.Results[i].Index = i
|
||||
rsp.Results[i].RelevanceScore = score[0]
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&rsp); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) loadModel(
|
||||
params llama.ModelParams,
|
||||
mpath string,
|
||||
@ -780,6 +851,7 @@ func (s *Server) loadModel(
|
||||
flashAttention bool,
|
||||
threads int,
|
||||
multiUserCache bool,
|
||||
reranking bool,
|
||||
) {
|
||||
llama.BackendInit()
|
||||
|
||||
@ -789,7 +861,7 @@ func (s *Server) loadModel(
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention)
|
||||
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, reranking)
|
||||
s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@ -838,6 +910,7 @@ func main() {
|
||||
tensorSplit := flag.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
|
||||
multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
|
||||
requirements := flag.Bool("requirements", false, "print json requirement information")
|
||||
reranking := flag.Bool("reranking", false, "enable reranking (default: disabled)")
|
||||
|
||||
flag.Parse()
|
||||
if *requirements {
|
||||
@ -893,7 +966,7 @@ func main() {
|
||||
}
|
||||
|
||||
server.ready.Add(1)
|
||||
go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache)
|
||||
go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache, *reranking)
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
@ -912,6 +985,7 @@ func main() {
|
||||
mux.HandleFunc("/embedding", server.embeddings)
|
||||
mux.HandleFunc("/completion", server.completion)
|
||||
mux.HandleFunc("/health", server.health)
|
||||
mux.HandleFunc("/rerank", server.rerank)
|
||||
|
||||
httpServer := http.Server{
|
||||
Handler: mux,
|
||||
|
Loading…
x
Reference in New Issue
Block a user