From 00d06619a11356a155362013b8fc0bc9d0d8a146 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Mon, 4 Dec 2023 21:16:27 -0800 Subject: [PATCH] Revert "chat api (#991)" while context variable is fixed This reverts commit 7a0899d62dee8a55810446dd7655b9e682ddf8ac. --- api/client.go | 13 -- api/types.go | 74 +++------- docs/api.md | 152 ++------------------ llm/llama.go | 54 +++----- llm/llm.go | 2 +- server/images.go | 85 +++--------- server/images_test.go | 10 +- server/routes.go | 313 +++++++++--------------------------------- 8 files changed, 144 insertions(+), 559 deletions(-) diff --git a/api/client.go b/api/client.go index 250711dd..44af222c 100644 --- a/api/client.go +++ b/api/client.go @@ -221,19 +221,6 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate }) } -type ChatResponseFunc func(ChatResponse) error - -func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error { - return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error { - var resp ChatResponse - if err := json.Unmarshal(bts, &resp); err != nil { - return err - } - - return fn(resp) - }) -} - type PullProgressFunc func(ProgressResponse) error func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error { diff --git a/api/types.go b/api/types.go index 14a7059e..692c4445 100644 --- a/api/types.go +++ b/api/types.go @@ -36,7 +36,7 @@ type GenerateRequest struct { Prompt string `json:"prompt"` System string `json:"system"` Template string `json:"template"` - Context []int `json:"context,omitempty"` // DEPRECATED: context is deprecated, use the /chat endpoint instead for chat history + Context []int `json:"context,omitempty"` Stream *bool `json:"stream,omitempty"` Raw bool `json:"raw,omitempty"` Format string `json:"format"` @@ -44,41 +44,6 @@ type GenerateRequest struct { Options map[string]interface{} `json:"options"` } -type ChatRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Template string `json:"template"` - Stream *bool `json:"stream,omitempty"` - Format string `json:"format"` - - Options map[string]interface{} `json:"options"` -} - -type Message struct { - Role string `json:"role"` // one of ["system", "user", "assistant"] - Content string `json:"content"` -} - -type ChatResponse struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - Message *Message `json:"message,omitempty"` - - Done bool `json:"done"` - Context []int `json:"context,omitempty"` - - EvalMetrics -} - -type EvalMetrics struct { - TotalDuration time.Duration `json:"total_duration,omitempty"` - LoadDuration time.Duration `json:"load_duration,omitempty"` - PromptEvalCount int `json:"prompt_eval_count,omitempty"` - PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` - EvalCount int `json:"eval_count,omitempty"` - EvalDuration time.Duration `json:"eval_duration,omitempty"` -} - // Options specfied in GenerateRequest, if you add a new option here add it to the API docs also type Options struct { Runner @@ -208,34 +173,39 @@ type GenerateResponse struct { Done bool `json:"done"` Context []int `json:"context,omitempty"` - EvalMetrics + TotalDuration time.Duration `json:"total_duration,omitempty"` + LoadDuration time.Duration `json:"load_duration,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration time.Duration `json:"eval_duration,omitempty"` } -func (m *EvalMetrics) Summary() { - if m.TotalDuration > 0 { - fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration) +func (r *GenerateResponse) Summary() { + if r.TotalDuration > 0 { + fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration) } - if m.LoadDuration > 0 { - fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration) + if r.LoadDuration > 0 { + fmt.Fprintf(os.Stderr, "load duration: %v\n", r.LoadDuration) } - if m.PromptEvalCount > 0 { - fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", m.PromptEvalCount) + if r.PromptEvalCount > 0 { + fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount) } - if m.PromptEvalDuration > 0 { - fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", m.PromptEvalDuration) - fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(m.PromptEvalCount)/m.PromptEvalDuration.Seconds()) + if r.PromptEvalDuration > 0 { + fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration) + fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds()) } - if m.EvalCount > 0 { - fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", m.EvalCount) + if r.EvalCount > 0 { + fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", r.EvalCount) } - if m.EvalDuration > 0 { - fmt.Fprintf(os.Stderr, "eval duration: %s\n", m.EvalDuration) - fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(m.EvalCount)/m.EvalDuration.Seconds()) + if r.EvalDuration > 0 { + fmt.Fprintf(os.Stderr, "eval duration: %s\n", r.EvalDuration) + fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds()) } } diff --git a/docs/api.md b/docs/api.md index 9e39cb9b..0595fadd 100644 --- a/docs/api.md +++ b/docs/api.md @@ -24,7 +24,7 @@ All durations are returned in nanoseconds. ### Streaming responses -Certain endpoints stream responses as JSON objects. +Certain endpoints stream responses as JSON objects delineated with the newline (`\n`) character. ## Generate a completion @@ -32,12 +32,10 @@ Certain endpoints stream responses as JSON objects. POST /api/generate ``` -Generate a response for a given prompt with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request. +Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses. The final response object will include statistics and additional data from the request. ### Parameters -`model` is required. - - `model`: (required) the [model name](#model-names) - `prompt`: the prompt to generate a response for @@ -45,10 +43,11 @@ Advanced parameters (optional): - `format`: the format to return a response in. Currently the only accepted value is `json` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` -- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`) - `system`: system prompt to (overrides what is defined in the `Modelfile`) +- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`) +- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects -- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API. +- `raw`: if `true` no formatting will be applied to the prompt and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing history yourself. ### JSON mode @@ -58,7 +57,7 @@ Enable JSON mode by setting the `format` parameter to `json`. This will structur ### Examples -#### Request (Prompt) +#### Request ```shell curl http://localhost:11434/api/generate -d '{ @@ -90,7 +89,7 @@ The final response in the stream also includes additional data about the generat - `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt - `eval_count`: number of tokens the response - `eval_duration`: time in nanoseconds spent generating the response -- `context`: deprecated, an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory +- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory - `response`: empty if the response was streamed, if not streamed, this will contain the full response To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`. @@ -115,8 +114,6 @@ To calculate how fast the response is generated in tokens per second (token/s), #### Request (No streaming) -A response can be recieved in one reply when streaming is off. - ```shell curl http://localhost:11434/api/generate -d '{ "model": "llama2", @@ -147,9 +144,9 @@ If `stream` is set to `false`, the response will be a single JSON object: } ``` -#### Request (Raw Mode) +#### Request (Raw mode) -In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting. +In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting and context. ```shell curl http://localhost:11434/api/generate -d '{ @@ -167,7 +164,6 @@ curl http://localhost:11434/api/generate -d '{ "model": "mistral", "created_at": "2023-11-03T15:36:02.583064Z", "response": " The sky appears blue because of a phenomenon called Rayleigh scattering.", - "context": [1, 2, 3], "done": true, "total_duration": 14648695333, "load_duration": 3302671417, @@ -279,6 +275,7 @@ curl http://localhost:11434/api/generate -d '{ "model": "llama2", "created_at": "2023-08-04T19:22:45.499127Z", "response": "The sky is blue because it is the color of the sky.", + "context": [1, 2, 3], "done": true, "total_duration": 5589157167, "load_duration": 3013701500, @@ -291,135 +288,6 @@ curl http://localhost:11434/api/generate -d '{ } ``` -## Send Chat Messages -```shell -POST /api/chat -``` - -Generate the next message in a chat with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request. - -### Parameters - -`model` is required. - -- `model`: (required) the [model name](#model-names) -- `messages`: the messages of the chat, this can be used to keep a chat memory - -Advanced parameters (optional): - -- `format`: the format to return a response in. Currently the only accepted value is `json` -- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` -- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`) -- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects - -### Examples - -#### Request -Send a chat message with a streaming response. - -```shell -curl http://localhost:11434/api/generate -d '{ - "model": "llama2", - "messages": [ - { - "role": "user", - "content": "why is the sky blue?" - } - ] -}' -``` - -#### Response - -A stream of JSON objects is returned: - -```json -{ - "model": "llama2", - "created_at": "2023-08-04T08:52:19.385406455-07:00", - "message": { - "role": "assisant", - "content": "The" - }, - "done": false -} -``` - -Final response: - -```json -{ - "model": "llama2", - "created_at": "2023-08-04T19:22:45.499127Z", - "done": true, - "total_duration": 5589157167, - "load_duration": 3013701500, - "sample_count": 114, - "sample_duration": 81442000, - "prompt_eval_count": 46, - "prompt_eval_duration": 1160282000, - "eval_count": 113, - "eval_duration": 1325948000 -} -``` - -#### Request (With History) -Send a chat message with a conversation history. - -```shell -curl http://localhost:11434/api/generate -d '{ - "model": "llama2", - "messages": [ - { - "role": "user", - "content": "why is the sky blue?" - }, - { - "role": "assistant", - "content": "due to rayleigh scattering." - }, - { - "role": "user", - "content": "how is that different than mie scattering?" - } - ] -}' -``` - -#### Response - -A stream of JSON objects is returned: - -```json -{ - "model": "llama2", - "created_at": "2023-08-04T08:52:19.385406455-07:00", - "message": { - "role": "assisant", - "content": "The" - }, - "done": false -} -``` - -Final response: - -```json -{ - "model": "llama2", - "created_at": "2023-08-04T19:22:45.499127Z", - "done": true, - "total_duration": 5589157167, - "load_duration": 3013701500, - "sample_count": 114, - "sample_duration": 81442000, - "prompt_eval_count": 46, - "prompt_eval_duration": 1160282000, - "eval_count": 113, - "eval_duration": 1325948000 -} -``` - ## Create a Model ```shell diff --git a/llm/llama.go b/llm/llama.go index 3cce7fef..fc033258 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -531,31 +531,21 @@ type prediction struct { const maxBufferSize = 512 * format.KiloByte -type PredictRequest struct { - Model string - Prompt string - Format string - CheckpointStart time.Time - CheckpointLoaded time.Time -} +func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error { + prevConvo, err := llm.Decode(ctx, prevContext) + if err != nil { + return err + } -type PredictResponse struct { - Model string - CreatedAt time.Time - TotalDuration time.Duration - LoadDuration time.Duration - Content string - Done bool - PromptEvalCount int - PromptEvalDuration time.Duration - EvalCount int - EvalDuration time.Duration - Context []int -} + // Remove leading spaces from prevConvo if present + prevConvo = strings.TrimPrefix(prevConvo, " ") + + var nextContext strings.Builder + nextContext.WriteString(prevConvo) + nextContext.WriteString(prompt) -func (llm *llama) Predict(ctx context.Context, predict PredictRequest, fn func(PredictResponse)) error { request := map[string]any{ - "prompt": predict.Prompt, + "prompt": nextContext.String(), "stream": true, "n_predict": llm.NumPredict, "n_keep": llm.NumKeep, @@ -577,7 +567,7 @@ func (llm *llama) Predict(ctx context.Context, predict PredictRequest, fn func(P "stop": llm.Stop, } - if predict.Format == "json" { + if format == "json" { request["grammar"] = jsonGrammar } @@ -634,25 +624,25 @@ func (llm *llama) Predict(ctx context.Context, predict PredictRequest, fn func(P } if p.Content != "" { - fn(PredictResponse{ - Model: predict.Model, - CreatedAt: time.Now().UTC(), - Content: p.Content, - }) + fn(api.GenerateResponse{Response: p.Content}) + nextContext.WriteString(p.Content) } if p.Stop { - fn(PredictResponse{ - Model: predict.Model, - CreatedAt: time.Now().UTC(), - TotalDuration: time.Since(predict.CheckpointStart), + embd, err := llm.Encode(ctx, nextContext.String()) + if err != nil { + return fmt.Errorf("encoding context: %v", err) + } + fn(api.GenerateResponse{ Done: true, + Context: embd, PromptEvalCount: p.Timings.PromptN, PromptEvalDuration: parseDurationMs(p.Timings.PromptMS), EvalCount: p.Timings.PredictedN, EvalDuration: parseDurationMs(p.Timings.PredictedMS), }) + return nil } } diff --git a/llm/llm.go b/llm/llm.go index 703ea012..4901d9fe 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -14,7 +14,7 @@ import ( ) type LLM interface { - Predict(context.Context, PredictRequest, func(PredictResponse)) error + Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error Embedding(context.Context, string) ([]float64, error) Encode(context.Context, string) ([]int, error) Decode(context.Context, []int) (string, error) diff --git a/server/images.go b/server/images.go index efc5e8bc..294fdf2b 100644 --- a/server/images.go +++ b/server/images.go @@ -47,82 +47,37 @@ type Model struct { Options map[string]interface{} } -type PromptVars struct { - System string - Prompt string - Response string -} +func (m *Model) Prompt(request api.GenerateRequest) (string, error) { + t := m.Template + if request.Template != "" { + t = request.Template + } -func (m *Model) Prompt(p PromptVars) (string, error) { - var prompt strings.Builder - tmpl, err := template.New("").Parse(m.Template) + tmpl, err := template.New("").Parse(t) if err != nil { return "", err } - if p.System == "" { - // use the default system prompt for this model if one is not specified - p.System = m.System + var vars struct { + First bool + System string + Prompt string + } + + vars.First = len(request.Context) == 0 + vars.System = m.System + vars.Prompt = request.Prompt + + if request.System != "" { + vars.System = request.System } var sb strings.Builder - if err := tmpl.Execute(&sb, p); err != nil { + if err := tmpl.Execute(&sb, vars); err != nil { return "", err } - prompt.WriteString(sb.String()) - prompt.WriteString(p.Response) - return prompt.String(), nil -} -func (m *Model) ChatPrompt(msgs []api.Message) (string, error) { - // build the prompt from the list of messages - var prompt strings.Builder - currentVars := PromptVars{} - - writePrompt := func() error { - p, err := m.Prompt(currentVars) - if err != nil { - return err - } - prompt.WriteString(p) - currentVars = PromptVars{} - return nil - } - - for _, msg := range msgs { - switch msg.Role { - case "system": - if currentVars.Prompt != "" || currentVars.System != "" { - if err := writePrompt(); err != nil { - return "", err - } - } - currentVars.System = msg.Content - case "user": - if currentVars.Prompt != "" || currentVars.System != "" { - if err := writePrompt(); err != nil { - return "", err - } - } - currentVars.Prompt = msg.Content - case "assistant": - currentVars.Response = msg.Content - if err := writePrompt(); err != nil { - return "", err - } - default: - return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) - } - } - - // Append the last set of vars if they are non-empty - if currentVars.Prompt != "" || currentVars.System != "" { - if err := writePrompt(); err != nil { - return "", err - } - } - - return prompt.String(), nil + return sb.String(), nil } type ManifestV2 struct { diff --git a/server/images_test.go b/server/images_test.go index 85e8d4bd..5e6a197b 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -2,15 +2,17 @@ package server import ( "testing" + + "github.com/jmorganca/ollama/api" ) func TestModelPrompt(t *testing.T) { - m := Model{ + var m Model + req := api.GenerateRequest{ Template: "a{{ .Prompt }}b", + Prompt: "

", } - s, err := m.Prompt(PromptVars{ - Prompt: "

", - }) + s, err := m.Prompt(req) if err != nil { t.Fatal(err) } diff --git a/server/routes.go b/server/routes.go index 385af66a..bc8ea804 100644 --- a/server/routes.go +++ b/server/routes.go @@ -60,26 +60,17 @@ var loaded struct { var defaultSessionDuration = 5 * time.Minute // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function -func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) { - model, err := GetModel(modelName) - if err != nil { - return nil, err - } - - workDir := c.GetString("workDir") - +func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error { opts := api.DefaultOptions() if err := opts.FromMap(model.Options); err != nil { log.Printf("could not load model options: %v", err) - return nil, err + return err } if err := opts.FromMap(reqOpts); err != nil { - return nil, err + return err } - ctx := c.Request.Context() - // check if the loaded model is still running in a subprocess, in case something unexpected happened if loaded.runner != nil { if err := loaded.runner.Ping(ctx); err != nil { @@ -115,7 +106,7 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName) } - return nil, err + return err } loaded.Model = model @@ -149,7 +140,7 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess } loaded.expireTimer.Reset(sessionDuration) - return model, nil + return nil } func GenerateHandler(c *gin.Context) { @@ -182,262 +173,88 @@ func GenerateHandler(c *gin.Context) { return } - sessionDuration := defaultSessionDuration - model, err := load(c, req.Model, req.Options, sessionDuration) + model, err := GetModel(req.Model) if err != nil { var pErr *fs.PathError - switch { - case errors.As(err, &pErr): + if errors.As(err, &pErr) { c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) - case errors.Is(err, api.ErrInvalidOpts): - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - default: - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - } - return - } - - // an empty request loads the model - if req.Prompt == "" && req.Template == "" && req.System == "" { - c.JSON(http.StatusOK, api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}) - return - } - - checkpointLoaded := time.Now() - - var prompt string - sendContext := false - switch { - case req.Raw: - prompt = req.Prompt - case req.Prompt != "": - if req.Template != "" { - // override the default model template - model.Template = req.Template - } - - var rebuild strings.Builder - if req.Context != nil { - // TODO: context is deprecated, at some point the context logic within this conditional should be removed - prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // Remove leading spaces from prevCtx if present - prevCtx = strings.TrimPrefix(prevCtx, " ") - rebuild.WriteString(prevCtx) - } - p, err := model.Prompt(PromptVars{ - System: req.System, - Prompt: req.Prompt, - }) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - rebuild.WriteString(p) - prompt = rebuild.String() - sendContext = true - } - - ch := make(chan any) - var generated strings.Builder - go func() { - defer close(ch) - - fn := func(r llm.PredictResponse) { - // Update model expiration - loaded.expireAt = time.Now().Add(sessionDuration) - loaded.expireTimer.Reset(sessionDuration) - - // Build up the full response - if _, err := generated.WriteString(r.Content); err != nil { - ch <- gin.H{"error": err.Error()} - return - } - - resp := api.GenerateResponse{ - Model: r.Model, - CreatedAt: r.CreatedAt, - Done: r.Done, - Response: r.Content, - EvalMetrics: api.EvalMetrics{ - TotalDuration: r.TotalDuration, - LoadDuration: r.LoadDuration, - PromptEvalCount: r.PromptEvalCount, - PromptEvalDuration: r.PromptEvalDuration, - EvalCount: r.EvalCount, - EvalDuration: r.EvalDuration, - }, - } - - if r.Done && sendContext { - embd, err := loaded.runner.Encode(c.Request.Context(), req.Prompt+generated.String()) - if err != nil { - ch <- gin.H{"error": err.Error()} - return - } - r.Context = embd - } - - ch <- resp - } - - // Start prediction - predictReq := llm.PredictRequest{ - Model: model.Name, - Prompt: prompt, - Format: req.Format, - CheckpointStart: checkpointStart, - CheckpointLoaded: checkpointLoaded, - } - if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { - ch <- gin.H{"error": err.Error()} - } - }() - - if req.Stream != nil && !*req.Stream { - // Wait for the channel to close - var r api.GenerateResponse - var sb strings.Builder - for resp := range ch { - var ok bool - if r, ok = resp.(api.GenerateResponse); !ok { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - sb.WriteString(r.Response) - } - r.Response = sb.String() - c.JSON(http.StatusOK, r) - return - } - - streamResponse(c, ch) -} - -func ChatHandler(c *gin.Context) { - loaded.mu.Lock() - defer loaded.mu.Unlock() - - checkpointStart := time.Now() - - var req api.ChatRequest - err := c.ShouldBindJSON(&req) - switch { - case errors.Is(err, io.EOF): - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) - return - case err != nil: - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - // validate the request - switch { - case req.Model == "": - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) - return - case len(req.Format) > 0 && req.Format != "json": - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"}) - return - } - - sessionDuration := defaultSessionDuration - model, err := load(c, req.Model, req.Options, sessionDuration) - if err != nil { - var pErr *fs.PathError - switch { - case errors.As(err, &pErr): - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) - case errors.Is(err, api.ErrInvalidOpts): - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - default: - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - } - return - } - - // an empty request loads the model - if len(req.Messages) == 0 { - c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}) - return - } - - checkpointLoaded := time.Now() - - if req.Template != "" { - // override the default model template - model.Template = req.Template - } - prompt, err := model.ChatPrompt(req.Messages) - if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - ch := make(chan any) + workDir := c.GetString("workDir") + // TODO: set this duration from the request if specified + sessionDuration := defaultSessionDuration + if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil { + if errors.Is(err, api.ErrInvalidOpts) { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + checkpointLoaded := time.Now() + + prompt := req.Prompt + if !req.Raw { + prompt, err = model.Prompt(req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + + ch := make(chan any) go func() { defer close(ch) + // an empty request loads the model + if req.Prompt == "" && req.Template == "" && req.System == "" { + ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true} + return + } - fn := func(r llm.PredictResponse) { - // Update model expiration + fn := func(r api.GenerateResponse) { loaded.expireAt = time.Now().Add(sessionDuration) loaded.expireTimer.Reset(sessionDuration) - resp := api.ChatResponse{ - Model: r.Model, - CreatedAt: r.CreatedAt, - Done: r.Done, - EvalMetrics: api.EvalMetrics{ - TotalDuration: r.TotalDuration, - LoadDuration: r.LoadDuration, - PromptEvalCount: r.PromptEvalCount, - PromptEvalDuration: r.PromptEvalDuration, - EvalCount: r.EvalCount, - EvalDuration: r.EvalDuration, - }, + r.Model = req.Model + r.CreatedAt = time.Now().UTC() + if r.Done { + r.TotalDuration = time.Since(checkpointStart) + r.LoadDuration = checkpointLoaded.Sub(checkpointStart) } - if !r.Done { - resp.Message = &api.Message{Role: "assistant", Content: r.Content} + if req.Raw { + // in raw mode the client must manage history on their own + r.Context = nil } - ch <- resp + ch <- r } - // Start prediction - predictReq := llm.PredictRequest{ - Model: model.Name, - Prompt: prompt, - Format: req.Format, - CheckpointStart: checkpointStart, - CheckpointLoaded: checkpointLoaded, - } - if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { + if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, req.Format, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() if req.Stream != nil && !*req.Stream { - // Wait for the channel to close - var r api.ChatResponse - var sb strings.Builder + var response api.GenerateResponse + generated := "" for resp := range ch { - var ok bool - if r, ok = resp.(api.ChatResponse); !ok { + if r, ok := resp.(api.GenerateResponse); ok { + generated += r.Response + response = r + } else { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - if r.Message != nil { - sb.WriteString(r.Message.Content) - } } - r.Message = &api.Message{Role: "assistant", Content: sb.String()} - c.JSON(http.StatusOK, r) + response.Response = generated + c.JSON(http.StatusOK, response) return } @@ -464,18 +281,15 @@ func EmbeddingHandler(c *gin.Context) { return } - sessionDuration := defaultSessionDuration - _, err = load(c, req.Model, req.Options, sessionDuration) + model, err := GetModel(req.Model) if err != nil { - var pErr *fs.PathError - switch { - case errors.As(err, &pErr): - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) - case errors.Is(err, api.ErrInvalidOpts): - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - default: - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - } + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + workDir := c.GetString("workDir") + if err := load(c.Request.Context(), workDir, model, req.Options, 5*time.Minute); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } @@ -953,7 +767,6 @@ func Serve(ln net.Listener, allowOrigins []string) error { r.POST("/api/pull", PullModelHandler) r.POST("/api/generate", GenerateHandler) - r.POST("/api/chat", ChatHandler) r.POST("/api/embeddings", EmbeddingHandler) r.POST("/api/create", CreateModelHandler) r.POST("/api/push", PushModelHandler)