only allow a single image to be passed

This commit is contained in:
Patrick Devine 2024-10-08 18:30:07 -07:00
parent 03cf7627ec
commit 3a1c8da5e4
2 changed files with 82 additions and 24 deletions

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"log/slog"
"strings"
@ -16,16 +17,28 @@ import (
type tokenizeFunc func(context.Context, string) ([]int, error)
var errTooManyImages = errors.New("vision model only supports a single image per message")
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
var system []api.Message
// always include the last message
isMllama := checkMllamaModelFamily(m)
n := len(msgs) - 1
// in reverse, find all messages that fit into context window
for i := n - 1; i >= 0; i-- {
for i := n; i >= 0; i-- {
if isMllama && len(msgs[i].Images) > 1 {
return "", nil, errTooManyImages
}
// always include the last message
if i == n {
continue
}
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
@ -62,27 +75,30 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
currMsgIdx := n
if checkMllamaModelFamily(m) {
if isMllama {
lastMsgIdx := len(msgs) - 1
if len(msgs[lastMsgIdx].Images) == 1 {
data, aspectRatioID, err := imageproc.Preprocess(msgs[lastMsgIdx].Images[0])
if err != nil {
return "", nil, err
}
for i := lastMsgIdx; i > currMsgIdx; i-- {
if len(msgs[i].Images) > 0 {
data, aspectRatioID, err := imageproc.Preprocess(msgs[i].Images[0])
if err != nil {
return "", nil, err
}
buf := new(bytes.Buffer)
err = binary.Write(buf, binary.LittleEndian, data)
if err != nil {
return "", nil, err
}
buf := new(bytes.Buffer)
err = binary.Write(buf, binary.LittleEndian, data)
if err != nil {
return "", nil, err
}
imgData := llm.ImageData{
Data: buf.Bytes(),
AspectRatioID: aspectRatioID,
}
imgData := llm.ImageData{
Data: buf.Bytes(),
AspectRatioID: aspectRatioID,
}
msgs[lastMsgIdx].Content = strings.TrimSpace("<|image|>" + msgs[lastMsgIdx].Content)
images = append(images, imgData)
msgs[i].Content = strings.TrimSpace("<|image|>" + msgs[i].Content)
images = append(images, imgData)
break
}
}
} else {
for cnt, msg := range msgs[currMsgIdx:] {

View File

@ -18,6 +18,7 @@ func TestChatPrompt(t *testing.T) {
prompt string
images [][]byte
aspectRatioID int
error error
}
tmpl, err := template.Parse(`
@ -30,15 +31,26 @@ func TestChatPrompt(t *testing.T) {
visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
mllamaModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}, Config: ConfigV2{ModelFamilies: []string{"mllama"}}}
img := image.NewRGBA(image.Rect(0, 0, 5, 5))
var buf bytes.Buffer
createImg := func(width, height int) ([]byte, error) {
img := image.NewRGBA(image.Rect(0, 0, 5, 5))
var buf bytes.Buffer
err = png.Encode(&buf, img)
if err := png.Encode(&buf, img); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
imgBuf, err := createImg(5, 5)
if err != nil {
t.Fatal(err)
}
imgBuf := buf.Bytes()
imgBuf2, err := createImg(6, 6)
if err != nil {
t.Fatal(err)
}
cases := []struct {
name string
@ -232,6 +244,34 @@ func TestChatPrompt(t *testing.T) {
aspectRatioID: 1,
},
},
{
name: "multiple messages with mllama",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{imgBuf}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{imgBuf2},
aspectRatioID: 1,
},
},
{
name: "too many images with mllama",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf, imgBuf}},
},
expect: expect{
error: errTooManyImages,
},
},
}
for _, tt := range cases {
@ -239,8 +279,10 @@ func TestChatPrompt(t *testing.T) {
model := tt.model
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
if err != nil {
if tt.error == nil && err != nil {
t.Fatal(err)
} else if tt.error != nil && err != tt.error {
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
}
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {