Compare commits

...

1 Commits

Author SHA1 Message Date
Patrick Devine
3c0d043b79 pass the template to the /api/chat endpoint 2024-07-10 14:17:39 -07:00
4 changed files with 34 additions and 8 deletions

View File

@ -84,6 +84,9 @@ type ChatRequest struct {
// Model is the model name, as in [GenerateRequest]. // Model is the model name, as in [GenerateRequest].
Model string `json:"model"` Model string `json:"model"`
// Template overrides the model's default prompt template.
Template string `json:"template"`
// Messages is the messages of the chat - can be used to keep a chat memory. // Messages is the messages of the chat - can be used to keep a chat memory.
Messages []Message `json:"messages"` Messages []Message `json:"messages"`

View File

@ -947,6 +947,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
req := &api.ChatRequest{ req := &api.ChatRequest{
Model: opts.Model, Model: opts.Model,
Template: opts.Template,
Messages: opts.Messages, Messages: opts.Messages,
Format: opts.Format, Format: opts.Format,
Options: opts.Options, Options: opts.Options,

View File

@ -18,6 +18,7 @@ import (
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline" "github.com/ollama/ollama/readline"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
) )
@ -205,9 +206,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Println("Set system message.") fmt.Println("Set system message.")
sb.Reset() sb.Reset()
case MultilineTemplate: case MultilineTemplate:
opts.Template = sb.String() mTemplate := sb.String()
fmt.Println("Set prompt template.")
sb.Reset() sb.Reset()
_, err := template.Parse(mTemplate)
if err != nil {
multiline = MultilineNone
scanner.Prompt.UseAlt = false
fmt.Println("The template is invalid.")
continue
}
opts.Template = mTemplate
fmt.Println("Set prompt template.")
} }
multiline = MultilineNone multiline = MultilineNone
@ -369,9 +378,15 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Println("Set system message.") fmt.Println("Set system message.")
sb.Reset() sb.Reset()
} else if args[1] == "template" { } else if args[1] == "template" {
opts.Template = sb.String() mTemplate := sb.String()
fmt.Println("Set prompt template.")
sb.Reset() sb.Reset()
_, err := template.Parse(mTemplate)
if err != nil {
fmt.Println("The template is invalid.")
continue
}
opts.Template = mTemplate
fmt.Println("Set prompt template.")
} }
sb.Reset() sb.Reset()

View File

@ -71,7 +71,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options. // scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise. // It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) { func (s *Server) scheduleRunner(ctx context.Context, name string, mTemplate string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
if name == "" { if name == "" {
return nil, nil, nil, fmt.Errorf("model %w", errRequired) return nil, nil, nil, fmt.Errorf("model %w", errRequired)
} }
@ -81,6 +81,13 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
return nil, nil, nil, err return nil, nil, nil, err
} }
if mTemplate != "" {
model.Template, err = template.Parse(mTemplate)
if err != nil {
return nil, nil, nil, err
}
}
if err := model.CheckCapabilities(caps...); err != nil { if err := model.CheckCapabilities(caps...); err != nil {
return nil, nil, nil, fmt.Errorf("%s %w", name, err) return nil, nil, nil, fmt.Errorf("%s %w", name, err)
} }
@ -120,7 +127,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
caps := []Capability{CapabilityCompletion} caps := []Capability{CapabilityCompletion}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, "", caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
return return
@ -256,7 +263,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, "", []Capability{}, req.Options, req.KeepAlive)
if err != nil { if err != nil {
handleScheduleError(c, req.Model, err) handleScheduleError(c, req.Model, err)
return return
@ -1132,7 +1139,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
caps := []Capability{CapabilityCompletion} caps := []Capability{CapabilityCompletion}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, req.Template, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
return return