From ad7e64181525cd22f312f5d053aaf4eb84685b0b Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sun, 14 Apr 2024 20:53:20 -0400 Subject: [PATCH] add batch embeddings --- api/types.go | 10 +++-- docs/api.md | 33 +++++++++++++- integration/embedding_test.go | 64 +++++++++++++++++++++++++++ integration/utils_test.go | 83 +++++++++++++++++++++++++++++++++++ llm/ext_server/server.cpp | 55 ++++++----------------- llm/server.go | 22 +++++----- server/routes.go | 44 +++++++++++++------ server/sched_test.go | 4 +- 8 files changed, 243 insertions(+), 72 deletions(-) create mode 100644 integration/embedding_test.go diff --git a/api/types.go b/api/types.go index 9200949c..e5c1e41f 100644 --- a/api/types.go +++ b/api/types.go @@ -159,15 +159,17 @@ type Runner struct { } type EmbeddingRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - KeepAlive *Duration `json:"keep_alive,omitempty"` + Model string `json:"model"` + Prompt string `json:"prompt,omitempty"` + PromptBatch []string `json:"prompt_batch,omitempty"` + KeepAlive *Duration `json:"keep_alive,omitempty"` Options map[string]interface{} `json:"options"` } type EmbeddingResponse struct { - Embedding []float64 `json:"embedding"` + Embedding []float64 `json:"embedding,omitempty"` + EmbeddingBatch [][]float64 `json:"embedding_batch,omitempty"` } type CreateRequest struct { diff --git a/docs/api.md b/docs/api.md index 5fc946ce..2364163a 100644 --- a/docs/api.md +++ b/docs/api.md @@ -1010,7 +1010,8 @@ Generate embeddings from a model ### Parameters - `model`: name of model to generate embeddings from -- `prompt`: text to generate embeddings for +- `prompt`: string to generate the embedding for +- `prompts`: array of strings to generate a batch of embeddings for Advanced parameters: @@ -1038,3 +1039,33 @@ curl http://localhost:11434/api/embeddings -d '{ ] } ``` + + +#### Request (batch) + +```shell +curl http://localhost:11434/api/embeddings -d '{ + "model": "all-minilm", + "prompt_batch": [ + "Here is an article about llamas...", + "Here is another article about llamas..." + ] +}' +``` + +#### Response + +```json +{ + "embedding_batch": [ + [ + 0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313, + 0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281 + ], + [ + 0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313, + 0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281 + ], + ] +} +``` diff --git a/integration/embedding_test.go b/integration/embedding_test.go new file mode 100644 index 00000000..49e15b4b --- /dev/null +++ b/integration/embedding_test.go @@ -0,0 +1,64 @@ +//go:build integration + +package integration + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/ollama/ollama/api" +) + +func TestAllMiniLMEmbedding(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + req := api.EmbeddingRequest{ + Model: "all-minilm", + Prompt: "why is the sky blue?", + Options: map[string]interface{}{ + "temperature": 0, + "seed": 123, + }, + } + + res := EmbeddingTestHelper(ctx, t, &http.Client{}, req) + + if len(res.Embedding) != 384 { + t.Fatalf("Expected 384 floats to be returned, got %v", len(res.Embedding)) + } + + if res.Embedding[0] != 0.146763876080513 { + t.Fatalf("Expected first embedding float to be 0.146763876080513, got %v", res.Embedding[0]) + } +} + +func TestAllMiniLMEmbeddings(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + req := api.EmbeddingRequest{ + Model: "all-minilm", + Prompts: []string{"why is the sky blue?", "why is the sky blue?"}, + Options: map[string]interface{}{ + "temperature": 0, + "seed": 123, + }, + } + + res := EmbeddingTestHelper(ctx, t, &http.Client{}, req) + + if len(res.Embeddings) != 2 { + t.Fatal("Expected 2 embeddings to be returned") + } + + if len(res.Embeddings[0]) != 384 { + t.Fatalf("Expected first embedding to have 384 floats, got %v", len(res.Embeddings[0])) + } + + if res.Embeddings[0][0] != 0.146763876080513 && res.Embeddings[1][0] != 0.146763876080513 { + t.Fatalf("Expected first embedding floats to be 0.146763876080513, got %v, %v", res.Embeddings[0][0], res.Embeddings[1][0]) + } +} diff --git a/integration/utils_test.go b/integration/utils_test.go index 3e91187a..792e603d 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -5,6 +5,7 @@ package integration import ( "bytes" "context" + "encoding/json" "errors" "fmt" "io" @@ -24,6 +25,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/app/lifecycle" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -285,6 +287,7 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap // Generate a set of requests // By default each request uses orca-mini as the model func GenerateRequests() ([]api.GenerateRequest, [][]string) { + stream := false return []api.GenerateRequest{ { Model: "orca-mini", @@ -336,3 +339,83 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { []string{"nitrogen", "oxygen", "carbon", "dioxide"}, } } + +func EmbeddingTestHelper(ctx context.Context, t *testing.T, client *http.Client, req api.EmbeddingRequest) api.EmbeddingResponse { + + // TODO maybe stuff in an init routine? + lifecycle.InitLogging() + + requestJSON, err := json.Marshal(req) + if err != nil { + t.Fatalf("Error serializing request: %v", err) + } + defer func() { + if os.Getenv("OLLAMA_TEST_EXISTING") == "" { + defer serverProcMutex.Unlock() + if t.Failed() { + fp, err := os.Open(lifecycle.ServerLogFile) + if err != nil { + slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err) + return + } + data, err := io.ReadAll(fp) + if err != nil { + slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err) + return + } + slog.Warn("SERVER LOG FOLLOWS") + os.Stderr.Write(data) + slog.Warn("END OF SERVER") + } + err = os.Remove(lifecycle.ServerLogFile) + if err != nil && !os.IsNotExist(err) { + slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err) + } + } + }() + scheme, testEndpoint := GetTestEndpoint() + + if os.Getenv("OLLAMA_TEST_EXISTING") == "" { + serverProcMutex.Lock() + fp, err := os.CreateTemp("", "ollama-server-*.log") + if err != nil { + t.Fatalf("failed to generate log file: %s", err) + } + lifecycle.ServerLogFile = fp.Name() + fp.Close() + assert.NoError(t, StartServer(ctx, testEndpoint)) + } + + err = PullIfMissing(ctx, client, scheme, testEndpoint, req.Model) + if err != nil { + t.Fatalf("Error pulling model: %v", err) + } + + // Make the request and get the response + httpReq, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/embeddings", bytes.NewReader(requestJSON)) + if err != nil { + t.Fatalf("Error creating request: %v", err) + } + + // Set the content type for the request + httpReq.Header.Set("Content-Type", "application/json") + + // Make the request with the HTTP client + response, err := client.Do(httpReq.WithContext(ctx)) + if err != nil { + t.Fatalf("Error making request: %v", err) + } + defer response.Body.Close() + body, err := io.ReadAll(response.Body) + assert.NoError(t, err) + assert.Equal(t, response.StatusCode, 200, string(body)) + + // Verify the response is valid JSON + var res api.EmbeddingResponse + err = json.Unmarshal(body, &res) + if err != nil { + assert.NoError(t, err, body) + } + + return res +} diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index 22117037..93f3d2d0 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -3209,54 +3209,27 @@ int main(int argc, char **argv) { return res.set_content(data.dump(), "application/json; charset=utf-8"); }); - svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res) + svr.Post("/embeddings", [&llama](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); - json prompt; - if (body.count("content") != 0) - { - prompt = body["content"]; - } - else - { - prompt = ""; + + const int id = llama.queue_tasks.get_new_id(); + llama.queue_results.add_waiting_task_id(id); + llama.request_completion(id, {{"prompt", body["contents"]}}, false, true, -1); + + task_result recv = llama.queue_results.recv(id); + llama.queue_results.remove_waiting_task_id(id); + + json embeddings = json::array(); + for (auto & elem : recv.result_json["results"]) { + embeddings.push_back(json_value(elem, "embedding", json::array())); } - json image_data; - if (body.count("image_data") != 0) { - image_data = body["image_data"]; - } - else - { - image_data = ""; - } - - // create and queue the task - const int task_id = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1); - - // get the result - task_result result = llama.queue_results.recv(task_id); - llama.queue_results.remove_waiting_task_id(task_id); - - // send the result - return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); + json result = json{{"embeddings", embeddings}}; + return res.set_content(result.dump(), "application/json; charset=utf-8"); }); - // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? - // "Bus error: 10" - this is on macOS, it does not crash on Linux - //std::thread t2([&]() - /*{ - bool running = true; - while (running) - { - running = llama.update_slots(); - } - }*/ - //); - if (sparams.n_threads_http < 1) { // +2 threads for monitoring endpoints sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); diff --git a/llm/server.go b/llm/server.go index 84babe46..3e289a17 100644 --- a/llm/server.go +++ b/llm/server.go @@ -32,7 +32,7 @@ type LlamaServer interface { Ping(ctx context.Context) error WaitUntilRunning(ctx context.Context) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error - Embedding(ctx context.Context, prompt string) ([]float64, error) + Embeddings(ctx context.Context, prompt []string) ([][]float64, error) Tokenize(ctx context.Context, content string) ([]int, error) Detokenize(ctx context.Context, tokens []int) (string, error) Close() error @@ -736,15 +736,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return fmt.Errorf("max retries exceeded") } -type EmbeddingRequest struct { - Content string `json:"content"` +type EmbeddingsRequest struct { + Contents []string `json:"contents"` } -type EmbeddingResponse struct { - Embedding []float64 `json:"embedding"` +type EmbeddingsResponse struct { + Embeddings [][]float64 `json:"embeddings"` } -func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) { +func (s *llmServer) Embeddings(ctx context.Context, prompts []string) ([][]float64, error) { if err := s.sem.Acquire(ctx, 1); err != nil { slog.Error("Failed to acquire semaphore", "error", err) return nil, err @@ -758,12 +758,12 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) } - data, err := json.Marshal(TokenizeRequest{Content: prompt}) + data, err := json.Marshal(EmbeddingsRequest{Contents: prompts}) if err != nil { return nil, fmt.Errorf("error marshaling embed data: %w", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embeddings", s.port), bytes.NewBuffer(data)) if err != nil { return nil, fmt.Errorf("error creating embed request: %w", err) } @@ -780,17 +780,19 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er return nil, fmt.Errorf("error reading embed response: %w", err) } + fmt.Println("embeddings response", string(body)) + if resp.StatusCode >= 400 { log.Printf("llm encode error: %s", body) return nil, fmt.Errorf("%s", body) } - var embedding EmbeddingResponse + var embedding EmbeddingsResponse if err := json.Unmarshal(body, &embedding); err != nil { return nil, fmt.Errorf("unmarshal tokenize response: %w", err) } - return embedding.Embedding, nil + return embedding.Embeddings, nil } type TokenizeRequest struct { diff --git a/server/routes.go b/server/routes.go index b1962d23..000cc098 100644 --- a/server/routes.go +++ b/server/routes.go @@ -403,23 +403,39 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - // an empty request loads the model - if req.Prompt == "" { - c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}}) - return - } + switch { + // single embedding + case len(req.Prompt) > 0: + embeddings, err := runner.llama.Embeddings(c.Request.Context(), []string{req.Prompt}) + if err != nil { + slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) + return + } - embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt) - if err != nil { - slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) - return - } + resp := api.EmbeddingResponse{Embedding: embeddings[0]} + c.JSON(http.StatusOK, resp) - resp := api.EmbeddingResponse{ - Embedding: embedding, + // batch embeddings + case len(req.PromptBatch) > 0: + embeddings, err := runner.llama.Embeddings(c.Request.Context(), req.PromptBatch) + if err != nil { + slog.Info(fmt.Sprintf("batch embedding generation failed: %v", err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) + return + } + + resp := api.EmbeddingResponse{EmbeddingBatch: embeddings} + c.JSON(http.StatusOK, resp) + + // empty prompt loads the model + default: + if req.PromptBatch != nil { + c.JSON(http.StatusOK, api.EmbeddingResponse{EmbeddingBatch: [][]float64{}}) + } else { + c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}}) + } } - c.JSON(http.StatusOK, resp) } func (s *Server) PullModelHandler(c *gin.Context) { diff --git a/server/sched_test.go b/server/sched_test.go index 86bd7846..98aa326b 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -530,7 +530,7 @@ type mockLlm struct { pingResp error waitResp error completionResp error - embeddingResp []float64 + embeddingResp [][]float64 embeddingRespErr error tokenizeResp []int tokenizeRespErr error @@ -546,7 +546,7 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { return s.completionResp } -func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) { +func (s *mockLlm) Embeddings(ctx context.Context, prompts []string) ([][]float64, error) { return s.embeddingResp, s.embeddingRespErr } func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {