From 46a7c682f217ef8b396058505c8e6e23254713ae Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 27 Aug 2024 13:59:33 -0700 Subject: [PATCH] runner.go: Fix embeddings endpoint The embeddings endpoint only takes a single input and provides a single output, instead of multiple as the current implementation expected. Fixing this also allows the implementation to be simplified and a few embedding-specific issues to be addressed. --- llama/llama.go | 2 +- llama/runner/runner.go | 53 ++++++++++++++++++------------------------ 2 files changed, 24 insertions(+), 31 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index 92adef86..704d9e8f 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -429,7 +429,7 @@ func NewSamplingContext(params SamplingParams) *SamplingContext { } func (s *SamplingContext) Free() { - if s.c != nil { + if s != nil { C.llama_sampling_cfree(s.c) } } diff --git a/llama/runner/runner.go b/llama/runner/runner.go index ecf9c9dc..0b25e42a 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -88,9 +88,15 @@ func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence if params.numKeep < 0 { params.numKeep = len(tokens) } - // Subtracting 4 ensures that at least 1 token can be discarded during shift - params.numKeep = min(params.numKeep, s.numCtx-4) - params.numKeep += s.bosToken + + if !params.embedding { + // Subtracting 4 ensures that at least 1 token can be discarded during shift + params.numKeep = min(params.numKeep, s.numCtx-4) + params.numKeep += s.bosToken + } else { + // Embeddings are 1 shot - just truncate to the context window, without ever shifting + params.numKeep = min(params.numKeep, s.numCtx) + } // truncate to fit in context window if len(tokens) > s.numCtx { @@ -523,14 +529,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } type EmbeddingRequest struct { - Content []string `json:"content"` + Content string `json:"content"` } type EmbeddingResponse struct { - Embedding [][]float32 `json:"embedding"` + Embedding []float32 `json:"embedding"` } -// TODO (jmorganca): is it safe to do this concurrently with decoding? func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { var req EmbeddingRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { @@ -541,36 +546,24 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") slog.Debug("embedding request", "content", req.Content) - seqs := make([]*Sequence, len(req.Content)) - embeddings := make([][]float32, len(req.Content)) - var processed int - for i, content := range req.Content { - seqs[i] = s.NewSequence(content, NewSequenceParams{embedding: true}) - } - // TODO - refactor to go routines to add seq's and drain the responses - // so we don't stall until each set is iterated through - for processed < len(seqs) { - s.mu.Lock() - for i, sq := range s.seqs { - if processed >= len(seqs) { - break - } - if sq == nil { - s.seqs[i] = seqs[processed] - processed += 1 - } - } - s.cond.Signal() - s.mu.Unlock() + seq := s.NewSequence(req.Content, NewSequenceParams{embedding: true}) - for i := range processed { - embeddings[i] = <-seqs[i].embedding + // TODO (jessegross): Wait for a free slot instead of failing and blocking forever + s.mu.Lock() + for i, sq := range s.seqs { + if sq == nil { + s.seqs[i] = seq + s.cond.Signal() + break } } + s.mu.Unlock() + + embedding := <-seq.embedding if err := json.NewEncoder(w).Encode(&EmbeddingResponse{ - Embedding: embeddings, + Embedding: embedding, }); err != nil { log.Println("Failed to encode result:", err) return