ollama/x/encoding/gguf/gguf_test.go
Blake Mizerany adc23d5f96 Add 'x/' from commit 'a10a11b9d371f36b7c3510da32a1d70b74e27bd1'
git-subtree-dir: x
git-subtree-mainline: 7d05a6ee8f44b314fa697a427439e5fa4d78c3d7
git-subtree-split: a10a11b9d371f36b7c3510da32a1d70b74e27bd1
2024-04-03 10:40:23 -07:00

346 lines
8.4 KiB
Go

package gguf
import (
"errors"
"io"
"strings"
"testing"
"kr.dev/diff"
)
func TestStat(t *testing.T) {
cases := []struct {
name string
data string
wantInfo Info
wantErr error
}{
{
name: "empty",
wantErr: ErrBadMagic,
},
{
name: "bad magic",
data: "\xBB\xAA\xDD\x00",
wantErr: ErrBadMagic,
},
{
name: "bad version",
data: string(magicBytes) +
"\x02\x00\x00\x00" + // version
"",
wantErr: ErrUnsupportedVersion,
},
{
name: "valid general.file_type",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
// general.file_type key
"\x11\x00\x00\x00\x00\x00\x00\x00" + // key length
"general.file_type" + // key
"\x04\x00\x00\x00" + // type (uint32)
"\x01\x00\x00\x00\x00\x00\x00\x00" + // uint32 value
"",
wantInfo: Info{
Version: 3,
FileType: 1,
},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
info, err := StatReader(strings.NewReader(tt.data))
if tt.wantErr != nil {
if !errors.Is(err, tt.wantErr) {
t.Fatalf("err = %v; want %q", err, tt.wantErr)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
diff.Test(t, t.Errorf, info, tt.wantInfo)
})
}
}
func TestReadInfo(t *testing.T) {
cases := []struct {
name string
data string
wantMeta []MetaEntry
wantTensor []TensorInfo
wantReadErr error
wantMetaErr error
wantTensorErr error
wantInfo Info
}{
{
name: "empty",
wantReadErr: io.ErrUnexpectedEOF,
},
{
name: "bad magic",
data: "\xBB\xAA\xDD\x00",
wantReadErr: ErrBadMagic,
},
{
name: "bad version",
data: string(magicBytes) +
"\x02\x00\x00\x00" + // version
"",
wantReadErr: ErrUnsupportedVersion,
},
{
name: "no metadata or tensors",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"",
wantReadErr: nil,
},
{
name: "good metadata",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
"K" + // key
"\x08\x00\x00\x00" + // type (string)
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
"VV" + // string value
"",
wantMeta: []MetaEntry{
{Key: "K", Type: ValueTypeString, Values: []MetaValue{{Type: ValueTypeString, Value: []byte("VV")}}},
},
},
{
name: "good metadata with multiple values",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x02\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
// MetaEntry 1
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
"x" + // key
"\x08\x00\x00\x00" + // type (string)
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
"XX" + // string value
// MetaEntry 2
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
"y" + // key
"\x04\x00\x00\x00" + // type (uint32)
"\x99\x88\x77\x66" + // uint32 value
"",
wantMeta: []MetaEntry{
{Key: "x", Type: ValueTypeString, Values: []MetaValue{{
Type: ValueTypeString,
Value: []byte("XX"),
}}},
{Key: "y", Type: ValueTypeUint32, Values: []MetaValue{{
Type: ValueTypeUint32,
Value: []byte{0x99, 0x88, 0x77, 0x66},
}}},
},
},
{
name: "negative string length in meta key",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF" + // key length
"K" + // key
"\x08\x00\x00\x00" + // type (string)
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
"VV" + // string value
"",
wantMetaErr: ErrMangled,
},
// Tensor tests
{
name: "good tensor",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
// Tensor 1
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
"t" +
// dimensions
"\x01\x00\x00\x00" + // dimensions length
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
"\x03\x00\x00\x00" + // type (i8)
"\x00\x01\x00\x00\x00\x00\x00\x00" + // offset
"",
wantTensor: []TensorInfo{
{
Name: "t",
Dimensions: []uint64{1},
Type: TypeQ4_1,
Offset: 256,
Size: 256,
},
},
},
{
name: "too many dimensions",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
// Tensor 1
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
"t" +
"\x00\x00\x00\x01" + // dimensions length
"",
wantTensorErr: ErrMangled,
},
{
name: "size computed",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x02\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
// Tensor 1
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
"t" +
"\x01\x00\x00\x00" + // dimensions length
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
"\x03\x00\x00\x00" + // type (i8)
"\x00\x01\x00\x00\x00\x00\x00\x00" + // offset
// Tensor 2
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
"t" +
"\x01\x00\x00\x00" + // dimensions length
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
"\x03\x00\x00\x00" + // type (i8)
"\x00\x03\x00\x00\x00\x00\x00\x00" + // offset
"",
wantTensor: []TensorInfo{
{
Name: "t",
Dimensions: []uint64{1},
Type: TypeQ4_1,
Offset: 256,
Size: 256,
},
{
Name: "t",
Dimensions: []uint64{1},
Type: TypeQ4_1,
Offset: 768,
Size: 512,
},
},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
f, err := ReadFile(strings.NewReader(tt.data))
if err != nil {
if !errors.Is(err, tt.wantReadErr) {
t.Fatalf("unexpected ReadFile error: %v", err)
}
return
}
var got []MetaEntry
for meta, err := range f.Metadata {
if !errors.Is(err, tt.wantMetaErr) {
t.Fatalf("err = %v; want %v", err, ErrMangled)
}
if err != nil {
return
}
got = append(got, meta)
}
diff.Test(t, t.Errorf, got, tt.wantMeta)
var gotT []TensorInfo
for tinfo, err := range f.Tensors {
if !errors.Is(err, tt.wantTensorErr) {
t.Fatalf("err = %v; want %v", err, tt.wantTensorErr)
}
if err != nil {
return
}
gotT = append(gotT, tinfo)
}
diff.Test(t, t.Errorf, gotT, tt.wantTensor)
})
}
}
func FuzzReadInfo(f *testing.F) {
f.Add(string(magicBytes))
f.Add(string(magicBytes) +
"\x03\x00\x00\x00" + // version
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"")
f.Add(string(magicBytes) +
"\x03\x00\x00\x00" + // version
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
"K" + // key
"\x08\x00\x00\x00" + // type (string)
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
"VV" + // string value
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
"t" +
"\x01\x00\x00\x00" + // dimensions length
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
"\x03\x00\x00\x00" + // type (i8)
"\x05\x00\x00\x00\x00\x00\x00\x00" + // offset
"")
f.Fuzz(func(t *testing.T, data string) {
gf, err := ReadFile(strings.NewReader(data))
if err != nil {
t.Logf("ReadFile error: %v", err)
t.Skip()
}
for _, err := range gf.Metadata {
if err != nil {
t.Logf("metadata error: %v", err)
t.Skip()
}
}
for tinfo, err := range gf.Tensors {
if err != nil {
t.Logf("tensor error: %v", err)
t.Skip()
}
if tinfo.Offset <= 0 {
t.Logf("invalid tensor offset: %+v", t)
t.Skip()
}
if tinfo.Size <= 0 {
t.Logf("invalid tensor size: %+v", t)
t.Skip()
}
}
})
}