switch to role based messages

This commit is contained in:
Bruce MacDonald 2023-11-14 18:01:33 -05:00
parent 9c21d23a35
commit 4718ecc62e
5 changed files with 61 additions and 45 deletions

View File

@ -32,7 +32,7 @@ func (e StatusError) Error() string {
type GenerateRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Prompt string `json:"prompt"` // prompt sends a message as the user
System string `json:"system"`
Template string `json:"template"`
Context []int `json:"context,omitempty"` // DEPRECATED: context is deprecated, use messages instead
@ -45,8 +45,8 @@ type GenerateRequest struct {
}
type Message struct {
Prompt string `json:"prompt"`
Response string `json:"response"`
Role string `json:"role"`
Content string `json:"content"`
}
// Options specfied in GenerateRequest, if you add a new option here add it to the API docs also
@ -96,7 +96,8 @@ type Runner struct {
type GenerateResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response"`
Response string `json:"response"` // the last response chunk when streaming
Message Message `json:"message"`
Done bool `json:"done"`
Context []int `json:"context,omitempty"`

View File

@ -529,10 +529,7 @@ func generate(cmd *cobra.Command, model, prompt string, messages []api.Message,
if err := client.Generate(cancelCtx, &request, fn); err != nil {
if strings.Contains(err.Error(), "context canceled") && abort {
return &api.Message{
Prompt: prompt,
Response: fullResponse.String(),
}, nil
return nil, nil
}
return nil, err
}
@ -543,10 +540,7 @@ func generate(cmd *cobra.Command, model, prompt string, messages []api.Message,
if !latest.Done {
if abort {
return &api.Message{
Prompt: prompt,
Response: fullResponse.String(),
}, nil
return nil, nil
}
return nil, errors.New("unexpected end of response")
}
@ -560,10 +554,7 @@ func generate(cmd *cobra.Command, model, prompt string, messages []api.Message,
latest.Summary()
}
return &api.Message{
Prompt: prompt,
Response: fullResponse.String(),
}, nil
return &latest.Message, nil
}
func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format string) error {
@ -765,11 +756,12 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
}
if len(line) > 0 && line[0] != '/' {
message, err := generate(cmd, model, line, messages, wordWrap, format)
assistant, err := generate(cmd, model, line, messages, wordWrap, format)
if err != nil {
return err
}
messages = append(messages, *message)
messages = append(messages, api.Message{Role: "user", Content: line})
messages = append(messages, *assistant)
}
}
}

View File

@ -54,7 +54,7 @@ type PromptVars struct {
Prompt string
}
func (m *Model) Prompt(vars PromptVars, reqTemplate string) (string, error) {
func (m *Model) Prompt(vars *PromptVars, reqTemplate string) (string, error) {
t := m.Template
if reqTemplate != "" {
// override the model template if one is specified

View File

@ -6,7 +6,7 @@ import (
func TestModelPrompt(t *testing.T) {
var m Model
s, err := m.Prompt(PromptVars{
s, err := m.Prompt(&PromptVars{
First: true,
Prompt: "<h1>",
}, "a{{ .Prompt }}b")

View File

@ -219,19 +219,34 @@ func GenerateHandler(c *gin.Context) {
prompt.WriteString(prevCtx)
}
// build the prompt history from messages
for i, m := range req.Messages {
// apply the template to the prompt
p, err := model.Prompt(PromptVars{
First: i == 0,
Prompt: m.Prompt,
System: req.System,
}, req.Template)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
vars := &PromptVars{
First: true,
}
for _, m := range req.Messages {
if m.Role == "system" {
if vars.System != "" {
flush(vars, req.Template, model)
}
vars.System = m.Content
}
prompt.WriteString(p)
prompt.WriteString(m.Response)
if m.Role == "user" {
if vars.Prompt != "" {
flush(vars, req.Template, model)
}
vars.Prompt = m.Content
}
if m.Role == "assistant" {
if vars.Prompt != "" || vars.System != "" {
flush(vars, req.Template, model)
}
prompt.WriteString(m.Content)
}
}
if vars.Prompt != "" || vars.System != "" {
flush(vars, req.Template, model)
}
// finally, add the current prompt as the most recent message
@ -240,7 +255,7 @@ func GenerateHandler(c *gin.Context) {
prompt.WriteString(req.Prompt)
} else if strings.TrimSpace(req.Prompt) != "" {
// template the request prompt before adding it
p, err := model.Prompt(PromptVars{
p, err := model.Prompt(&PromptVars{
First: first,
System: req.System,
Prompt: req.Prompt,
@ -253,11 +268,7 @@ func GenerateHandler(c *gin.Context) {
}
sendContext := first || isContextSet
var respCtx strings.Builder
if _, err := respCtx.WriteString(prompt.String()); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var msgContent strings.Builder
ch := make(chan any)
go func() {
defer close(ch)
@ -273,20 +284,21 @@ func GenerateHandler(c *gin.Context) {
r.Model = req.Model
r.CreatedAt = time.Now().UTC()
// if the final response expects a context, build the context as we go
if sendContext {
if _, err := respCtx.WriteString(r.Response); err != nil {
ch <- gin.H{"error": err.Error()}
return
}
if _, err := msgContent.WriteString(r.Response); err != nil {
ch <- gin.H{"error": err.Error()}
return
}
if r.Done {
r.TotalDuration = time.Since(checkpointStart)
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
r.Message = api.Message{
Role: "assistant",
Content: msgContent.String(),
}
// if the response expects a context, encode it and send it back
if sendContext {
embd, err := loaded.runner.Encode(c.Request.Context(), respCtx.String())
embd, err := loaded.runner.Encode(c.Request.Context(), prompt.String()+msgContent.String())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@ -328,6 +340,17 @@ func GenerateHandler(c *gin.Context) {
streamResponse(c, ch)
}
func flush(prompt *PromptVars, template string, model *Model) (string, error) {
p, err := model.Prompt(prompt, template)
if err != nil {
return "", err
}
prompt.First = false
prompt.Prompt = ""
prompt.System = ""
return p, nil
}
func EmbeddingHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()