Merge bfae776f34250386c0e05ae4e18126d42cfb940e into 67691e410db7a50b07a64858820b14de9aa91314
This commit is contained in:
commit
9cd032b57e
17
api/types.go
17
api/types.go
@ -242,6 +242,7 @@ type Runner struct {
|
||||
UseMMap *bool `json:"use_mmap,omitempty"`
|
||||
UseMLock bool `json:"use_mlock,omitempty"`
|
||||
NumThread int `json:"num_thread,omitempty"`
|
||||
Reranking bool `json:"reranking,omitempty"`
|
||||
}
|
||||
|
||||
// EmbedRequest is the request passed to [Client.Embed].
|
||||
@ -293,6 +294,22 @@ type EmbeddingResponse struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
type RerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
TopN int `json:"top_n"` // return top N documents
|
||||
Documents []string `json:"documents"` // list of documents to rerank
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
Options map[string]interface{} `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type RerankResponse struct {
|
||||
Results []struct {
|
||||
Document string `json:"document"`
|
||||
RelevanceScore float32 `json:"relevance_score"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
// CreateRequest is the request passed to [Client.Create].
|
||||
type CreateRequest struct {
|
||||
Model string `json:"model"`
|
||||
|
@ -140,7 +140,7 @@ type ContextParams struct {
|
||||
c C.struct_llama_context_params
|
||||
}
|
||||
|
||||
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool) ContextParams {
|
||||
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, reranking bool) ContextParams {
|
||||
params := C.llama_context_default_params()
|
||||
params.n_ctx = C.uint(numCtx)
|
||||
params.n_batch = C.uint(batchSize)
|
||||
@ -149,6 +149,9 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla
|
||||
params.n_threads_batch = params.n_threads
|
||||
params.embeddings = C.bool(true)
|
||||
params.flash_attn = C.bool(flashAttention)
|
||||
if reranking {
|
||||
params.pooling_type = C.LLAMA_POOLING_TYPE_RANK
|
||||
}
|
||||
return ContextParams{c: params}
|
||||
}
|
||||
|
||||
@ -214,6 +217,18 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
|
||||
return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
|
||||
}
|
||||
|
||||
func (c *Context) GetTokenBOS() C.llama_token {
|
||||
return C.llama_token_bos(c.Model().c)
|
||||
}
|
||||
|
||||
func (c *Context) GetTokenEOS() C.llama_token {
|
||||
return C.llama_token_eos(c.Model().c)
|
||||
}
|
||||
|
||||
func (c *Context) GetTokenSEP() C.llama_token {
|
||||
return C.llama_token_sep(c.Model().c)
|
||||
}
|
||||
|
||||
type ModelParams struct {
|
||||
NumGpuLayers int
|
||||
MainGpu int
|
||||
@ -298,6 +313,10 @@ func (m *Model) AddBOSToken() bool {
|
||||
return bool(C.llama_add_bos_token(m.c))
|
||||
}
|
||||
|
||||
func (m *Model) AddEOSToken() bool {
|
||||
return bool(C.llama_add_eos_token(m.c))
|
||||
}
|
||||
|
||||
func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
|
||||
cLoraPath := C.CString(loraPath)
|
||||
defer C.free(unsafe.Pointer(cLoraPath))
|
||||
|
@ -753,6 +753,77 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
type RerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
TopN int `json:"top_n"` // return top N documents
|
||||
Documents []string `json:"documents"` // list of documents to rerank
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
}
|
||||
|
||||
type RerankResponse struct {
|
||||
Results []struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float32 `json:"relevance_score"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
func (s *Server) rerank(w http.ResponseWriter, r *http.Request) {
|
||||
var req RerankRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("bad rereank request: %s", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var rsp RerankResponse
|
||||
rsp.Results = make([]struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float32 `json:"relevance_score"`
|
||||
}, len(req.Documents))
|
||||
|
||||
for i, doc := range req.Documents {
|
||||
// reranking prompt format: [BOS]query[EOS][SEP]doc[EOS]
|
||||
p := ""
|
||||
if !s.model.AddBOSToken() {
|
||||
p += s.model.TokenToPiece(int(s.lc.GetTokenBOS()))
|
||||
}
|
||||
p += req.Query + s.model.TokenToPiece(int(s.lc.GetTokenEOS())) + s.model.TokenToPiece(int(s.lc.GetTokenSEP())) + doc
|
||||
if !s.model.AddEOSToken() {
|
||||
p += s.model.TokenToPiece(int(s.lc.GetTokenEOS()))
|
||||
}
|
||||
seq, err := s.NewSequence(p, nil, NewSequenceParams{embedding: true})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
s.seqs[i] = seq
|
||||
s.cond.Signal()
|
||||
break
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
score := <-seq.embedding
|
||||
rsp.Results[i].Index = i
|
||||
rsp.Results[i].RelevanceScore = score[0]
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&rsp); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) loadModel(
|
||||
params llama.ModelParams,
|
||||
mpath string,
|
||||
@ -762,6 +833,7 @@ func (s *Server) loadModel(
|
||||
flashAttention bool,
|
||||
threads int,
|
||||
multiUserCache bool,
|
||||
reranking bool,
|
||||
) {
|
||||
llama.BackendInit()
|
||||
|
||||
@ -771,7 +843,7 @@ func (s *Server) loadModel(
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention)
|
||||
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, reranking)
|
||||
s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@ -819,6 +891,7 @@ func main() {
|
||||
tensorSplit := flag.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
|
||||
multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
|
||||
requirements := flag.Bool("requirements", false, "print json requirement information")
|
||||
reranking := flag.Bool("reranking", false, "enable reranking (default: disabled)")
|
||||
|
||||
flag.Parse()
|
||||
if *requirements {
|
||||
@ -874,7 +947,7 @@ func main() {
|
||||
}
|
||||
|
||||
server.ready.Add(1)
|
||||
go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache)
|
||||
go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache, *reranking)
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
@ -893,6 +966,7 @@ func main() {
|
||||
mux.HandleFunc("/embedding", server.embeddings)
|
||||
mux.HandleFunc("/completion", server.completion)
|
||||
mux.HandleFunc("/health", server.health)
|
||||
mux.HandleFunc("/rerank", server.rerank)
|
||||
|
||||
httpServer := http.Server{
|
||||
Handler: mux,
|
||||
|
@ -38,6 +38,7 @@ type LlamaServer interface {
|
||||
WaitUntilRunning(ctx context.Context) error
|
||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||
Embedding(ctx context.Context, input string) ([]float32, error)
|
||||
Rerank(ctx context.Context, req RerankRequest, fn func(RerankResponse)) error
|
||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||
Close() error
|
||||
@ -188,6 +189,10 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
||||
"--batch-size", strconv.Itoa(opts.NumBatch),
|
||||
}
|
||||
|
||||
if opts.Reranking {
|
||||
params = append(params, "--reranking")
|
||||
}
|
||||
|
||||
if opts.NumGPU >= 0 {
|
||||
params = append(params, "--n-gpu-layers", strconv.Itoa(opts.NumGPU))
|
||||
}
|
||||
@ -911,6 +916,69 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
|
||||
return e.Embedding, nil
|
||||
}
|
||||
|
||||
type RerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
TopN int `json:"top_n"` // return top N documents
|
||||
Documents []string `json:"documents"` // list of documents to rerank
|
||||
}
|
||||
type RerankResponse struct {
|
||||
Results []struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float32 `json:"relevance_score"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Rerank(ctx context.Context, req RerankRequest, fn func(RerankResponse)) error {
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
return err
|
||||
}
|
||||
defer s.sem.Release(1)
|
||||
|
||||
status, err := s.getServerStatusRetry(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if status != ServerStatusReady {
|
||||
return fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshaling rerank data: %w", err)
|
||||
}
|
||||
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/rerank", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating rerank request: %w", err)
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("do rerank request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading rerank response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm rerank error: %s", body)
|
||||
return fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var rr RerankResponse
|
||||
if err := json.Unmarshal(body, &rr); err != nil {
|
||||
return fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||
}
|
||||
fn(rr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type TokenizeRequest struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@ -348,6 +349,79 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func (s *Server) RerankHandler(c *gin.Context) {
|
||||
var req api.RerankRequest
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(req.Documents) == 0:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Documents cannot be empty"})
|
||||
return
|
||||
case req.Query == "":
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Query cannot be empty"})
|
||||
return
|
||||
case req.TopN < 0:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "TopN cannot be negative"})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Options == nil {
|
||||
req.Options = make(map[string]any)
|
||||
}
|
||||
req.Options["reranking"] = true
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
llmreq := llm.RerankRequest{
|
||||
Model: req.Model,
|
||||
Query: req.Query,
|
||||
TopN: req.TopN,
|
||||
Documents: req.Documents,
|
||||
}
|
||||
err = r.Rerank(c.Request.Context(), llmreq, func(rr llm.RerankResponse) {
|
||||
sort.SliceStable(rr.Results, func(i, j int) bool {
|
||||
return rr.Results[i].RelevanceScore > rr.Results[j].RelevanceScore
|
||||
})
|
||||
|
||||
var topn int
|
||||
if req.TopN == 0 {
|
||||
topn = len(rr.Results) // if TopN is unset, return all results
|
||||
} else {
|
||||
topn = min(len(rr.Results), req.TopN)
|
||||
}
|
||||
topResults := rr.Results[:topn]
|
||||
|
||||
rsp := api.RerankResponse{
|
||||
Results: make([]struct {
|
||||
Document string `json:"document"`
|
||||
RelevanceScore float32 `json:"relevance_score"`
|
||||
}, topn),
|
||||
}
|
||||
|
||||
for i, result := range topResults {
|
||||
rsp.Results[i].Document = req.Documents[result.Index]
|
||||
rsp.Results[i].RelevanceScore = result.RelevanceScore
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, rsp)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("rerank failed: %v", err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to rerank"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
var req api.EmbedRequest
|
||||
@ -1158,6 +1232,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
||||
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
r.POST("/api/rerank", s.RerankHandler)
|
||||
|
||||
// Compatibility endpoints
|
||||
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
||||
|
Loading…
x
Reference in New Issue
Block a user