Compare commits
2 Commits
main
...
mxyng/func
Author | SHA1 | Date | |
---|---|---|---|
|
97d5582c24 | ||
|
d44647efc8 |
@ -298,7 +298,7 @@ func detectContentType(r io.Reader) (string, error) {
|
|||||||
// mxyng: this only really works if the input contains tool calls in some JSON format
|
// mxyng: this only really works if the input contains tool calls in some JSON format
|
||||||
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||||
// create a subtree from the node that ranges over .ToolCalls
|
// create a subtree from the node that ranges over .ToolCalls
|
||||||
tmpl := m.Template.Subtree(func(n parse.Node) bool {
|
tmpl := m.Template.Sub(func(n parse.Node) bool {
|
||||||
if t, ok := n.(*parse.RangeNode); ok {
|
if t, ok := n.(*parse.RangeNode); ok {
|
||||||
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||||
}
|
}
|
||||||
@ -311,7 +311,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
if err := tmpl.Template().Execute(&b, map[string][]api.ToolCall{
|
||||||
"ToolCalls": {
|
"ToolCalls": {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
|
@ -515,18 +515,6 @@ func TestCreateTemplateSystem(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 400, actual %d", w.Code)
|
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("template with undefined function", func(t *testing.T) {
|
|
||||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
|
||||||
Name: "test",
|
|
||||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
|
|
||||||
Stream: &stream,
|
|
||||||
})
|
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected status code 400, actual %d", w.Code)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateLicenses(t *testing.T) {
|
func TestCreateLicenses(t *testing.T) {
|
||||||
|
@ -79,7 +79,7 @@ func Named(s string) (*named, error) {
|
|||||||
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
|
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
|
||||||
|
|
||||||
type Template struct {
|
type Template struct {
|
||||||
*template.Template
|
tree *parse.Tree
|
||||||
raw string
|
raw string
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,17 +110,18 @@ var funcs = template.FuncMap{
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Parse(s string) (*Template, error) {
|
func Parse(s string) (*Template, error) {
|
||||||
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
|
tree := parse.New("")
|
||||||
|
tree.Mode = tree.Mode | parse.SkipFuncCheck
|
||||||
|
|
||||||
tmpl, err := tmpl.Parse(s)
|
tree, err := tree.Parse(s, "", "", map[string]*parse.Tree{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
t := Template{Template: tmpl, raw: s}
|
t := Template{tree, s}
|
||||||
if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
|
if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
|
||||||
// touch up the template and append {{ .Response }}
|
// touch up the template and append {{ .Response }}
|
||||||
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
|
t.tree.Root.Nodes = append(t.tree.Root.Nodes, &response)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &t, nil
|
return &t, nil
|
||||||
@ -132,11 +133,9 @@ func (t *Template) String() string {
|
|||||||
|
|
||||||
func (t *Template) Vars() []string {
|
func (t *Template) Vars() []string {
|
||||||
var vars []string
|
var vars []string
|
||||||
for _, tt := range t.Templates() {
|
for _, n := range t.tree.Root.Nodes {
|
||||||
for _, n := range tt.Root.Nodes {
|
|
||||||
vars = append(vars, Identifiers(n)...)
|
vars = append(vars, Identifiers(n)...)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
set := make(map[string]struct{})
|
set := make(map[string]struct{})
|
||||||
for _, n := range vars {
|
for _, n := range vars {
|
||||||
@ -158,7 +157,8 @@ type Values struct {
|
|||||||
forceLegacy bool
|
forceLegacy bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
|
// Sub returns a new template with the subtree that matches the predicate
|
||||||
|
func (t *Template) Sub(fn func(parse.Node) bool) *Template {
|
||||||
var walk func(parse.Node) parse.Node
|
var walk func(parse.Node) parse.Node
|
||||||
walk = func(n parse.Node) parse.Node {
|
walk = func(n parse.Node) parse.Node {
|
||||||
if fn(n) {
|
if fn(n) {
|
||||||
@ -191,29 +191,34 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if n := walk(t.Tree.Root); n != nil {
|
if n := walk(t.tree.Root); n != nil {
|
||||||
return (&template.Template{
|
return &Template{
|
||||||
Tree: &parse.Tree{
|
tree: &parse.Tree{
|
||||||
Root: &parse.ListNode{
|
Root: &parse.ListNode{
|
||||||
Nodes: []parse.Node{n},
|
Nodes: []parse.Node{n},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}).Funcs(funcs)
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Template) Template() *template.Template {
|
||||||
|
return template.Must(template.New("").Option("missingkey=zero").Funcs(funcs).AddParseTree("", t.tree))
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Template) Execute(w io.Writer, v Values) error {
|
func (t *Template) Execute(w io.Writer, v Values) error {
|
||||||
|
tmpl := t.Template()
|
||||||
system, messages := collate(v.Messages)
|
system, messages := collate(v.Messages)
|
||||||
if v.Prompt != "" && v.Suffix != "" {
|
if v.Prompt != "" && v.Suffix != "" {
|
||||||
return t.Template.Execute(w, map[string]any{
|
return tmpl.Execute(w, map[string]any{
|
||||||
"Prompt": v.Prompt,
|
"Prompt": v.Prompt,
|
||||||
"Suffix": v.Suffix,
|
"Suffix": v.Suffix,
|
||||||
"Response": "",
|
"Response": "",
|
||||||
})
|
})
|
||||||
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
||||||
return t.Template.Execute(w, map[string]any{
|
return tmpl.Execute(w, map[string]any{
|
||||||
"System": system,
|
"System": system,
|
||||||
"Messages": messages,
|
"Messages": messages,
|
||||||
"Tools": v.Tools,
|
"Tools": v.Tools,
|
||||||
@ -226,7 +231,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||||||
var prompt, response string
|
var prompt, response string
|
||||||
for _, m := range messages {
|
for _, m := range messages {
|
||||||
execute := func() error {
|
execute := func() error {
|
||||||
if err := t.Template.Execute(&b, map[string]any{
|
if err := tmpl.Execute(&b, map[string]any{
|
||||||
"System": system,
|
"System": system,
|
||||||
"Prompt": prompt,
|
"Prompt": prompt,
|
||||||
"Response": response,
|
"Response": response,
|
||||||
@ -261,7 +266,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var cut bool
|
var cut bool
|
||||||
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
|
nodes := deleteNode(t.tree.Root.Copy(), func(n parse.Node) bool {
|
||||||
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
|
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
|
||||||
cut = true
|
cut = true
|
||||||
return false
|
return false
|
||||||
@ -271,7 +276,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||||||
})
|
})
|
||||||
|
|
||||||
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
|
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
|
||||||
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
|
if err := template.Must(tmpl.AddParseTree("", &tree)).Execute(&b, map[string]any{
|
||||||
"System": system,
|
"System": system,
|
||||||
"Prompt": prompt,
|
"Prompt": prompt,
|
||||||
"Response": response,
|
"Response": response,
|
||||||
|
@ -53,7 +53,7 @@ func TestNamed(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tmpl.Tree.Root.String() == "" {
|
if tmpl.tree.Root.String() == "" {
|
||||||
t.Errorf("empty %s template", k)
|
t.Errorf("empty %s template", k)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -152,7 +152,7 @@ func TestTemplate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParse(t *testing.T) {
|
func TestParseVars(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
template string
|
template string
|
||||||
vars []string
|
vars []string
|
||||||
@ -180,6 +180,9 @@ func TestParse(t *testing.T) {
|
|||||||
{{ end }}<|im_start|>assistant
|
{{ end }}<|im_start|>assistant
|
||||||
{{ .Response }}<|im_end|>
|
{{ .Response }}<|im_end|>
|
||||||
{{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
|
{{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
|
||||||
|
{"{{ json .Messages }}", []string{"messages"}},
|
||||||
|
// undefined functions should not error
|
||||||
|
{"{{ undefined }}", []string{"response"}},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
@ -196,6 +199,30 @@ func TestParse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseExecute(t *testing.T) {
|
||||||
|
t.Run("undefined function", func(t *testing.T) {
|
||||||
|
tmpl, err := Parse(`{{- if .Suffix }}{{ .Prompt }} {{ .Suffix }}{{- else }}{{ undefined }}{{- end }}`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&b, Values{Prompt: "def add(", Suffix: " return c"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(b.String(), "def add( return c"); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tmpl.Execute(io.Discard, Values{}); err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
} else if !strings.Contains(err.Error(), "\"undefined\" is not a defined function") {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestExecuteWithMessages(t *testing.T) {
|
func TestExecuteWithMessages(t *testing.T) {
|
||||||
type template struct {
|
type template struct {
|
||||||
name string
|
name string
|
||||||
|
Loading…
x
Reference in New Issue
Block a user