Compare commits

...

2 Commits

Author SHA1 Message Date
Michael Yang
97d5582c24 test: rm undefined function test 2024-07-22 10:48:49 -07:00
Michael Yang
d44647efc8 template: disable func checking 2024-07-22 10:48:49 -07:00
4 changed files with 57 additions and 37 deletions

View File

@ -298,7 +298,7 @@ func detectContentType(r io.Reader) (string, error) {
// mxyng: this only really works if the input contains tool calls in some JSON format // mxyng: this only really works if the input contains tool calls in some JSON format
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// create a subtree from the node that ranges over .ToolCalls // create a subtree from the node that ranges over .ToolCalls
tmpl := m.Template.Subtree(func(n parse.Node) bool { tmpl := m.Template.Sub(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok { if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
} }
@ -311,7 +311,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
} }
var b bytes.Buffer var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]api.ToolCall{ if err := tmpl.Template().Execute(&b, map[string][]api.ToolCall{
"ToolCalls": { "ToolCalls": {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{

View File

@ -498,7 +498,7 @@ func TestCreateTemplateSystem(t *testing.T) {
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
}) })
if w.Code != http.StatusBadRequest { if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code) t.Fatalf("expected status code 400, actual %d", w.Code)
} }
@ -510,19 +510,7 @@ func TestCreateTemplateSystem(t *testing.T) {
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
}) })
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
t.Run("template with undefined function", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest { if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code) t.Fatalf("expected status code 400, actual %d", w.Code)
} }

View File

