diff --git a/go.mod b/go.mod index 6e437c73..1e61d3cb 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/mattn/go-runewidth v0.0.14 github.com/nlpodyssey/gopickle v0.3.0 github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c + golang.org/x/image v0.14.0 ) require ( diff --git a/go.sum b/go.sum index 926ed26d..d4d1c9a9 100644 --- a/go.sum +++ b/go.sum @@ -230,6 +230,8 @@ golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+o golang.org/x/image v0.0.0-20200618115811-c13761719519/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20201208152932-35266b937fa6/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20210216034530-4410531fe030/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= +golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= diff --git a/llm/server.go b/llm/server.go index 6c504f14..f8b81c6d 100644 --- a/llm/server.go +++ b/llm/server.go @@ -673,8 +673,10 @@ ws ::= ([ \t\n] ws)? const maxBufferSize = 512 * format.KiloByte type ImageData struct { - Data []byte `json:"data"` - ID int `json:"id"` + Data []byte `json:"data"` + ID int `json:"id"` + ImageData []float32 `json:"image_data"` + AspectRatioID int `json:"aspect_ratio_id"` } type completion struct { diff --git a/server/imageproc/images.go b/server/imageproc/images.go new file mode 100644 index 00000000..652aab75 --- /dev/null +++ b/server/imageproc/images.go @@ -0,0 +1,238 @@ +package imageproc + +import ( + "bytes" + "fmt" + "image" + _ "image/jpeg" + _ "image/png" + "math" + + "golang.org/x/image/draw" +) + +func GetSupportedAspectRatios(maxTiles int) []image.Point { + ratios := []image.Point{} + + for w := range maxTiles { + for h := range maxTiles { + if (w+1)*(h+1) <= maxTiles { + ratios = append(ratios, image.Point{w + 1, h + 1}) + } + } + } + + return ratios +} + +func clip(a, a_min, a_max int) int { + if a < a_min { + return a_min + } else if a > a_max { + return a_max + } + + return a +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func GetImageSizeFitToCanvas(imageSize, canvasSize image.Point, tileSize int) image.Point { + targetWidth := clip(imageSize.X, tileSize, canvasSize.X) + targetHeight := clip(imageSize.Y, tileSize, canvasSize.Y) + + scaleWidth := float64(targetWidth) / float64(imageSize.X) + scaleHeight := float64(targetHeight) / float64(imageSize.Y) + + var w, h int + + if scaleWidth < scaleHeight { + w = targetWidth + h = min(int(math.Floor(float64(imageSize.Y)*scaleWidth)), targetHeight) + } else { + w = min(int(math.Floor(float64(imageSize.X)*scaleHeight)), targetWidth) + h = targetHeight + } + + return image.Point{w, h} +} + +func GetOptimalTiledCanvas(imageSize image.Point, maxImageTiles, tileSize int) image.Point { + possibleTileArrangements := GetSupportedAspectRatios(maxImageTiles) + possibleCanvasSizes := []image.Point{} + for _, pta := range possibleTileArrangements { + possibleCanvasSizes = append(possibleCanvasSizes, image.Point{pta.X * tileSize, pta.Y * tileSize}) + } + + scales := []float64{} + + for _, pcs := range possibleCanvasSizes { + scaleHeight := float64(pcs.Y) / float64(imageSize.Y) + scaleWidth := float64(pcs.X) / float64(imageSize.X) + + if scaleWidth > scaleHeight { + scales = append(scales, scaleHeight) + } else { + scales = append(scales, scaleWidth) + } + } + + var minUpscale float64 + var maxDownscale float64 + var upscale bool + + for _, s := range scales { + if s > 1.0 { + upscale = true + if minUpscale == 0 { + minUpscale = s + } else { + minUpscale = math.Min(minUpscale, s) + } + } else { + maxDownscale = math.Max(maxDownscale, s) + } + } + + selectedScale := maxDownscale + if upscale { + selectedScale = minUpscale + } + + selectedCanvas := possibleCanvasSizes[0] + for n, pcs := range possibleCanvasSizes { + if scales[n] == selectedScale { + // choose the largest possible canvas + if pcs.X*pcs.Y > selectedCanvas.X*selectedCanvas.Y { + selectedCanvas = pcs + } + } + } + return selectedCanvas +} + +func SplitToTiles(img image.Image, numTilesSize image.Point) []image.Image { + b := img.Bounds() + width := b.Max.X - b.Min.X + height := b.Max.Y - b.Min.Y + tileHeight := height / numTilesSize.Y + tileWidth := width / numTilesSize.X + + images := []image.Image{} + + for h := range numTilesSize.Y { + for w := range numTilesSize.X { + rect := image.Rect(tileWidth*w, tileHeight*h, tileWidth*(w+1), tileHeight*(h+1)) + images = append(images, img.(interface { + SubImage(image.Rectangle) image.Image + }).SubImage(rect)) + } + } + + return images +} + +func ResizeImage(img image.Image, outputSize image.Point, maxImageTiles int) (image.Image, image.Point) { + b := img.Bounds() + tileSize := outputSize.Y + + canvasSize := GetOptimalTiledCanvas(b.Max, maxImageTiles, tileSize) + aspectRatio := image.Point{canvasSize.X / tileSize, canvasSize.Y / tileSize} + + newSize := GetImageSizeFitToCanvas(b.Max, canvasSize, tileSize) + + dst := image.NewRGBA(image.Rect(0, 0, newSize.X, newSize.Y)) + draw.ApproxBiLinear.Scale(dst, dst.Rect, img, b, draw.Over, nil) + + return dst, aspectRatio +} + +func PadImage(img image.Image, outputSize, aspectRatio image.Point) image.Image { + paddedSize := image.Point{ + X: outputSize.X * aspectRatio.X, + Y: outputSize.Y * aspectRatio.Y, + } + + dst := image.NewRGBA(image.Rect(0, 0, paddedSize.X, paddedSize.Y)) + centerX := (paddedSize.X - img.Bounds().Max.X) / 2 + centerY := (paddedSize.Y - img.Bounds().Max.Y) / 2 + pos := image.Rect(centerX, centerY, centerX+img.Bounds().Max.X, centerY+img.Bounds().Max.Y) + + draw.Draw(dst, pos, img, image.Point{0, 0}, draw.Over) + + return dst +} + +func PackImages(img image.Image, aspectRatio image.Point, mean, std [3]float32) []float32 { + subImages := SplitToTiles(img, aspectRatio) + + var pixelVals []float32 + + for _, subImg := range subImages { + bounds := subImg.Bounds() + rVals := []float32{} + gVals := []float32{} + bVals := []float32{} + for y := bounds.Min.Y; y < bounds.Max.Y; y++ { + for x := bounds.Min.X; x < bounds.Max.X; x++ { + c := subImg.At(x, y) + r, g, b, _ := c.RGBA() + rVal := float32(r>>8) / 255.0 + gVal := float32(g>>8) / 255.0 + bVal := float32(b>>8) / 255.0 + + rVal = (rVal - mean[0]) / std[0] + gVal = (gVal - mean[1]) / std[1] + bVal = (bVal - mean[2]) / std[2] + + rVals = append(rVals, rVal) + gVals = append(gVals, gVal) + bVals = append(bVals, bVal) + } + } + pixelVals = append(pixelVals, rVals...) + pixelVals = append(pixelVals, gVals...) + pixelVals = append(pixelVals, bVals...) + } + + return pixelVals +} + +func Preprocess(imageData []byte) ([]float32, int, error) { + // todo: need guard in here for bad image data + + // mllama values + outputSize := image.Point{560, 560} + maxTiles := 4 + + // clip values + mean := [3]float32{0.48145466, 0.4578275, 0.40821073} + std := [3]float32{0.26862954, 0.26130258, 0.27577711} + + img, _, err := image.Decode(bytes.NewReader(imageData)) + if err != nil { + return nil, 0, fmt.Errorf("failed to decode image: %w", err) + } + + newImage, aspectRatio := ResizeImage(img, outputSize, maxTiles) + newImage = PadImage(newImage, outputSize, aspectRatio) + + // todo: need to scale (dim) by 1/256 + + data := PackImages(newImage, aspectRatio, mean, std) + supportedRatios := GetSupportedAspectRatios(maxTiles) + var aspectRatioIndex int + for n, r := range supportedRatios { + if r == aspectRatio { + aspectRatioIndex = n+1 + break + } + } + + return data, aspectRatioIndex, nil +} diff --git a/server/imageproc/images_test.go b/server/imageproc/images_test.go new file mode 100644 index 00000000..ce30cfde --- /dev/null +++ b/server/imageproc/images_test.go @@ -0,0 +1,305 @@ +package imageproc + +import ( + "image" + "reflect" + "testing" +) + +func testEq(a, b any) bool { + va := reflect.ValueOf(a) + vb := reflect.ValueOf(b) + + if va.Kind() != reflect.Slice || vb.Kind() != reflect.Slice { + return false + } + + if va.Len() != vb.Len() { + return false + } + + for i := range va.Len() { + if !reflect.DeepEqual(va.Index(i).Interface(), vb.Index(i).Interface()) { + return false + } + } + return true +} + +func TestAspectRatios(t *testing.T) { + type AspectCase struct { + MaxTiles int + Expected []image.Point + } + + cases := []AspectCase{ + { + MaxTiles: 1, + Expected: []image.Point{{1, 1}}, + }, + { + MaxTiles: 2, + Expected: []image.Point{{1, 1}, {1, 2}, {2, 1}}, + }, + { + MaxTiles: 3, + Expected: []image.Point{{1, 1}, {1, 2}, {1, 3}, {2, 1}, {3, 1}}, + }, + { + MaxTiles: 4, + Expected: []image.Point{{1, 1}, {1, 2}, {1, 3}, {1, 4}, {2, 1}, {2, 2}, {3, 1}, {4, 1}}, + }, + } + + for _, c := range cases { + actual := GetSupportedAspectRatios(c.MaxTiles) + + if !testEq(actual, c.Expected) { + t.Errorf("incorrect aspect ratio: '%#v'. expected: '%#v'", actual, c.Expected) + } + } +} + +func TestGetImageSizeFitToCanvas(t *testing.T) { + type ImageSizeCase struct { + ImageRect image.Point + CanvasRect image.Point + TileSize int + Expected image.Point + } + + cases := []ImageSizeCase{ + { + ImageRect: image.Point{400, 400}, + CanvasRect: image.Point{640, 480}, + TileSize: 200, + Expected: image.Point{400, 400}, + }, + { + ImageRect: image.Point{1024, 768}, + CanvasRect: image.Point{640, 480}, + TileSize: 200, + Expected: image.Point{640, 480}, + }, + { + ImageRect: image.Point{500, 500}, + CanvasRect: image.Point{1000, 1000}, + TileSize: 750, + Expected: image.Point{750, 750}, + }, + { + ImageRect: image.Point{500, 1000}, + CanvasRect: image.Point{2000, 2000}, + TileSize: 2000, + Expected: image.Point{1000, 2000}, + }, + { + ImageRect: image.Point{4000, 3000}, + CanvasRect: image.Point{2000, 1000}, + TileSize: 1000, + Expected: image.Point{1333, 1000}, + }, + { + ImageRect: image.Point{667, 1000}, + CanvasRect: image.Point{1000, 1000}, + TileSize: 560, + Expected: image.Point{667, 1000}, + }, + } + + for _, c := range cases { + actual := GetImageSizeFitToCanvas(c.ImageRect, c.CanvasRect, c.TileSize) + + if actual != c.Expected { + t.Errorf("incorrect image rect: '%#v'. expected: '%#v'", actual, c.Expected) + } + } +} + +func TestGetOptimalTiledCanvas(t *testing.T) { + type TiledCanvasSizeCase struct { + ImageSize image.Point + MaxImageTiles int + TileSize int + Expected image.Point + } + + cases := []TiledCanvasSizeCase{ + { + ImageSize: image.Point{1024, 768}, + MaxImageTiles: 4, + TileSize: 1000, + Expected: image.Point{4000, 1000}, + }, + { + ImageSize: image.Point{1024, 768}, + MaxImageTiles: 4, + TileSize: 560, + Expected: image.Point{1120, 1120}, + }, + } + + for _, c := range cases { + actual := GetOptimalTiledCanvas(c.ImageSize, c.MaxImageTiles, c.TileSize) + + if actual != c.Expected { + t.Errorf("incorrect tiled canvas: '%#v'. expected: '%#v'", actual, c.Expected) + } + } +} + +func TestSplitToTiles(t *testing.T) { + type SplitCase struct { + TestImage image.Image + NumTilesSize image.Point + Expected []image.Image + } + + cases := []SplitCase{ + { + TestImage: image.NewRGBA(image.Rect(0, 0, 1024, 768)), + NumTilesSize: image.Point{1, 1}, + Expected: []image.Image{image.NewRGBA(image.Rect(0, 0, 1024, 768))}, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 1000, 500)), + NumTilesSize: image.Point{2, 1}, + Expected: []image.Image{ + image.NewRGBA(image.Rect(0, 0, 500, 500)), + image.NewRGBA(image.Rect(500, 0, 1000, 500)), + }, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 1000, 1000)), + NumTilesSize: image.Point{2, 2}, + Expected: []image.Image{ + image.NewRGBA(image.Rect(0, 0, 500, 500)), + image.NewRGBA(image.Rect(500, 0, 1000, 500)), + image.NewRGBA(image.Rect(0, 500, 500, 1000)), + image.NewRGBA(image.Rect(500, 500, 1000, 1000)), + }, + }, + } + + for _, c := range cases { + actual := SplitToTiles(c.TestImage, c.NumTilesSize) + + if len(actual) != len(c.Expected) { + t.Errorf("incorrect number of images '%d': expected: '%d'", len(actual), len(c.Expected)) + } + + for i := range actual { + if actual[i].Bounds() != c.Expected[i].Bounds() { + t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual[i].Bounds(), c.Expected[i].Bounds()) + } + } + } +} + +func TestResize(t *testing.T) { + type ResizeCase struct { + TestImage image.Image + OutputSize image.Point + MaxImageTiles int + ExpectedImage image.Image + ExpectedAspectRatio image.Point + } + + cases := []ResizeCase{ + { + TestImage: image.NewRGBA(image.Rect(0, 0, 200, 200)), + OutputSize: image.Point{100, 100}, + MaxImageTiles: 1, + ExpectedImage: image.NewRGBA(image.Rect(0, 0, 100, 100)), + ExpectedAspectRatio: image.Point{1, 1}, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 200, 200)), + OutputSize: image.Point{100, 100}, + MaxImageTiles: 2, + ExpectedImage: image.NewRGBA(image.Rect(0, 0, 100, 100)), + ExpectedAspectRatio: image.Point{1, 2}, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 2560, 1920)), + OutputSize: image.Point{560, 560}, + MaxImageTiles: 4, + ExpectedImage: image.NewRGBA(image.Rect(0, 0, 1120, 840)), + ExpectedAspectRatio: image.Point{2, 2}, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 1024, 768)), + OutputSize: image.Point{560, 560}, + MaxImageTiles: 4, + ExpectedImage: image.NewRGBA(image.Rect(0, 0, 1024, 768)), + ExpectedAspectRatio: image.Point{2, 2}, + }, + } + + for _, c := range cases { + actualImage, actualAspectRatio := ResizeImage(c.TestImage, c.OutputSize, c.MaxImageTiles) + + if actualImage.Bounds() != c.ExpectedImage.Bounds() { + t.Errorf("image size incorrect: '%#v': expected: '%#v'", actualImage.Bounds(), c.ExpectedImage.Bounds()) + } + + if actualAspectRatio != c.ExpectedAspectRatio { + t.Errorf("canvas size incorrect: '%#v': expected: '%#v'", actualAspectRatio, c.ExpectedAspectRatio) + } + } +} + +func TestPad(t *testing.T) { + type PadCase struct { + TestImage image.Image + OutputSize image.Point + AspectRatio image.Point + Expected image.Image + } + + cases := []PadCase{ + { + TestImage: image.NewRGBA(image.Rect(0, 0, 1000, 667)), + OutputSize: image.Point{560, 560}, + AspectRatio: image.Point{2, 2}, + Expected: image.NewRGBA(image.Rect(0, 0, 1120, 1120)), + }, + } + + for _, c := range cases { + actual := PadImage(c.TestImage, c.OutputSize, c.AspectRatio) + + if actual.Bounds() != c.Expected.Bounds() { + t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual.Bounds(), c.Expected.Bounds()) + } + } +} + +func TestPackImages(t *testing.T) { + type PackCase struct { + TestImage image.Image + AspectRatio image.Point + } + + mean := [3]float32{0.48145466, 0.4578275, 0.40821073} + std := [3]float32{0.26862954, 0.26130258, 0.27577711} + + cases := []PackCase{ + { + TestImage: image.NewRGBA(image.Rect(0, 0, 1120, 1120)), + AspectRatio: image.Point{2, 2}, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 560, 560)), + AspectRatio: image.Point{1, 1}, + }, + { + TestImage: image.NewRGBA(image.Rect(0, 0, 1120, 560)), + AspectRatio: image.Point{1, 2}, + }, + } + + for _, c := range cases { + PackImages(c.TestImage, c.AspectRatio, mean, std) + } +} diff --git a/server/prompt.go b/server/prompt.go index be0d4969..88393798 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -7,6 +7,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/server/imageproc" "github.com/ollama/ollama/template" ) @@ -61,14 +62,37 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. return "", nil, err } + preprocess := checkMllamaModelFamily(m) + for _, m := range msgs[n:] { for _, i := range m.Images { - images = append(images, llm.ImageData{ - ID: len(images), - Data: i, - }) + if preprocess { + data, aspectRatioID, err := imageproc.Preprocess(i) + if err != nil { + return "", nil, err + } + images = append(images, llm.ImageData{ + ID: len(images), + ImageData: data, + AspectRatioID: aspectRatioID, + }) + } else { + images = append(images, llm.ImageData{ + ID: len(images), + Data: i, + }) + } } } return b.String(), images, nil } + +func checkMllamaModelFamily(m *Model) bool { + for _, arch := range m.Config.ModelFamilies { + if arch == "mllama" { + return true + } + } + return false +} diff --git a/server/prompt_test.go b/server/prompt_test.go index 5fe3d4c5..26a20027 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -203,7 +203,7 @@ func TestChatPrompt(t *testing.T) { } if !bytes.Equal(images[i].Data, tt.images[i]) { - t.Errorf("expected %q, got %q", tt.images[i], images[i]) + t.Errorf("expected %q, got %q", tt.images[i], images[i].Data) } } })