diff --git a/api/types.go b/api/types.go index 1940ba40..4d9489ca 100644 --- a/api/types.go +++ b/api/types.go @@ -32,7 +32,7 @@ func (e StatusError) Error() string { type GenerateRequest struct { Model string `json:"model"` - Prompt string `json:"prompt"` + Prompt string `json:"prompt"` // prompt sends a message as the user System string `json:"system"` Template string `json:"template"` Context []int `json:"context,omitempty"` // DEPRECATED: context is deprecated, use messages instead @@ -45,8 +45,8 @@ type GenerateRequest struct { } type Message struct { - Prompt string `json:"prompt"` - Response string `json:"response"` + Role string `json:"role"` + Content string `json:"content"` } // Options specfied in GenerateRequest, if you add a new option here add it to the API docs also @@ -96,7 +96,8 @@ type Runner struct { type GenerateResponse struct { Model string `json:"model"` CreatedAt time.Time `json:"created_at"` - Response string `json:"response"` + Response string `json:"response"` // the last response chunk when streaming + Message Message `json:"message"` Done bool `json:"done"` Context []int `json:"context,omitempty"` diff --git a/cmd/cmd.go b/cmd/cmd.go index 42136cef..a801bedd 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -529,10 +529,7 @@ func generate(cmd *cobra.Command, model, prompt string, messages []api.Message, if err := client.Generate(cancelCtx, &request, fn); err != nil { if strings.Contains(err.Error(), "context canceled") && abort { - return &api.Message{ - Prompt: prompt, - Response: fullResponse.String(), - }, nil + return nil, nil } return nil, err } @@ -543,10 +540,7 @@ func generate(cmd *cobra.Command, model, prompt string, messages []api.Message, if !latest.Done { if abort { - return &api.Message{ - Prompt: prompt, - Response: fullResponse.String(), - }, nil + return nil, nil } return nil, errors.New("unexpected end of response") } @@ -560,10 +554,7 @@ func generate(cmd *cobra.Command, model, prompt string, messages []api.Message, latest.Summary() } - return &api.Message{ - Prompt: prompt, - Response: fullResponse.String(), - }, nil + return &latest.Message, nil } func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format string) error { @@ -765,11 +756,12 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format } if len(line) > 0 && line[0] != '/' { - message, err := generate(cmd, model, line, messages, wordWrap, format) + assistant, err := generate(cmd, model, line, messages, wordWrap, format) if err != nil { return err } - messages = append(messages, *message) + messages = append(messages, api.Message{Role: "user", Content: line}) + messages = append(messages, *assistant) } } } diff --git a/server/images.go b/server/images.go index 5aa051e0..fd07d83b 100644 --- a/server/images.go +++ b/server/images.go @@ -54,7 +54,7 @@ type PromptVars struct { Prompt string } -func (m *Model) Prompt(vars PromptVars, reqTemplate string) (string, error) { +func (m *Model) Prompt(vars *PromptVars, reqTemplate string) (string, error) { t := m.Template if reqTemplate != "" { // override the model template if one is specified diff --git a/server/images_test.go b/server/images_test.go index d700bf0b..b34fdddb 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -6,7 +6,7 @@ import ( func TestModelPrompt(t *testing.T) { var m Model - s, err := m.Prompt(PromptVars{ + s, err := m.Prompt(&PromptVars{ First: true, Prompt: "

", }, "a{{ .Prompt }}b") diff --git a/server/routes.go b/server/routes.go index 617eaa5e..7618f621 100644 --- a/server/routes.go +++ b/server/routes.go @@ -219,19 +219,34 @@ func GenerateHandler(c *gin.Context) { prompt.WriteString(prevCtx) } // build the prompt history from messages - for i, m := range req.Messages { - // apply the template to the prompt - p, err := model.Prompt(PromptVars{ - First: i == 0, - Prompt: m.Prompt, - System: req.System, - }, req.Template) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return + vars := &PromptVars{ + First: true, + } + for _, m := range req.Messages { + if m.Role == "system" { + if vars.System != "" { + flush(vars, req.Template, model) + } + vars.System = m.Content } - prompt.WriteString(p) - prompt.WriteString(m.Response) + + if m.Role == "user" { + if vars.Prompt != "" { + flush(vars, req.Template, model) + } + vars.Prompt = m.Content + } + + if m.Role == "assistant" { + if vars.Prompt != "" || vars.System != "" { + flush(vars, req.Template, model) + } + prompt.WriteString(m.Content) + } + } + + if vars.Prompt != "" || vars.System != "" { + flush(vars, req.Template, model) } // finally, add the current prompt as the most recent message @@ -240,7 +255,7 @@ func GenerateHandler(c *gin.Context) { prompt.WriteString(req.Prompt) } else if strings.TrimSpace(req.Prompt) != "" { // template the request prompt before adding it - p, err := model.Prompt(PromptVars{ + p, err := model.Prompt(&PromptVars{ First: first, System: req.System, Prompt: req.Prompt, @@ -253,11 +268,7 @@ func GenerateHandler(c *gin.Context) { } sendContext := first || isContextSet - var respCtx strings.Builder - if _, err := respCtx.WriteString(prompt.String()); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + var msgContent strings.Builder ch := make(chan any) go func() { defer close(ch) @@ -273,20 +284,21 @@ func GenerateHandler(c *gin.Context) { r.Model = req.Model r.CreatedAt = time.Now().UTC() - // if the final response expects a context, build the context as we go - if sendContext { - if _, err := respCtx.WriteString(r.Response); err != nil { - ch <- gin.H{"error": err.Error()} - return - } + if _, err := msgContent.WriteString(r.Response); err != nil { + ch <- gin.H{"error": err.Error()} + return } if r.Done { r.TotalDuration = time.Since(checkpointStart) r.LoadDuration = checkpointLoaded.Sub(checkpointStart) + r.Message = api.Message{ + Role: "assistant", + Content: msgContent.String(), + } // if the response expects a context, encode it and send it back if sendContext { - embd, err := loaded.runner.Encode(c.Request.Context(), respCtx.String()) + embd, err := loaded.runner.Encode(c.Request.Context(), prompt.String()+msgContent.String()) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -328,6 +340,17 @@ func GenerateHandler(c *gin.Context) { streamResponse(c, ch) } +func flush(prompt *PromptVars, template string, model *Model) (string, error) { + p, err := model.Prompt(prompt, template) + if err != nil { + return "", err + } + prompt.First = false + prompt.Prompt = "" + prompt.System = "" + return p, nil +} + func EmbeddingHandler(c *gin.Context) { loaded.mu.Lock() defer loaded.mu.Unlock()