diff --git a/api/types.go b/api/types.go index 2a36a1f6..1940ba40 100644 --- a/api/types.go +++ b/api/types.go @@ -31,18 +31,24 @@ func (e StatusError) Error() string { } type GenerateRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - System string `json:"system"` - Template string `json:"template"` - Context []int `json:"context,omitempty"` - Stream *bool `json:"stream,omitempty"` - Raw bool `json:"raw,omitempty"` - Format string `json:"format"` + Model string `json:"model"` + Prompt string `json:"prompt"` + System string `json:"system"` + Template string `json:"template"` + Context []int `json:"context,omitempty"` // DEPRECATED: context is deprecated, use messages instead + Messages []Message `json:"messages,omitempty"` // messages sent in the conversation so far + Stream *bool `json:"stream,omitempty"` + Raw bool `json:"raw,omitempty"` + Format string `json:"format"` Options map[string]interface{} `json:"options"` } +type Message struct { + Prompt string `json:"prompt"` + Response string `json:"response"` +} + // Options specfied in GenerateRequest, if you add a new option here add it to the API docs also type Options struct { Runner @@ -87,6 +93,22 @@ type Runner struct { NumThread int `json:"num_thread,omitempty"` } +type GenerateResponse struct { + Model string `json:"model"` + CreatedAt time.Time `json:"created_at"` + Response string `json:"response"` + + Done bool `json:"done"` + Context []int `json:"context,omitempty"` + + TotalDuration time.Duration `json:"total_duration,omitempty"` + LoadDuration time.Duration `json:"load_duration,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration time.Duration `json:"eval_duration,omitempty"` +} + type EmbeddingRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` @@ -164,22 +186,6 @@ type TokenResponse struct { Token string `json:"token"` } -type GenerateResponse struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - Response string `json:"response"` - - Done bool `json:"done"` - Context []int `json:"context,omitempty"` - - TotalDuration time.Duration `json:"total_duration,omitempty"` - LoadDuration time.Duration `json:"load_duration,omitempty"` - PromptEvalCount int `json:"prompt_eval_count,omitempty"` - PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` - EvalCount int `json:"eval_count,omitempty"` - EvalDuration time.Duration `json:"eval_duration,omitempty"` -} - func (r *GenerateResponse) Summary() { if r.TotalDuration > 0 { fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration) diff --git a/cmd/cmd.go b/cmd/cmd.go index 2c48ca80..42136cef 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -427,7 +427,11 @@ func RunGenerate(cmd *cobra.Command, args []string) error { // output is being piped if !term.IsTerminal(int(os.Stdout.Fd())) { - return generate(cmd, args[0], strings.Join(prompts, " "), false, format) + _, err := generate(cmd, args[0], strings.Join(prompts, " "), nil, false, format) + if err != nil { + return err + } + return nil } wordWrap := os.Getenv("TERM") == "xterm-256color" @@ -442,18 +446,20 @@ func RunGenerate(cmd *cobra.Command, args []string) error { // prompts are provided via stdin or args so don't enter interactive mode if len(prompts) > 0 { - return generate(cmd, args[0], strings.Join(prompts, " "), wordWrap, format) + _, err := generate(cmd, args[0], strings.Join(prompts, " "), nil, wordWrap, format) + if err != nil { + return err + } + return nil } return generateInteractive(cmd, args[0], wordWrap, format) } -type generateContextKey string - -func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format string) error { +func generate(cmd *cobra.Command, model, prompt string, messages []api.Message, wordWrap bool, format string) (*api.Message, error) { client, err := api.ClientFromEnvironment() if err != nil { - return err + return nil, err } p := progress.NewProgress(os.Stderr) @@ -464,11 +470,6 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st var latest api.GenerateResponse - generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int) - if !ok { - generateContext = []int{} - } - termWidth, _, err := term.GetSize(int(os.Stdout.Fd())) if err != nil { wordWrap = false @@ -490,14 +491,16 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st var currentLineLength int var wordBuffer string - request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, Format: format} - fn := func(response api.GenerateResponse) error { + var fullResponse strings.Builder + request := api.GenerateRequest{Model: model, Prompt: prompt, Messages: messages, Format: format} + fn := func(generated api.GenerateResponse) error { p.StopAndClear() - latest = response + latest = generated + fullResponse.WriteString(generated.Response) if wordWrap { - for _, ch := range response.Response { + for _, ch := range generated.Response { if currentLineLength+1 > termWidth-5 { // backtrack the length of the last word and clear to the end of the line fmt.Printf("\x1b[%dD\x1b[K\n", len(wordBuffer)) @@ -518,7 +521,7 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st } } } else { - fmt.Print(response.Response) + fmt.Print(generated.Response) } return nil @@ -526,9 +529,12 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st if err := client.Generate(cancelCtx, &request, fn); err != nil { if strings.Contains(err.Error(), "context canceled") && abort { - return nil + return &api.Message{ + Prompt: prompt, + Response: fullResponse.String(), + }, nil } - return err + return nil, err } if prompt != "" { fmt.Println() @@ -537,30 +543,32 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st if !latest.Done { if abort { - return nil + return &api.Message{ + Prompt: prompt, + Response: fullResponse.String(), + }, nil } - return errors.New("unexpected end of response") + return nil, errors.New("unexpected end of response") } verbose, err := cmd.Flags().GetBool("verbose") if err != nil { - return err + return nil, err } if verbose { latest.Summary() } - ctx := cmd.Context() - ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context) - cmd.SetContext(ctx) - - return nil + return &api.Message{ + Prompt: prompt, + Response: fullResponse.String(), + }, nil } func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format string) error { // load the model - if err := generate(cmd, model, "", false, ""); err != nil { + if _, err := generate(cmd, model, "", nil, false, ""); err != nil { return err } @@ -614,6 +622,7 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format defer fmt.Printf(readline.EndBracketedPaste) var multiLineBuffer string + messages := make([]api.Message, 0) for { line, err := scanner.Readline() @@ -756,9 +765,11 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format } if len(line) > 0 && line[0] != '/' { - if err := generate(cmd, model, line, wordWrap, format); err != nil { + message, err := generate(cmd, model, line, messages, wordWrap, format) + if err != nil { return err } + messages = append(messages, *message) } } } diff --git a/docs/api.md b/docs/api.md index 99378ac3..cf6bf86a 100644 --- a/docs/api.md +++ b/docs/api.md @@ -45,9 +45,13 @@ Advanced parameters (optional): - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `system`: system prompt to (overrides what is defined in the `Modelfile`) - `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`) -- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory +- `messages`: the messages of the conversation until this point, this can be used to keep a conversational memory - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects -- `raw`: if `true` no formatting will be applied to the prompt and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing history yourself. +- `raw`: if `true` no formatting will be applied to the prompt, and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing messages yourself. + +Deprecated parameters (optional): + +- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory ### JSON mode @@ -89,8 +93,8 @@ The final response in the stream also includes additional data about the generat - `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt - `eval_count`: number of tokens the response - `eval_duration`: time in nanoseconds spent generating the response -- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory - `response`: empty if the response was streamed, if not streamed, this will contain the full response +- `context`: optionally, if no messages were specified the context will be returned as an encoding of the conversation used in this response, this field is deprecated and will be removed in a future version To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`. @@ -146,6 +150,41 @@ If `stream` is set to `false`, the response will be a single JSON object: #### Request (Raw mode) +To continue a conversation, you can provide a `messages` parameter with the conversation so far. This is a list of prompts and responses. + +```shell +curl -X POST http://localhost:11434/api/generate -d '{ + "model": "mistral", + "prompt": "what did I just ask?", + "messages": [ + { + "prompt": "why is the sky blue?", + "response": "The sky appears blue because of a phenomenon called Rayleigh scattering." + } + ], + "stream": false, +}' +``` + +#### Response + +```json +{ + "model": "mistral", + "created_at": "2023-11-03T21:56:04.806917Z", + "response": "You asked for an explanation of why the sky is blue.", + "done": true, + "total_duration": 5211750166, + "load_duration": 3714731708, + "prompt_eval_count": 44, + "prompt_eval_duration": 532827000, + "eval_count": 12, + "eval_duration": 938680000 +} +``` + +#### Request + In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting and context. ```shell diff --git a/llm/llama.go b/llm/llama.go index 4eab751d..9b9f3b03 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -527,21 +527,9 @@ type prediction struct { const maxBufferSize = 512 * format.KiloByte -func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error { - prevConvo, err := llm.Decode(ctx, prevContext) - if err != nil { - return err - } - - // Remove leading spaces from prevConvo if present - prevConvo = strings.TrimPrefix(prevConvo, " ") - - var nextContext strings.Builder - nextContext.WriteString(prevConvo) - nextContext.WriteString(prompt) - +func (llm *llama) Predict(ctx context.Context, prompt string, format string, fn func(api.GenerateResponse)) error { request := map[string]any{ - "prompt": nextContext.String(), + "prompt": prompt, "stream": true, "n_predict": llm.NumPredict, "n_keep": llm.NumKeep, @@ -621,18 +609,12 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, if p.Content != "" { fn(api.GenerateResponse{Response: p.Content}) - nextContext.WriteString(p.Content) } if p.Stop { - embd, err := llm.Encode(ctx, nextContext.String()) - if err != nil { - return fmt.Errorf("encoding context: %v", err) - } fn(api.GenerateResponse{ Done: true, - Context: embd, PromptEvalCount: p.Timings.PromptN, PromptEvalDuration: parseDurationMs(p.Timings.PromptMS), EvalCount: p.Timings.PredictedN, diff --git a/llm/llm.go b/llm/llm.go index 22706da5..4303d26c 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -14,7 +14,7 @@ import ( ) type LLM interface { - Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error + Predict(context.Context, string, string, func(api.GenerateResponse)) error Embedding(context.Context, string) ([]float64, error) Encode(context.Context, string) ([]int, error) Decode(context.Context, []int) (string, error) diff --git a/server/images.go b/server/images.go index 8d2af15b..5aa051e0 100644 --- a/server/images.go +++ b/server/images.go @@ -48,10 +48,17 @@ type Model struct { Options map[string]interface{} } -func (m *Model) Prompt(request api.GenerateRequest) (string, error) { +type PromptVars struct { + First bool + System string + Prompt string +} + +func (m *Model) Prompt(vars PromptVars, reqTemplate string) (string, error) { t := m.Template - if request.Template != "" { - t = request.Template + if reqTemplate != "" { + // override the model template if one is specified + t = reqTemplate } tmpl, err := template.New("").Parse(t) @@ -59,18 +66,9 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) { return "", err } - var vars struct { - First bool - System string - Prompt string - } - - vars.First = len(request.Context) == 0 - vars.System = m.System - vars.Prompt = request.Prompt - - if request.System != "" { - vars.System = request.System + if vars.System == "" { + // use the default system prompt for this model if one is not specified + vars.System = m.System } var sb strings.Builder diff --git a/server/images_test.go b/server/images_test.go index 5e6a197b..d700bf0b 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -2,17 +2,14 @@ package server import ( "testing" - - "github.com/jmorganca/ollama/api" ) func TestModelPrompt(t *testing.T) { var m Model - req := api.GenerateRequest{ - Template: "a{{ .Prompt }}b", - Prompt: "

