add messages to /generate api

- deprecate generation context, but continue to support it
- on first request generation context will still be returned
- if messages are specified context is not returned
- rebuild generation context from prompt/reply messages
- update generate docs with messages parameter
This commit is contained in:
Bruce MacDonald 2023-11-03 17:50:21 -04:00
parent 6066c70edd
commit 9c21d23a35
8 changed files with 205 additions and 105 deletions

View File

@ -31,18 +31,24 @@ func (e StatusError) Error() string {
}
type GenerateRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
System string `json:"system"`
Template string `json:"template"`
Context []int `json:"context,omitempty"`
Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"`
Format string `json:"format"`
Model string `json:"model"`
Prompt string `json:"prompt"`
System string `json:"system"`
Template string `json:"template"`
Context []int `json:"context,omitempty"` // DEPRECATED: context is deprecated, use messages instead
Messages []Message `json:"messages,omitempty"` // messages sent in the conversation so far
Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"`
Format string `json:"format"`
Options map[string]interface{} `json:"options"`
}
type Message struct {
Prompt string `json:"prompt"`
Response string `json:"response"`
}
// Options specfied in GenerateRequest, if you add a new option here add it to the API docs also
type Options struct {
Runner
@ -87,6 +93,22 @@ type Runner struct {
NumThread int `json:"num_thread,omitempty"`
}
type GenerateResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response"`
Done bool `json:"done"`
Context []int `json:"context,omitempty"`
TotalDuration time.Duration `json:"total_duration,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration time.Duration `json:"eval_duration,omitempty"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
@ -164,22 +186,6 @@ type TokenResponse struct {
Token string `json:"token"`
}
type GenerateResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response"`
Done bool `json:"done"`
Context []int `json:"context,omitempty"`
TotalDuration time.Duration `json:"total_duration,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration time.Duration `json:"eval_duration,omitempty"`
}
func (r *GenerateResponse) Summary() {
if r.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration)

View File

@ -427,7 +427,11 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
// output is being piped
if !term.IsTerminal(int(os.Stdout.Fd())) {
return generate(cmd, args[0], strings.Join(prompts, " "), false, format)
_, err := generate(cmd, args[0], strings.Join(prompts, " "), nil, false, format)
if err != nil {
return err
}
return nil
}
wordWrap := os.Getenv("TERM") == "xterm-256color"
@ -442,18 +446,20 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
// prompts are provided via stdin or args so don't enter interactive mode
if len(prompts) > 0 {
return generate(cmd, args[0], strings.Join(prompts, " "), wordWrap, format)
_, err := generate(cmd, args[0], strings.Join(prompts, " "), nil, wordWrap, format)
if err != nil {
return err
}
return nil
}
return generateInteractive(cmd, args[0], wordWrap, format)
}
type generateContextKey string
func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format string) error {
func generate(cmd *cobra.Command, model, prompt string, messages []api.Message, wordWrap bool, format string) (*api.Message, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
return nil, err
}
p := progress.NewProgress(os.Stderr)
@ -464,11 +470,6 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
var latest api.GenerateResponse
generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
if !ok {
generateContext = []int{}
}
termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil {
wordWrap = false
@ -490,14 +491,16 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
var currentLineLength int
var wordBuffer string
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, Format: format}
fn := func(response api.GenerateResponse) error {
var fullResponse strings.Builder
request := api.GenerateRequest{Model: model, Prompt: prompt, Messages: messages, Format: format}
fn := func(generated api.GenerateResponse) error {
p.StopAndClear()
latest = response
latest = generated
fullResponse.WriteString(generated.Response)
if wordWrap {
for _, ch := range response.Response {
for _, ch := range generated.Response {
if currentLineLength+1 > termWidth-5 {
// backtrack the length of the last word and clear to the end of the line
fmt.Printf("\x1b[%dD\x1b[K\n", len(wordBuffer))
@ -518,7 +521,7 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
}
}
} else {
fmt.Print(response.Response)
fmt.Print(generated.Response)
}
return nil
@ -526,9 +529,12 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
if err := client.Generate(cancelCtx, &request, fn); err != nil {
if strings.Contains(err.Error(), "context canceled") && abort {
return nil
return &api.Message{
Prompt: prompt,
Response: fullResponse.String(),
}, nil
}
return err
return nil, err
}
if prompt != "" {
fmt.Println()
@ -537,30 +543,32 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
if !latest.Done {
if abort {
return nil
return &api.Message{
Prompt: prompt,
Response: fullResponse.String(),
}, nil
}
return errors.New("unexpected end of response")
return nil, errors.New("unexpected end of response")
}
verbose, err := cmd.Flags().GetBool("verbose")
if err != nil {
return err
return nil, err
}
if verbose {
latest.Summary()
}
ctx := cmd.Context()
ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
cmd.SetContext(ctx)
return nil
return &api.Message{
Prompt: prompt,
Response: fullResponse.String(),
}, nil
}
func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format string) error {
// load the model
if err := generate(cmd, model, "", false, ""); err != nil {
if _, err := generate(cmd, model, "", nil, false, ""); err != nil {
return err
}
@ -614,6 +622,7 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
defer fmt.Printf(readline.EndBracketedPaste)
var multiLineBuffer string
messages := make([]api.Message, 0)
for {
line, err := scanner.Readline()
@ -756,9 +765,11 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
}
if len(line) > 0 && line[0] != '/' {
if err := generate(cmd, model, line, wordWrap, format); err != nil {
message, err := generate(cmd, model, line, messages, wordWrap, format)
if err != nil {
return err
}
messages = append(messages, *message)
}
}
}

View File

@ -45,9 +45,13 @@ Advanced parameters (optional):
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `system`: system prompt to (overrides what is defined in the `Modelfile`)
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
- `messages`: the messages of the conversation until this point, this can be used to keep a conversational memory
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `raw`: if `true` no formatting will be applied to the prompt and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing history yourself.
- `raw`: if `true` no formatting will be applied to the prompt, and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing messages yourself.
Deprecated parameters (optional):
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
### JSON mode
@ -89,8 +93,8 @@ The final response in the stream also includes additional data about the generat
- `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt
- `eval_count`: number of tokens the response
- `eval_duration`: time in nanoseconds spent generating the response
- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
- `response`: empty if the response was streamed, if not streamed, this will contain the full response
- `context`: optionally, if no messages were specified the context will be returned as an encoding of the conversation used in this response, this field is deprecated and will be removed in a future version
To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`.
@ -146,6 +150,41 @@ If `stream` is set to `false`, the response will be a single JSON object:
#### Request (Raw mode)
To continue a conversation, you can provide a `messages` parameter with the conversation so far. This is a list of prompts and responses.
```shell
curl -X POST http://localhost:11434/api/generate -d '{
"model": "mistral",
"prompt": "what did I just ask?",
"messages": [
{
"prompt": "why is the sky blue?",
"response": "The sky appears blue because of a phenomenon called Rayleigh scattering."
}
],
"stream": false,
}'
```
#### Response
```json
{
"model": "mistral",
"created_at": "2023-11-03T21:56:04.806917Z",
"response": "You asked for an explanation of why the sky is blue.",
"done": true,
"total_duration": 5211750166,
"load_duration": 3714731708,
"prompt_eval_count": 44,
"prompt_eval_duration": 532827000,
"eval_count": 12,
"eval_duration": 938680000
}
```
#### Request
In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting and context.
```shell

