switch to role based messages
This commit is contained in:
parent
9c21d23a35
commit
4718ecc62e
@ -32,7 +32,7 @@ func (e StatusError) Error() string {
|
||||
|
||||
type GenerateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Prompt string `json:"prompt"` // prompt sends a message as the user
|
||||
System string `json:"system"`
|
||||
Template string `json:"template"`
|
||||
Context []int `json:"context,omitempty"` // DEPRECATED: context is deprecated, use messages instead
|
||||
@ -45,8 +45,8 @@ type GenerateRequest struct {
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Response string `json:"response"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// Options specfied in GenerateRequest, if you add a new option here add it to the API docs also
|
||||
@ -96,7 +96,8 @@ type Runner struct {
|
||||
type GenerateResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Response string `json:"response"`
|
||||
Response string `json:"response"` // the last response chunk when streaming
|
||||
Message Message `json:"message"`
|
||||
|
||||
Done bool `json:"done"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
20
cmd/cmd.go
20
cmd/cmd.go
@ -529,10 +529,7 @@ func generate(cmd *cobra.Command, model, prompt string, messages []api.Message,
|
||||
|
||||
if err := client.Generate(cancelCtx, &request, fn); err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") && abort {
|
||||
return &api.Message{
|
||||
Prompt: prompt,
|
||||
Response: fullResponse.String(),
|
||||
}, nil
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@ -543,10 +540,7 @@ func generate(cmd *cobra.Command, model, prompt string, messages []api.Message,
|
||||
|
||||
if !latest.Done {
|
||||
if abort {
|
||||
return &api.Message{
|
||||
Prompt: prompt,
|
||||
Response: fullResponse.String(),
|
||||
}, nil
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.New("unexpected end of response")
|
||||
}
|
||||
@ -560,10 +554,7 @@ func generate(cmd *cobra.Command, model, prompt string, messages []api.Message,
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
return &api.Message{
|
||||
Prompt: prompt,
|
||||
Response: fullResponse.String(),
|
||||
}, nil
|
||||
return &latest.Message, nil
|
||||
}
|
||||
|
||||
func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format string) error {
|
||||
@ -765,11 +756,12 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
|
||||
}
|
||||
|
||||
if len(line) > 0 && line[0] != '/' {
|
||||
message, err := generate(cmd, model, line, messages, wordWrap, format)
|
||||
assistant, err := generate(cmd, model, line, messages, wordWrap, format)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
messages = append(messages, *message)
|
||||
messages = append(messages, api.Message{Role: "user", Content: line})
|
||||
messages = append(messages, *assistant)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -54,7 +54,7 @@ type PromptVars struct {
|
||||
Prompt string
|
||||
}
|
||||
|
||||
func (m *Model) Prompt(vars PromptVars, reqTemplate string) (string, error) {
|
||||
func (m *Model) Prompt(vars *PromptVars, reqTemplate string) (string, error) {
|
||||
t := m.Template
|
||||
if reqTemplate != "" {
|
||||
// override the model template if one is specified
|
||||
|
@ -6,7 +6,7 @@ import (
|
||||
|
||||
func TestModelPrompt(t *testing.T) {
|
||||
var m Model
|
||||
s, err := m.Prompt(PromptVars{
|
||||
s, err := m.Prompt(&PromptVars{
|
||||
First: true,
|
||||
Prompt: "<h1>",
|
||||
}, "a{{ .Prompt }}b")
|
||||
|
@ -219,19 +219,34 @@ func GenerateHandler(c *gin.Context) {
|
||||
prompt.WriteString(prevCtx)
|
||||
}
|
||||
// build the prompt history from messages
|
||||
for i, m := range req.Messages {
|
||||
// apply the template to the prompt
|
||||
p, err := model.Prompt(PromptVars{
|
||||
First: i == 0,
|
||||
Prompt: m.Prompt,
|
||||
System: req.System,
|
||||
}, req.Template)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
vars := &PromptVars{
|
||||
First: true,
|
||||
}
|
||||
for _, m := range req.Messages {
|
||||
if m.Role == "system" {
|
||||
if vars.System != "" {
|
||||
flush(vars, req.Template, model)
|
||||
}
|
||||
vars.System = m.Content
|
||||
}
|
||||
prompt.WriteString(p)
|
||||
prompt.WriteString(m.Response)
|
||||
|
||||
if m.Role == "user" {
|
||||
if vars.Prompt != "" {
|
||||
flush(vars, req.Template, model)
|
||||
}
|
||||
vars.Prompt = m.Content
|
||||
}
|
||||
|
||||
if m.Role == "assistant" {
|
||||
if vars.Prompt != "" || vars.System != "" {
|
||||
flush(vars, req.Template, model)
|
||||
}
|
||||
prompt.WriteString(m.Content)
|
||||
}
|
||||
}
|
||||
|
||||
if vars.Prompt != "" || vars.System != "" {
|
||||
flush(vars, req.Template, model)
|
||||
}
|
||||
|
||||
// finally, add the current prompt as the most recent message
|
||||
@ -240,7 +255,7 @@ func GenerateHandler(c *gin.Context) {
|
||||
prompt.WriteString(req.Prompt)
|
||||
} else if strings.TrimSpace(req.Prompt) != "" {
|
||||
// template the request prompt before adding it
|
||||
p, err := model.Prompt(PromptVars{
|
||||
p, err := model.Prompt(&PromptVars{
|
||||
First: first,
|
||||
System: req.System,
|
||||
Prompt: req.Prompt,
|
||||
@ -253,11 +268,7 @@ func GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
sendContext := first || isContextSet
|
||||
var respCtx strings.Builder
|
||||
if _, err := respCtx.WriteString(prompt.String()); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
var msgContent strings.Builder
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
@ -273,20 +284,21 @@ func GenerateHandler(c *gin.Context) {
|
||||
|
||||
r.Model = req.Model
|
||||
r.CreatedAt = time.Now().UTC()
|
||||
// if the final response expects a context, build the context as we go
|
||||
if sendContext {
|
||||
if _, err := respCtx.WriteString(r.Response); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
if _, err := msgContent.WriteString(r.Response); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
r.TotalDuration = time.Since(checkpointStart)
|
||||
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
r.Message = api.Message{
|
||||
Role: "assistant",
|
||||
Content: msgContent.String(),
|
||||
}
|
||||
// if the response expects a context, encode it and send it back
|
||||
if sendContext {
|
||||
embd, err := loaded.runner.Encode(c.Request.Context(), respCtx.String())
|
||||
embd, err := loaded.runner.Encode(c.Request.Context(), prompt.String()+msgContent.String())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@ -328,6 +340,17 @@ func GenerateHandler(c *gin.Context) {
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func flush(prompt *PromptVars, template string, model *Model) (string, error) {
|
||||
p, err := model.Prompt(prompt, template)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
prompt.First = false
|
||||
prompt.Prompt = ""
|
||||
prompt.System = ""
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func EmbeddingHandler(c *gin.Context) {
|
||||
loaded.mu.Lock()
|
||||
defer loaded.mu.Unlock()
|
||||
|
Loading…
x
Reference in New Issue
Block a user