forked from third-party-mirrors/ollama
236 lines
4.8 KiB
Go
236 lines
4.8 KiB
Go
package convert
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/sha256"
|
|
"encoding/json"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"math"
|
|
"os"
|
|
"path/filepath"
|
|
"slices"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/llm"
|
|
"golang.org/x/exp/maps"
|
|
)
|
|
|
|
func convertFull(t *testing.T, d string) (*os.File, llm.KV, llm.Tensors) {
|
|
t.Helper()
|
|
|
|
f, err := os.CreateTemp(t.TempDir(), "f16")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer f.Close()
|
|
|
|
if err := Convert(d, f); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
r, err := os.Open(f.Name())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(func() { r.Close() })
|
|
|
|
m, _, err := llm.DecodeGGML(r, math.MaxInt)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if _, err := r.Seek(0, io.SeekStart); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return r, m.KV(), m.Tensors()
|
|
}
|
|
|
|
func TestMain(m *testing.M) {
|
|
var level slog.Level
|
|
flag.TextVar(&level, "level", slog.LevelInfo, "log level")
|
|
flag.Parse()
|
|
slog.SetLogLoggerLevel(level)
|
|
os.Exit(m.Run())
|
|
}
|
|
|
|
func TestConvertFull(t *testing.T) {
|
|
cases := []string{
|
|
"Meta-Llama-3-8B-Instruct",
|
|
"Mistral-7B-Instruct-v0.2",
|
|
"Mixtral-8x7B-Instruct-v0.1",
|
|
"gemma-2b-it",
|
|
}
|
|
|
|
for i := range cases {
|
|
tt := cases[i]
|
|
t.Run(tt, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
p := filepath.Join("testdata", tt)
|
|
if testing.Short() {
|
|
t.Skip("skipping in short mode")
|
|
} else if _, err := os.Stat(p); err != nil {
|
|
t.Skipf("%s not found", p)
|
|
}
|
|
|
|
f, kv, tensors := convertFull(t, p)
|
|
actual := make(map[string]string)
|
|
for k, v := range kv {
|
|
if s, ok := v.(json.Marshaler); !ok {
|
|
actual[k] = fmt.Sprintf("%v", v)
|
|
} else {
|
|
bts, err := json.Marshal(s)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
actual[k] = fmt.Sprintf("%x", sha256.Sum256(bts))
|
|
}
|
|
}
|
|
|
|
for _, tensor := range tensors.Items {
|
|
sha256sum := sha256.New()
|
|
sr := io.NewSectionReader(f, int64(tensors.Offset+tensor.Offset), int64(tensor.Size()))
|
|
if _, err := io.Copy(sha256sum, sr); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
actual[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
|
|
}
|
|
|
|
expectFile, err := os.Open(filepath.Join("testdata", fmt.Sprintf("%s.json", tt)))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
var expect map[string]string
|
|
if err := json.NewDecoder(expectFile).Decode(&expect); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
keys := maps.Keys(expect)
|
|
slices.Sort(keys)
|
|
for _, k := range keys {
|
|
if v, ok := actual[k]; !ok {
|
|
t.Errorf("missing %s", k)
|
|
} else if v != expect[k] {
|
|
t.Errorf("unexpected %s: want %s, got %s", k, expect[k], v)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestConvertNPZ(t *testing.T) {
|
|
cases := []string{
|
|
"adapters.npz",
|
|
}
|
|
|
|
for _, fn := range cases {
|
|
ts, err := parseNPZ(filepath.Join("testdata", fn))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(ts) != 16*2*2 {
|
|
t.Errorf("got: %d want: %d total layers", len(ts), 16*2*2)
|
|
}
|
|
|
|
a := adapter{}
|
|
|
|
for _, m := range ts {
|
|
at := m.(adapterTensor)
|
|
if at.path != filepath.Join("testdata", fn) {
|
|
t.Errorf("got: %s want: %s", at.path, filepath.Join("testdata", fn))
|
|
}
|
|
if at.dtype != "F32" {
|
|
t.Errorf("got: %s but only F32s are currently supported", at.dtype)
|
|
}
|
|
if len(at.tensorBase.shape) != 2 {
|
|
t.Errorf("got: %d want: %d tensor shape", at.tensorBase.shape, 2)
|
|
}
|
|
}
|
|
|
|
var ws io.WriteSeeker = &memWriter{}
|
|
err = llm.WriteGGLA(ws, a.KV(nil), a.Tensors(ts))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
mw := ws.(*memWriter)
|
|
slog.Info(fmt.Sprintf("buffer len = %d", len(mw.buf)))
|
|
if len(mw.buf) == 0 {
|
|
t.Errorf("ggla layer not written correctly")
|
|
}
|
|
rs := bytes.NewReader(mw.buf)
|
|
ggml, _, err := llm.DecodeGGML(rs, len(mw.buf))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if ggml == nil {
|
|
t.Fatalf("ggla didn't convert to ggml correctly")
|
|
}
|
|
|
|
kv := ggml.KV()
|
|
if kv == nil {
|
|
t.Fatalf("no lora KVs were set")
|
|
}
|
|
|
|
r, ok := kv["r"]
|
|
if !ok || r != uint32(8) {
|
|
t.Errorf("lora rank was not set correctly")
|
|
}
|
|
|
|
alpha, ok := kv["alpha"]
|
|
if !ok || alpha != uint32(160) {
|
|
t.Errorf("lora alpha was not set correctly")
|
|
}
|
|
|
|
gts := ggml.Tensors()
|
|
if len(ts) != len(gts.Items) {
|
|
t.Fatalf("got: %d want: %d tensors in ggla", len(gts.Items), len(ts))
|
|
}
|
|
}
|
|
}
|
|
|
|
type memWriter struct {
|
|
buf []byte
|
|
pos int
|
|
}
|
|
|
|
func (m *memWriter) Write(p []byte) (n int, err error) {
|
|
minCap := m.pos + len(p)
|
|
if minCap > cap(m.buf) {
|
|
buf2 := make([]byte, len(m.buf), minCap+len(p)) // add some extra
|
|
copy(buf2, m.buf)
|
|
m.buf = buf2
|
|
}
|
|
if minCap > len(m.buf) {
|
|
m.buf = m.buf[:minCap]
|
|
}
|
|
copy(m.buf[m.pos:], p)
|
|
m.pos += len(p)
|
|
return len(p), nil
|
|
}
|
|
|
|
func (m *memWriter) Seek(offset int64, whence int) (int64, error) {
|
|
newPos, offs := 0, int(offset)
|
|
switch whence {
|
|
case io.SeekStart:
|
|
newPos = offs
|
|
case io.SeekCurrent:
|
|
newPos = m.pos + offs
|
|
case io.SeekEnd:
|
|
newPos = len(m.buf) + offs
|
|
}
|
|
if newPos < 0 {
|
|
return 0, errors.New("negative result pos")
|
|
}
|
|
m.pos = newPos
|
|
return int64(newPos), nil
|
|
}
|