", - } - s, err := m.Prompt(req) + s, err := m.Prompt(PromptVars{ + First: true, + Prompt: "

", + }, "a{{ .Prompt }}b") if err != nil { t.Fatal(err) } diff --git a/server/routes.go b/server/routes.go index 8a5a5a24..617eaa5e 100644 --- a/server/routes.go +++ b/server/routes.go @@ -161,6 +161,8 @@ func GenerateHandler(c *gin.Context) { } // validate the request + isContextSet := req.Context != nil && len(req.Context) > 0 + areMessagesSet := req.Messages != nil && len(req.Messages) > 0 switch { case req.Model == "": c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) @@ -168,9 +170,12 @@ func GenerateHandler(c *gin.Context) { case len(req.Format) > 0 && req.Format != "json": c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"}) return - case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0): - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"}) + case req.Raw && (req.Template != "" || req.System != "" || isContextSet || areMessagesSet): + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, context, or messages"}) return + case areMessagesSet && isContextSet: + // this makes rebuilding the prompt history too complicated, so don't allow it + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "only one of messages or context may be specified"}) } model, err := GetModel(req.Model) @@ -199,20 +204,65 @@ func GenerateHandler(c *gin.Context) { checkpointLoaded := time.Now() - prompt := req.Prompt - if !req.Raw { - prompt, err = model.Prompt(req) + var prompt strings.Builder + if req.Context != nil { + // TODO: context is deprecated, at some point the context logic within this conditional should be removed + // if the request has a context rather than messages, decode it and add it to the prompt + prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + + // Remove leading spaces from prevCtx if present + prevCtx = strings.TrimPrefix(prevCtx, " ") + prompt.WriteString(prevCtx) + } + // build the prompt history from messages + for i, m := range req.Messages { + // apply the template to the prompt + p, err := model.Prompt(PromptVars{ + First: i == 0, + Prompt: m.Prompt, + System: req.System, + }, req.Template) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + prompt.WriteString(p) + prompt.WriteString(m.Response) } + // finally, add the current prompt as the most recent message + first := !isContextSet && !areMessagesSet + if req.Raw { + prompt.WriteString(req.Prompt) + } else if strings.TrimSpace(req.Prompt) != "" { + // template the request prompt before adding it + p, err := model.Prompt(PromptVars{ + First: first, + System: req.System, + Prompt: req.Prompt, + }, req.Template) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + prompt.WriteString(p) + } + + sendContext := first || isContextSet + var respCtx strings.Builder + if _, err := respCtx.WriteString(prompt.String()); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } ch := make(chan any) go func() { defer close(ch) // an empty request loads the model - if req.Prompt == "" && req.Template == "" && req.System == "" { + if req.Prompt == "" && req.Template == "" && req.System == "" && !areMessagesSet { ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true} return } @@ -223,9 +273,26 @@ func GenerateHandler(c *gin.Context) { r.Model = req.Model r.CreatedAt = time.Now().UTC() + // if the final response expects a context, build the context as we go + if sendContext { + if _, err := respCtx.WriteString(r.Response); err != nil { + ch <- gin.H{"error": err.Error()} + return + } + } + if r.Done { r.TotalDuration = time.Since(checkpointStart) r.LoadDuration = checkpointLoaded.Sub(checkpointStart) + // if the response expects a context, encode it and send it back + if sendContext { + embd, err := loaded.runner.Encode(c.Request.Context(), respCtx.String()) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + r.Context = embd + } } if req.Raw { @@ -236,7 +303,7 @@ func GenerateHandler(c *gin.Context) { ch <- r } - if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, req.Format, fn); err != nil { + if err := loaded.runner.Predict(c.Request.Context(), prompt.String(), req.Format, fn); err != nil { ch <- gin.H{"error": err.Error()} } }()