From 016a3df89dd30691d44ec5ab29c1a35193d16aae Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Tue, 12 Nov 2024 18:26:48 -0800 Subject: [PATCH] add pixtral --- models/pixtral/imageproc.go | 68 +++++++++++ models/pixtral/imageproc_test.go | 190 +++++++++++++++++++++++++++++++ 2 files changed, 258 insertions(+) create mode 100644 models/pixtral/imageproc.go create mode 100644 models/pixtral/imageproc_test.go diff --git a/models/pixtral/imageproc.go b/models/pixtral/imageproc.go new file mode 100644 index 00000000..0c5187ae --- /dev/null +++ b/models/pixtral/imageproc.go @@ -0,0 +1,68 @@ +package pixtral + +import ( + "fmt" + "image" + _ "image/jpeg" + _ "image/png" + "io" + "math" + + "github.com/ollama/ollama/models/imageproc" +) + +func getNumImageTokens(imageSize, patchSize image.Point) image.Point { + return image.Point{ + (imageSize.X-1)/patchSize.X + 1, + (imageSize.Y-1)/patchSize.Y + 1, + } +} + +func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.Point) image.Point { + b := img.Bounds() + le := float64(longestEdge) + ratio := math.Max(float64(b.Max.Y)/le, float64(b.Max.X)/le) + + newSize := img.Bounds().Max + + if ratio > 1.0 { + newSize = image.Point{ + int(math.Ceil(float64(b.Max.X) / ratio)), + int(math.Ceil(float64(b.Max.Y) / ratio)), + } + } + + tokens := getNumImageTokens(newSize, patchSize) + return image.Point{ + tokens.X * patchSize.X, + tokens.Y * patchSize.Y, + } +} + +func resizeImage(img image.Image, format string, longestEdge int, patchSize image.Point) image.Image { + if format == "png" { + img = imageproc.Composite(img) + } + + newSize := getResizeOutputImageSize(img, longestEdge, patchSize) + + // todo should be ResizeBicubic, but it doesn't exist + return imageproc.Resize(img, newSize, imageproc.ResizeBilinear) +} + +func Preprocess(imageData io.Reader) ([]float32, map[string]any, error) { + img, format, err := image.Decode(imageData) + if err != nil { + return nil, nil, fmt.Errorf("failed to decode image: %w", err) + } + + longestEdge := 1024 + patchSize := image.Point{16, 16} + + img = resizeImage(img, format, longestEdge, patchSize) + + data := imageproc.Normalize(img, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true) + + opts := map[string]any{} + return data, opts, nil +} diff --git a/models/pixtral/imageproc_test.go b/models/pixtral/imageproc_test.go new file mode 100644 index 00000000..698a56e4 --- /dev/null +++ b/models/pixtral/imageproc_test.go @@ -0,0 +1,190 @@ +package pixtral + +import ( + "bytes" + "image" + "image/png" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestGetNumImageTokens(t *testing.T) { + type numImageTokensCase struct { + ImageSize image.Point + PatchSize image.Point + Expected image.Point + } + + cases := []numImageTokensCase{ + { + ImageSize: image.Point{1024, 764}, + PatchSize: image.Point{16, 16}, + Expected: image.Point{64, 48}, + }, + { + ImageSize: image.Point{800, 600}, + PatchSize: image.Point{16, 16}, + Expected: image.Point{50, 38}, + }, + { + ImageSize: image.Point{640, 480}, + PatchSize: image.Point{16, 16}, + Expected: image.Point{40, 30}, + }, + { + ImageSize: image.Point{320, 200}, + PatchSize: image.Point{16, 16}, + Expected: image.Point{20, 13}, + }, + { + ImageSize: image.Point{1320, 200}, + PatchSize: image.Point{16, 16}, + Expected: image.Point{83, 13}, + }, + { + ImageSize: image.Point{2000, 200}, + PatchSize: image.Point{16, 16}, + Expected: image.Point{125, 13}, + }, + { + ImageSize: image.Point{10000, 200}, + PatchSize: image.Point{16, 16}, + Expected: image.Point{625, 13}, + }, + { + ImageSize: image.Point{1131, 577}, + PatchSize: image.Point{16, 16}, + Expected: image.Point{71, 37}, + }, + { + ImageSize: image.Point{16, 16}, + PatchSize: image.Point{16, 16}, + Expected: image.Point{1, 1}, + }, + } + + for _, c := range cases { + actual := getNumImageTokens(c.ImageSize, c.PatchSize) + + if diff := cmp.Diff(actual, c.Expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + } +} + +func TestGetResizeOutputImageSize(t *testing.T) { + type resizeCase struct { + Image image.Image + LongestEdge int + PatchSize image.Point + Expected image.Point + } + + cases := []resizeCase{ + { + Image: image.NewRGBA(image.Rect(0, 0, 1024, 768)), + LongestEdge: 1024, + PatchSize: image.Point{16, 16}, + Expected: image.Point{1024, 768}, + }, + { + Image: image.NewRGBA(image.Rect(0, 0, 1162, 690)), + LongestEdge: 1024, + PatchSize: image.Point{16, 16}, + Expected: image.Point{1024, 624}, + }, + { + Image: image.NewRGBA(image.Rect(0, 0, 300, 200)), + LongestEdge: 1024, + PatchSize: image.Point{16, 16}, + Expected: image.Point{304, 208}, + }, + { + Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)), + LongestEdge: 1024, + PatchSize: image.Point{16, 16}, + Expected: image.Point{1024, 288}, + }, + } + + for _, c := range cases { + actual := getResizeOutputImageSize(c.Image, c.LongestEdge, c.PatchSize) + + if diff := cmp.Diff(actual, c.Expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + } +} + +func TestResize(t *testing.T) { + type resizeCase struct { + Image image.Image + LongestEdge int + PatchSize image.Point + Expected image.Image + } + + cases := []resizeCase{ + { + Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)), + LongestEdge: 1024, + PatchSize: image.Point{16, 16}, + Expected: image.NewRGBA(image.Rect(0, 0, 1024, 288)), + }, + { + Image: image.NewRGBA(image.Rect(0, 0, 10, 10)), + LongestEdge: 1024, + PatchSize: image.Point{16, 16}, + Expected: image.NewRGBA(image.Rect(0, 0, 16, 16)), + }, + } + + for _, c := range cases { + actual := resizeImage(c.Image, "png", c.LongestEdge, c.PatchSize) + + if actual.Bounds() != c.Expected.Bounds() { + t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual.Bounds(), c.Expected.Bounds()) + } + } +} + +func TestPreprocess(t *testing.T) { + type preprocessCase struct { + TestImage image.Image + ExpectedLen int + } + + cases := []preprocessCase{ + { + TestImage: image.NewRGBA(image.Rect(0, 0, 10, 10)), + ExpectedLen: 16 * 16 * 3 * 1, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 2000, 2000)), + ExpectedLen: 1024 * 1024 * 3 * 1, + }, + } + + for _, c := range cases { + var buf bytes.Buffer + err := png.Encode(&buf, c.TestImage) + if err != nil { + t.Fatal(err) + } + + imgData, _, err := Preprocess(&buf) + if err != nil { + t.Fatalf("error processing: %q", err) + } + + switch len(imgData) { + case 0: + t.Errorf("no image data returned") + case c.ExpectedLen: + // ok + default: + t.Errorf("unexpected image data length: %d, expected: %d", len(imgData), c.ExpectedLen) + } + } +}