View File

@ -527,21 +527,9 @@ type prediction struct {
const maxBufferSize = 512 * format.KiloByte
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error {
prevConvo, err := llm.Decode(ctx, prevContext)
if err != nil {
return err
}
// Remove leading spaces from prevConvo if present
prevConvo = strings.TrimPrefix(prevConvo, " ")
var nextContext strings.Builder
nextContext.WriteString(prevConvo)
nextContext.WriteString(prompt)
func (llm *llama) Predict(ctx context.Context, prompt string, format string, fn func(api.GenerateResponse)) error {
request := map[string]any{
"prompt": nextContext.String(),
"prompt": prompt,
"stream": true,
"n_predict": llm.NumPredict,
"n_keep": llm.NumKeep,
@ -621,18 +609,12 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
if p.Content != "" {
fn(api.GenerateResponse{Response: p.Content})
nextContext.WriteString(p.Content)
}
if p.Stop {
embd, err := llm.Encode(ctx, nextContext.String())
if err != nil {
return fmt.Errorf("encoding context: %v", err)
}
fn(api.GenerateResponse{
Done: true,
Context: embd,
PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
EvalCount: p.Timings.PredictedN,

View File

@ -14,7 +14,7 @@ import (
)
type LLM interface {
Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error
Predict(context.Context, string, string, func(api.GenerateResponse)) error
Embedding(context.Context, string) ([]float64, error)
Encode(context.Context, string) ([]int, error)
Decode(context.Context, []int) (string, error)

View File

@ -48,10 +48,17 @@ type Model struct {
Options map[string]interface{}
}
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
type PromptVars struct {
First bool
System string
Prompt string
}
func (m *Model) Prompt(vars PromptVars, reqTemplate string) (string, error) {
t := m.Template
if request.Template != "" {
t = request.Template
if reqTemplate != "" {
// override the model template if one is specified
t = reqTemplate
}
tmpl, err := template.New("").Parse(t)
@ -59,18 +66,9 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
return "", err
}
var vars struct {
First bool
System string
Prompt string
}
vars.First = len(request.Context) == 0
vars.System = m.System
vars.Prompt = request.Prompt
if request.System != "" {
vars.System = request.System
if vars.System == "" {
// use the default system prompt for this model if one is not specified
vars.System = m.System
}
var sb strings.Builder

View File

@ -2,17 +2,14 @@ package server
import (
"testing"
"github.com/jmorganca/ollama/api"
)
func TestModelPrompt(t *testing.T) {
var m Model
req := api.GenerateRequest{
Template: "a{{ .Prompt }}b",
Prompt: "<h1>",
}
s, err := m.Prompt(req)
s, err := m.Prompt(PromptVars{
First: true,
Prompt: "<h1>",
}, "a{{ .Prompt }}b")
if err != nil {
t.Fatal(err)
}

View File

@ -161,6 +161,8 @@ func GenerateHandler(c *gin.Context) {
}
// validate the request
isContextSet := req.Context != nil && len(req.Context) > 0
areMessagesSet := req.Messages != nil && len(req.Messages) > 0
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
@ -168,9 +170,12 @@ func GenerateHandler(c *gin.Context) {
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
case req.Raw && (req.Template != "" || req.System != "" || isContextSet || areMessagesSet):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, context, or messages"})
return
case areMessagesSet && isContextSet:
// this makes rebuilding the prompt history too complicated, so don't allow it
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "only one of messages or context may be specified"})
}
model, err := GetModel(req.Model)
@ -199,20 +204,65 @@ func GenerateHandler(c *gin.Context) {
checkpointLoaded := time.Now()
prompt := req.Prompt
if !req.Raw {
prompt, err = model.Prompt(req)
var prompt strings.Builder
if req.Context != nil {
// TODO: context is deprecated, at some point the context logic within this conditional should be removed
// if the request has a context rather than messages, decode it and add it to the prompt
prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Remove leading spaces from prevCtx if present
prevCtx = strings.TrimPrefix(prevCtx, " ")
prompt.WriteString(prevCtx)
}
// build the prompt history from messages
for i, m := range req.Messages {
// apply the template to the prompt
p, err := model.Prompt(PromptVars{
First: i == 0,
Prompt: m.Prompt,
System: req.System,
}, req.Template)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
prompt.WriteString(p)
prompt.WriteString(m.Response)
}
// finally, add the current prompt as the most recent message
first := !isContextSet && !areMessagesSet
if req.Raw {
prompt.WriteString(req.Prompt)
} else if strings.TrimSpace(req.Prompt) != "" {
// template the request prompt before adding it
p, err := model.Prompt(PromptVars{
First: first,
System: req.System,
Prompt: req.Prompt,
}, req.Template)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
prompt.WriteString(p)
}
sendContext := first || isContextSet
var respCtx strings.Builder
if _, err := respCtx.WriteString(prompt.String()); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ch := make(chan any)
go func() {
defer close(ch)
// an empty request loads the model
if req.Prompt == "" && req.Template == "" && req.System == "" {
if req.Prompt == "" && req.Template == "" && req.System == "" && !areMessagesSet {
ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
return
}
@ -223,9 +273,26 @@ func GenerateHandler(c *gin.Context) {
r.Model = req.Model
r.CreatedAt = time.Now().UTC()
// if the final response expects a context, build the context as we go
if sendContext {
if _, err := respCtx.WriteString(r.Response); err != nil {
ch <- gin.H{"error": err.Error()}
return
}
}
if r.Done {
r.TotalDuration = time.Since(checkpointStart)
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
// if the response expects a context, encode it and send it back
if sendContext {
embd, err := loaded.runner.Encode(c.Request.Context(), respCtx.String())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
r.Context = embd
}
}
if req.Raw {
@ -236,7 +303,7 @@ func GenerateHandler(c *gin.Context) {
ch <- r
}
if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, req.Format, fn); err != nil {
if err := loaded.runner.Predict(c.Request.Context(), prompt.String(), req.Format, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()