only allow a single image to be passed
This commit is contained in:
parent
03cf7627ec
commit
3a1c8da5e4
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
@ -16,16 +17,28 @@ import (
|
||||
|
||||
type tokenizeFunc func(context.Context, string) ([]int, error)
|
||||
|
||||
var errTooManyImages = errors.New("vision model only supports a single image per message")
|
||||
|
||||
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||
// latest message and 2) system messages
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
|
||||
var system []api.Message
|
||||
|
||||
// always include the last message
|
||||
isMllama := checkMllamaModelFamily(m)
|
||||
|
||||
n := len(msgs) - 1
|
||||
// in reverse, find all messages that fit into context window
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
for i := n; i >= 0; i-- {
|
||||
if isMllama && len(msgs[i].Images) > 1 {
|
||||
return "", nil, errTooManyImages
|
||||
}
|
||||
|
||||
// always include the last message
|
||||
if i == n {
|
||||
continue
|
||||
}
|
||||
|
||||
system = make([]api.Message, 0)
|
||||
for j := range i {
|
||||
if msgs[j].Role == "system" {
|
||||
@ -62,27 +75,30 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
|
||||
currMsgIdx := n
|
||||
|
||||
if checkMllamaModelFamily(m) {
|
||||
if isMllama {
|
||||
lastMsgIdx := len(msgs) - 1
|
||||
if len(msgs[lastMsgIdx].Images) == 1 {
|
||||
data, aspectRatioID, err := imageproc.Preprocess(msgs[lastMsgIdx].Images[0])
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
for i := lastMsgIdx; i > currMsgIdx; i-- {
|
||||
if len(msgs[i].Images) > 0 {
|
||||
data, aspectRatioID, err := imageproc.Preprocess(msgs[i].Images[0])
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
err = binary.Write(buf, binary.LittleEndian, data)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
buf := new(bytes.Buffer)
|
||||
err = binary.Write(buf, binary.LittleEndian, data)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
imgData := llm.ImageData{
|
||||
Data: buf.Bytes(),
|
||||
AspectRatioID: aspectRatioID,
|
||||
}
|
||||
imgData := llm.ImageData{
|
||||
Data: buf.Bytes(),
|
||||
AspectRatioID: aspectRatioID,
|
||||
}
|
||||
|
||||
msgs[lastMsgIdx].Content = strings.TrimSpace("<|image|>" + msgs[lastMsgIdx].Content)
|
||||
images = append(images, imgData)
|
||||
msgs[i].Content = strings.TrimSpace("<|image|>" + msgs[i].Content)
|
||||
images = append(images, imgData)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for cnt, msg := range msgs[currMsgIdx:] {
|
||||
|
@ -18,6 +18,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
prompt string
|
||||
images [][]byte
|
||||
aspectRatioID int
|
||||
error error
|
||||
}
|
||||
|
||||
tmpl, err := template.Parse(`
|
||||
@ -30,15 +31,26 @@ func TestChatPrompt(t *testing.T) {
|
||||
visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
||||
mllamaModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}, Config: ConfigV2{ModelFamilies: []string{"mllama"}}}
|
||||
|
||||
img := image.NewRGBA(image.Rect(0, 0, 5, 5))
|
||||
var buf bytes.Buffer
|
||||
createImg := func(width, height int) ([]byte, error) {
|
||||
img := image.NewRGBA(image.Rect(0, 0, 5, 5))
|
||||
var buf bytes.Buffer
|
||||
|
||||
err = png.Encode(&buf, img)
|
||||
if err := png.Encode(&buf, img); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
imgBuf, err := createImg(5, 5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
imgBuf := buf.Bytes()
|
||||
imgBuf2, err := createImg(6, 6)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
@ -232,6 +244,34 @@ func TestChatPrompt(t *testing.T) {
|
||||
aspectRatioID: 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple messages with mllama",
|
||||
model: mllamaModel,
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{imgBuf}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{imgBuf2},
|
||||
aspectRatioID: 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "too many images with mllama",
|
||||
model: mllamaModel,
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf, imgBuf}},
|
||||
},
|
||||
expect: expect{
|
||||
error: errTooManyImages,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
@ -239,8 +279,10 @@ func TestChatPrompt(t *testing.T) {
|
||||
model := tt.model
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
|
||||
if err != nil {
|
||||
if tt.error == nil && err != nil {
|
||||
t.Fatal(err)
|
||||
} else if tt.error != nil && err != tt.error {
|
||||
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
|
||||
|
Loading…
x
Reference in New Issue
Block a user