git-subtree-dir: x git-subtree-mainline: 7d05a6ee8f44b314fa697a427439e5fa4d78c3d7 git-subtree-split: a10a11b9d371f36b7c3510da32a1d70b74e27bd1
346 lines
8.4 KiB
Go
346 lines
8.4 KiB
Go
package gguf
|
|
|
|
import (
|
|
"errors"
|
|
"io"
|
|
"strings"
|
|
"testing"
|
|
|
|
"kr.dev/diff"
|
|
)
|
|
|
|
func TestStat(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
data string
|
|
wantInfo Info
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "empty",
|
|
wantErr: ErrBadMagic,
|
|
},
|
|
{
|
|
name: "bad magic",
|
|
data: "\xBB\xAA\xDD\x00",
|
|
wantErr: ErrBadMagic,
|
|
},
|
|
{
|
|
name: "bad version",
|
|
data: string(magicBytes) +
|
|
"\x02\x00\x00\x00" + // version
|
|
"",
|
|
wantErr: ErrUnsupportedVersion,
|
|
},
|
|
{
|
|
name: "valid general.file_type",
|
|
data: string(magicBytes) + // magic
|
|
"\x03\x00\x00\x00" + // version
|
|
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
|
|
|
// general.file_type key
|
|
"\x11\x00\x00\x00\x00\x00\x00\x00" + // key length
|
|
"general.file_type" + // key
|
|
"\x04\x00\x00\x00" + // type (uint32)
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // uint32 value
|
|
"",
|
|
wantInfo: Info{
|
|
Version: 3,
|
|
FileType: 1,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range cases {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
info, err := StatReader(strings.NewReader(tt.data))
|
|
if tt.wantErr != nil {
|
|
if !errors.Is(err, tt.wantErr) {
|
|
t.Fatalf("err = %v; want %q", err, tt.wantErr)
|
|
}
|
|
return
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
diff.Test(t, t.Errorf, info, tt.wantInfo)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestReadInfo(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
data string
|
|
|
|
wantMeta []MetaEntry
|
|
wantTensor []TensorInfo
|
|
wantReadErr error
|
|
wantMetaErr error
|
|
wantTensorErr error
|
|
wantInfo Info
|
|
}{
|
|
{
|
|
name: "empty",
|
|
wantReadErr: io.ErrUnexpectedEOF,
|
|
},
|
|
{
|
|
name: "bad magic",
|
|
data: "\xBB\xAA\xDD\x00",
|
|
wantReadErr: ErrBadMagic,
|
|
},
|
|
{
|
|
name: "bad version",
|
|
data: string(magicBytes) +
|
|
"\x02\x00\x00\x00" + // version
|
|
"",
|
|
wantReadErr: ErrUnsupportedVersion,
|
|
},
|
|
{
|
|
name: "no metadata or tensors",
|
|
data: string(magicBytes) + // magic
|
|
"\x03\x00\x00\x00" + // version
|
|
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
|
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
|
"",
|
|
wantReadErr: nil,
|
|
},
|
|
{
|
|
name: "good metadata",
|
|
data: string(magicBytes) + // magic
|
|
"\x03\x00\x00\x00" + // version
|
|
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
|
"K" + // key
|
|
"\x08\x00\x00\x00" + // type (string)
|
|
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
|
"VV" + // string value
|
|
"",
|
|
wantMeta: []MetaEntry{
|
|
{Key: "K", Type: ValueTypeString, Values: []MetaValue{{Type: ValueTypeString, Value: []byte("VV")}}},
|
|
},
|
|
},
|
|
{
|
|
name: "good metadata with multiple values",
|
|
data: string(magicBytes) + // magic
|
|
"\x03\x00\x00\x00" + // version
|
|
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
|
"\x02\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
|
|
|
// MetaEntry 1
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
|
"x" + // key
|
|
"\x08\x00\x00\x00" + // type (string)
|
|
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
|
"XX" + // string value
|
|
|
|
// MetaEntry 2
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
|
"y" + // key
|
|
"\x04\x00\x00\x00" + // type (uint32)
|
|
"\x99\x88\x77\x66" + // uint32 value
|
|
"",
|
|
wantMeta: []MetaEntry{
|
|
{Key: "x", Type: ValueTypeString, Values: []MetaValue{{
|
|
Type: ValueTypeString,
|
|
Value: []byte("XX"),
|
|
}}},
|
|
{Key: "y", Type: ValueTypeUint32, Values: []MetaValue{{
|
|
Type: ValueTypeUint32,
|
|
Value: []byte{0x99, 0x88, 0x77, 0x66},
|
|
}}},
|
|
},
|
|
},
|
|
{
|
|
name: "negative string length in meta key",
|
|
data: string(magicBytes) + // magic
|
|
"\x03\x00\x00\x00" + // version
|
|
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
|
"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF" + // key length
|
|
"K" + // key
|
|
"\x08\x00\x00\x00" + // type (string)
|
|
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
|
"VV" + // string value
|
|
"",
|
|
wantMetaErr: ErrMangled,
|
|
},
|
|
|
|
// Tensor tests
|
|
{
|
|
name: "good tensor",
|
|
data: string(magicBytes) + // magic
|
|
"\x03\x00\x00\x00" + // version
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
|
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
|
|
|
// Tensor 1
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
|
"t" +
|
|
|
|
// dimensions
|
|
"\x01\x00\x00\x00" + // dimensions length
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
|
|
|
"\x03\x00\x00\x00" + // type (i8)
|
|
"\x00\x01\x00\x00\x00\x00\x00\x00" + // offset
|
|
"",
|
|
wantTensor: []TensorInfo{
|
|
{
|
|
Name: "t",
|
|
Dimensions: []uint64{1},
|
|
Type: TypeQ4_1,
|
|
Offset: 256,
|
|
Size: 256,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "too many dimensions",
|
|
data: string(magicBytes) + // magic
|
|
"\x03\x00\x00\x00" + // version
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
|
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
|
|
|
// Tensor 1
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
|
"t" +
|
|
|
|
"\x00\x00\x00\x01" + // dimensions length
|
|
"",
|
|
wantTensorErr: ErrMangled,
|
|
},
|
|
{
|
|
name: "size computed",
|
|
data: string(magicBytes) + // magic
|
|
"\x03\x00\x00\x00" + // version
|
|
"\x02\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
|
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
|
|
|
// Tensor 1
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
|
"t" +
|
|
"\x01\x00\x00\x00" + // dimensions length
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
|
"\x03\x00\x00\x00" + // type (i8)
|
|
"\x00\x01\x00\x00\x00\x00\x00\x00" + // offset
|
|
|
|
// Tensor 2
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
|
"t" +
|
|
"\x01\x00\x00\x00" + // dimensions length
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
|
"\x03\x00\x00\x00" + // type (i8)
|
|
"\x00\x03\x00\x00\x00\x00\x00\x00" + // offset
|
|
"",
|
|
wantTensor: []TensorInfo{
|
|
{
|
|
Name: "t",
|
|
Dimensions: []uint64{1},
|
|
Type: TypeQ4_1,
|
|
Offset: 256,
|
|
Size: 256,
|
|
},
|
|
{
|
|
Name: "t",
|
|
Dimensions: []uint64{1},
|
|
Type: TypeQ4_1,
|
|
Offset: 768,
|
|
Size: 512,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range cases {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
f, err := ReadFile(strings.NewReader(tt.data))
|
|
if err != nil {
|
|
if !errors.Is(err, tt.wantReadErr) {
|
|
t.Fatalf("unexpected ReadFile error: %v", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
var got []MetaEntry
|
|
for meta, err := range f.Metadata {
|
|
if !errors.Is(err, tt.wantMetaErr) {
|
|
t.Fatalf("err = %v; want %v", err, ErrMangled)
|
|
}
|
|
if err != nil {
|
|
return
|
|
}
|
|
got = append(got, meta)
|
|
}
|
|
diff.Test(t, t.Errorf, got, tt.wantMeta)
|
|
|
|
var gotT []TensorInfo
|
|
for tinfo, err := range f.Tensors {
|
|
if !errors.Is(err, tt.wantTensorErr) {
|
|
t.Fatalf("err = %v; want %v", err, tt.wantTensorErr)
|
|
}
|
|
if err != nil {
|
|
return
|
|
}
|
|
gotT = append(gotT, tinfo)
|
|
}
|
|
diff.Test(t, t.Errorf, gotT, tt.wantTensor)
|
|
})
|
|
}
|
|
}
|
|
|
|
func FuzzReadInfo(f *testing.F) {
|
|
f.Add(string(magicBytes))
|
|
f.Add(string(magicBytes) +
|
|
"\x03\x00\x00\x00" + // version
|
|
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
|
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
|
"")
|
|
f.Add(string(magicBytes) +
|
|
"\x03\x00\x00\x00" + // version
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
|
"K" + // key
|
|
"\x08\x00\x00\x00" + // type (string)
|
|
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
|
"VV" + // string value
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
|
"t" +
|
|
"\x01\x00\x00\x00" + // dimensions length
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
|
"\x03\x00\x00\x00" + // type (i8)
|
|
"\x05\x00\x00\x00\x00\x00\x00\x00" + // offset
|
|
"")
|
|
|
|
f.Fuzz(func(t *testing.T, data string) {
|
|
gf, err := ReadFile(strings.NewReader(data))
|
|
if err != nil {
|
|
t.Logf("ReadFile error: %v", err)
|
|
t.Skip()
|
|
}
|
|
for _, err := range gf.Metadata {
|
|
if err != nil {
|
|
t.Logf("metadata error: %v", err)
|
|
t.Skip()
|
|
}
|
|
}
|
|
for tinfo, err := range gf.Tensors {
|
|
if err != nil {
|
|
t.Logf("tensor error: %v", err)
|
|
t.Skip()
|
|
}
|
|
if tinfo.Offset <= 0 {
|
|
t.Logf("invalid tensor offset: %+v", t)
|
|
t.Skip()
|
|
}
|
|
if tinfo.Size <= 0 {
|
|
t.Logf("invalid tensor size: %+v", t)
|
|
t.Skip()
|
|
}
|
|
}
|
|
})
|
|
}
|