From e0241118d053eed047ba5406c94d6cd38d394c2c Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Wed, 31 Jul 2024 11:08:09 -0700 Subject: [PATCH] Get embeddings working Truncation doesn't pass, but the other embeddings tests pass --- llama/runner/runner.go | 43 ++++++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 11c0e6c2..61f1e7f0 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -59,7 +59,7 @@ func (s *Sequence) prompt() bool { } func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence { - tokens, err := s.lc.Model().Tokenize(prompt, false, true) + tokens, err := s.lc.Model().Tokenize(prompt, embedding, true) if err != nil { panic(err) } @@ -345,11 +345,11 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } type EmbeddingRequest struct { - Prompt string `json:"prompt"` + Content []string `json:"content"` } type EmbeddingResponse struct { - Embedding []float32 `json:"embedding"` + Embedding [][]float32 `json:"embedding"` } // TODO (jmorganca): is it safe to do this concurrently with decoding? @@ -362,22 +362,37 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - seq := s.NewSequence(req.Prompt, 0, nil, nil, true) + slog.Debug("embedding request", "content", req.Content) + seqs := make([]*Sequence, len(req.Content)) + embeddings := make([][]float32, len(req.Content)) + var processed int + for i, content := range req.Content { + seqs[i] = s.NewSequence(content, 0, nil, nil, true) + } - s.mu.Lock() - for i, sq := range s.seqs { - if sq == nil { - s.seqs[i] = seq - s.cond.Signal() - break + // TODO - refactor to go routines to add seq's and drain the responses + // so we don't stall until each set is iterated through + for processed < len(seqs) { + s.mu.Lock() + for i, sq := range s.seqs { + if processed >= len(seqs) { + break + } + if sq == nil { + s.seqs[i] = seqs[processed] + processed += 1 + } + } + s.cond.Signal() + s.mu.Unlock() + + for i := range processed { + embeddings[i] = <-seqs[i].embedding } } - s.mu.Unlock() - - embedding := <-seq.embedding if err := json.NewEncoder(w).Encode(&EmbeddingResponse{ - Embedding: embedding, + Embedding: embeddings, }); err != nil { log.Println("Failed to encode result:", err) return