forked from third-party-mirrors/ollama
add pixtral
This commit is contained in:
parent
685125ab03
commit
016a3df89d
68
models/pixtral/imageproc.go
Normal file
68
models/pixtral/imageproc.go
Normal file
@ -0,0 +1,68 @@
|
||||
package pixtral
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"io"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/models/imageproc"
|
||||
)
|
||||
|
||||
func getNumImageTokens(imageSize, patchSize image.Point) image.Point {
|
||||
return image.Point{
|
||||
(imageSize.X-1)/patchSize.X + 1,
|
||||
(imageSize.Y-1)/patchSize.Y + 1,
|
||||
}
|
||||
}
|
||||
|
||||
func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.Point) image.Point {
|
||||
b := img.Bounds()
|
||||
le := float64(longestEdge)
|
||||
ratio := math.Max(float64(b.Max.Y)/le, float64(b.Max.X)/le)
|
||||
|
||||
newSize := img.Bounds().Max
|
||||
|
||||
if ratio > 1.0 {
|
||||
newSize = image.Point{
|
||||
int(math.Ceil(float64(b.Max.X) / ratio)),
|
||||
int(math.Ceil(float64(b.Max.Y) / ratio)),
|
||||
}
|
||||
}
|
||||
|
||||
tokens := getNumImageTokens(newSize, patchSize)
|
||||
return image.Point{
|
||||
tokens.X * patchSize.X,
|
||||
tokens.Y * patchSize.Y,
|
||||
}
|
||||
}
|
||||
|
||||
func resizeImage(img image.Image, format string, longestEdge int, patchSize image.Point) image.Image {
|
||||
if format == "png" {
|
||||
img = imageproc.Composite(img)
|
||||
}
|
||||
|
||||
newSize := getResizeOutputImageSize(img, longestEdge, patchSize)
|
||||
|
||||
// todo should be ResizeBicubic, but it doesn't exist
|
||||
return imageproc.Resize(img, newSize, imageproc.ResizeBilinear)
|
||||
}
|
||||
|
||||
func Preprocess(imageData io.Reader) ([]float32, map[string]any, error) {
|
||||
img, format, err := image.Decode(imageData)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to decode image: %w", err)
|
||||
}
|
||||
|
||||
longestEdge := 1024
|
||||
patchSize := image.Point{16, 16}
|
||||
|
||||
img = resizeImage(img, format, longestEdge, patchSize)
|
||||
|
||||
data := imageproc.Normalize(img, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
|
||||
|
||||
opts := map[string]any{}
|
||||
return data, opts, nil
|
||||
}
|
190
models/pixtral/imageproc_test.go
Normal file
190
models/pixtral/imageproc_test.go
Normal file
@ -0,0 +1,190 @@
|
||||
package pixtral
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"image/png"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestGetNumImageTokens(t *testing.T) {
|
||||
type numImageTokensCase struct {
|
||||
ImageSize image.Point
|
||||
PatchSize image.Point
|
||||
Expected image.Point
|
||||
}
|
||||
|
||||
cases := []numImageTokensCase{
|
||||
{
|
||||
ImageSize: image.Point{1024, 764},
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{64, 48},
|
||||
},
|
||||
{
|
||||
ImageSize: image.Point{800, 600},
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{50, 38},
|
||||
},
|
||||
{
|
||||
ImageSize: image.Point{640, 480},
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{40, 30},
|
||||
},
|
||||
{
|
||||
ImageSize: image.Point{320, 200},
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{20, 13},
|
||||
},
|
||||
{
|
||||
ImageSize: image.Point{1320, 200},
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{83, 13},
|
||||
},
|
||||
{
|
||||
ImageSize: image.Point{2000, 200},
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{125, 13},
|
||||
},
|
||||
{
|
||||
ImageSize: image.Point{10000, 200},
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{625, 13},
|
||||
},
|
||||
{
|
||||
ImageSize: image.Point{1131, 577},
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{71, 37},
|
||||
},
|
||||
{
|
||||
ImageSize: image.Point{16, 16},
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{1, 1},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
actual := getNumImageTokens(c.ImageSize, c.PatchSize)
|
||||
|
||||
if diff := cmp.Diff(actual, c.Expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetResizeOutputImageSize(t *testing.T) {
|
||||
type resizeCase struct {
|
||||
Image image.Image
|
||||
LongestEdge int
|
||||
PatchSize image.Point
|
||||
Expected image.Point
|
||||
}
|
||||
|
||||
cases := []resizeCase{
|
||||
{
|
||||
Image: image.NewRGBA(image.Rect(0, 0, 1024, 768)),
|
||||
LongestEdge: 1024,
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{1024, 768},
|
||||
},
|
||||
{
|
||||
Image: image.NewRGBA(image.Rect(0, 0, 1162, 690)),
|
||||
LongestEdge: 1024,
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{1024, 624},
|
||||
},
|
||||
{
|
||||
Image: image.NewRGBA(image.Rect(0, 0, 300, 200)),
|
||||
LongestEdge: 1024,
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{304, 208},
|
||||
},
|
||||
{
|
||||
Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)),
|
||||
LongestEdge: 1024,
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{1024, 288},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
actual := getResizeOutputImageSize(c.Image, c.LongestEdge, c.PatchSize)
|
||||
|
||||
if diff := cmp.Diff(actual, c.Expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResize(t *testing.T) {
|
||||
type resizeCase struct {
|
||||
Image image.Image
|
||||
LongestEdge int
|
||||
PatchSize image.Point
|
||||
Expected image.Image
|
||||
}
|
||||
|
||||
cases := []resizeCase{
|
||||
{
|
||||
Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)),
|
||||
LongestEdge: 1024,
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.NewRGBA(image.Rect(0, 0, 1024, 288)),
|
||||
},
|
||||
{
|
||||
Image: image.NewRGBA(image.Rect(0, 0, 10, 10)),
|
||||
LongestEdge: 1024,
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.NewRGBA(image.Rect(0, 0, 16, 16)),
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
actual := resizeImage(c.Image, "png", c.LongestEdge, c.PatchSize)
|
||||
|
||||
if actual.Bounds() != c.Expected.Bounds() {
|
||||
t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual.Bounds(), c.Expected.Bounds())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreprocess(t *testing.T) {
|
||||
type preprocessCase struct {
|
||||
TestImage image.Image
|
||||
ExpectedLen int
|
||||
}
|
||||
|
||||
cases := []preprocessCase{
|
||||
{
|
||||
TestImage: image.NewRGBA(image.Rect(0, 0, 10, 10)),
|
||||
ExpectedLen: 16 * 16 * 3 * 1,
|
||||
},
|
||||
{
|
||||
TestImage: image.NewRGBA(image.Rect(0, 0, 2000, 2000)),
|
||||
ExpectedLen: 1024 * 1024 * 3 * 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
var buf bytes.Buffer
|
||||
err := png.Encode(&buf, c.TestImage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
imgData, _, err := Preprocess(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("error processing: %q", err)
|
||||
}
|
||||
|
||||
switch len(imgData) {
|
||||
case 0:
|
||||
t.Errorf("no image data returned")
|
||||
case c.ExpectedLen:
|
||||
// ok
|
||||
default:
|
||||
t.Errorf("unexpected image data length: %d, expected: %d", len(imgData), c.ExpectedLen)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user