@ -79,8 +79,8 @@ func Named(s string) (*named, error) {
var DefaultTemplate, _ = Parse("{{ .Prompt }}") var DefaultTemplate, _ = Parse("{{ .Prompt }}")
type Template struct { type Template struct {
*template.Template tree *parse.Tree
raw string raw string
} }
// response is a template node that can be added to templates that don't already have one // response is a template node that can be added to templates that don't already have one
@ -110,17 +110,18 @@ var funcs = template.FuncMap{
} }
func Parse(s string) (*Template, error) { func Parse(s string) (*Template, error) {
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs) tree := parse.New("")
tree.Mode = tree.Mode | parse.SkipFuncCheck
tmpl, err := tmpl.Parse(s) tree, err := tree.Parse(s, "", "", map[string]*parse.Tree{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
t := Template{Template: tmpl, raw: s} t := Template{tree, s}
if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") { if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
// touch up the template and append {{ .Response }} // touch up the template and append {{ .Response }}
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response) t.tree.Root.Nodes = append(t.tree.Root.Nodes, &response)
} }
return &t, nil return &t, nil
@ -132,10 +133,8 @@ func (t *Template) String() string {
func (t *Template) Vars() []string { func (t *Template) Vars() []string {
var vars []string var vars []string
for _, tt := range t.Templates() { for _, n := range t.tree.Root.Nodes {
for _, n := range tt.Root.Nodes { vars = append(vars, Identifiers(n)...)
vars = append(vars, Identifiers(n)...)
}
} }
set := make(map[string]struct{}) set := make(map[string]struct{})
@ -158,7 +157,8 @@ type Values struct {
forceLegacy bool forceLegacy bool
} }
func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template { // Sub returns a new template with the subtree that matches the predicate
func (t *Template) Sub(fn func(parse.Node) bool) *Template {
var walk func(parse.Node) parse.Node var walk func(parse.Node) parse.Node
walk = func(n parse.Node) parse.Node { walk = func(n parse.Node) parse.Node {
if fn(n) { if fn(n) {
@ -191,29 +191,34 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
return nil return nil
} }
if n := walk(t.Tree.Root); n != nil { if n := walk(t.tree.Root); n != nil {
return (&template.Template{ return &Template{
Tree: &parse.Tree{ tree: &parse.Tree{
Root: &parse.ListNode{ Root: &parse.ListNode{
Nodes: []parse.Node{n}, Nodes: []parse.Node{n},
}, },
}, },
}).Funcs(funcs) }
} }
return nil return nil
} }
func (t *Template) Template() *template.Template {
return template.Must(template.New("").Option("missingkey=zero").Funcs(funcs).AddParseTree("", t.tree))
}
func (t *Template) Execute(w io.Writer, v Values) error { func (t *Template) Execute(w io.Writer, v Values) error {
tmpl := t.Template()
system, messages := collate(v.Messages) system, messages := collate(v.Messages)
if v.Prompt != "" && v.Suffix != "" { if v.Prompt != "" && v.Suffix != "" {
return t.Template.Execute(w, map[string]any{ return tmpl.Execute(w, map[string]any{
"Prompt": v.Prompt, "Prompt": v.Prompt,
"Suffix": v.Suffix, "Suffix": v.Suffix,
"Response": "", "Response": "",
}) })
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { } else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{ return tmpl.Execute(w, map[string]any{
"System": system, "System": system,
"Messages": messages, "Messages": messages,
"Tools": v.Tools, "Tools": v.Tools,
@ -226,7 +231,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
var prompt, response string var prompt, response string
for _, m := range messages { for _, m := range messages {
execute := func() error { execute := func() error {
if err := t.Template.Execute(&b, map[string]any{ if err := tmpl.Execute(&b, map[string]any{
"System": system, "System": system,
"Prompt": prompt, "Prompt": prompt,
"Response": response, "Response": response,
@ -261,7 +266,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
} }
var cut bool var cut bool
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool { nodes := deleteNode(t.tree.Root.Copy(), func(n parse.Node) bool {
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") { if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
cut = true cut = true
return false return false
@ -271,7 +276,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
}) })
tree := parse.Tree{Root: nodes.(*parse.ListNode)} tree := parse.Tree{Root: nodes.(*parse.ListNode)}
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{ if err := template.Must(tmpl.AddParseTree("", &tree)).Execute(&b, map[string]any{
"System": system, "System": system,
"Prompt": prompt, "Prompt": prompt,
"Response": response, "Response": response,

View File

@ -53,7 +53,7 @@ func TestNamed(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if tmpl.Tree.Root.String() == "" { if tmpl.tree.Root.String() == "" {
t.Errorf("empty %s template", k) t.Errorf("empty %s template", k)
} }
}) })
@ -152,7 +152,7 @@ func TestTemplate(t *testing.T) {
} }
} }
func TestParse(t *testing.T) { func TestParseVars(t *testing.T) {
cases := []struct { cases := []struct {
template string template string
vars []string vars []string
@ -180,6 +180,9 @@ func TestParse(t *testing.T) {
{{ end }}<|im_start|>assistant {{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|> {{ .Response }}<|im_end|>
{{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}}, {{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
{"{{ json .Messages }}", []string{"messages"}},
// undefined functions should not error
{"{{ undefined }}", []string{"response"}},
} }
for _, tt := range cases { for _, tt := range cases {
@ -196,6 +199,30 @@ func TestParse(t *testing.T) {
} }
} }
func TestParseExecute(t *testing.T) {
t.Run("undefined function", func(t *testing.T) {
tmpl, err := Parse(`{{- if .Suffix }}{{ .Prompt }} {{ .Suffix }}{{- else }}{{ undefined }}{{- end }}`)
if err != nil {
t.Fatal(err)
}
var b bytes.Buffer
if err := tmpl.Execute(&b, Values{Prompt: "def add(", Suffix: " return c"}); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(b.String(), "def add( return c"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
if err := tmpl.Execute(io.Discard, Values{}); err == nil {
t.Fatal("expected error")
} else if !strings.Contains(err.Error(), "\"undefined\" is not a defined function") {
t.Fatal(err)
}
})
}
func TestExecuteWithMessages(t *testing.T) { func TestExecuteWithMessages(t *testing.T) {
type template struct { type template struct {
name string name string