package convert

import (
	"io"
	"io/fs"
	"os"
	"path/filepath"
	"strings"
	"testing"

	"github.com/google/go-cmp/cmp"
)

func createTokenizerFS(t *testing.T, dir string, files map[string]io.Reader) fs.FS {
	t.Helper()

	for k, v := range files {
		if err := func() error {
			f, err := os.Create(filepath.Join(dir, k))
			if err != nil {
				return err
			}
			defer f.Close()

			if _, err := io.Copy(f, v); err != nil {
				return err
			}

			return nil
		}(); err != nil {
			t.Fatalf("unexpected error: %v", err)
		}
	}

	return os.DirFS(dir)
}

func TestParseTokenizer(t *testing.T) {
	cases := []struct {
		name              string
		fsys              fs.FS
		specialTokenTypes []string
		want              *Tokenizer
	}{
		{
			name: "string chat template",
			fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
				"tokenizer.json": strings.NewReader(`{}`),
				"tokenizer_config.json": strings.NewReader(`{
					"chat_template": "<default template>"
				}`),
			}),
			want: &Tokenizer{
				Vocabulary: &Vocabulary{Model: "gpt2"},
				Pre:        "default",
				Template:   "<default template>",
			},
		},
		{
			name: "list chat template",
			fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
				"tokenizer.json": strings.NewReader(`{}`),
				"tokenizer_config.json": strings.NewReader(`{
					"chat_template": [
						{
							"name": "default",
							"template": "<default template>"
						},
						{
							"name": "tools",
							"template": "<tools template>"
						}
					]
				}`),
			}),
			want: &Tokenizer{
				Vocabulary: &Vocabulary{Model: "gpt2"},
				Pre:        "default",
				Template:   "<default template>",
			},
		},
		{
			name: "added tokens",
			fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
				"tokenizer.json": strings.NewReader(`{
					"added_tokens": [
						{
							"id": 999,
							"content": "<unused999>",
							"special": false
						}
					]
				}`),
			}),
			want: &Tokenizer{
				Vocabulary: &Vocabulary{
					Model:  "gpt2",
					Tokens: []string{"<unused999>"},
					Scores: []float32{999},
					Types:  []int32{4},
				},
				Pre: "default",
			},
		},
		{
			name: "added tokens overlap vocab",
			fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
				"tokenizer.json": strings.NewReader(`{
					"added_tokens": [
						{
							"id": 0,
							"content": "<pad>",
							"special": true
						}
					],
					"model": {
						"vocab": {
							"<pad>": 0
						}
					}
				}`),
			}),
			want: &Tokenizer{
				Vocabulary: &Vocabulary{
					Model:  "gpt2",
					Tokens: []string{"<pad>"},
					Scores: []float32{0},
					Types:  []int32{3},
				},
				Pre: "default",
			},
		},
		{
			name: "special token types",
			fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
				"tokenizer.json": strings.NewReader(`{
					"added_tokens": [
						{
							"id": 0,
							"content": "<pad>",
							"special": true
						},
						{
							"id": 1,
							"content": "<eos>",
							"special": true
						},
						{
							"id": 2,
							"content": "<bos>",
							"special": true
						},
						{
							"id": 3,
							"content": "<unk>",
							"special": true
						}
					],
					"model": {
						"vocab": {
							"<pad>": 0,
							"<eos>": 1,
							"<bos>": 2,
							"<unk>": 3
						}
					}
				}`),
				"tokenizer_config.json": strings.NewReader(`{
					"add_bos_token": true,
					"add_eos_token": false,
					"bos_token": "<bos>",
					"eos_token": "<eos>",
					"pad_token": "<pad>",
					"unk_token": "<unk>"
				}`),
			}),
			specialTokenTypes: []string{"pad", "eos", "bos", "unk"},
			want: &Tokenizer{
				Vocabulary: &Vocabulary{
					Model:  "gpt2",
					Tokens: []string{"<pad>", "<eos>", "<bos>", "<unk>"},
					Scores: []float32{0, 1, 2, 3},
					Types:  []int32{3, 3, 3, 3},
				},
				SpecialVocabulary: []*SpecialVocabulary{
					{Type: "pad", Content: "<pad>", ID: 0, AddToken: false},
					{Type: "eos", Content: "<eos>", ID: 1, AddToken: false},
					{Type: "bos", Content: "<bos>", ID: 2, AddToken: true},
					{Type: "unk", Content: "<unk>", ID: 3, AddToken: false},
				},
				Pre: "default",
			},
		},
	}

	for _, tt := range cases {
		t.Run(tt.name, func(t *testing.T) {
			tokenizer, err := parseTokenizer(tt.fsys, tt.specialTokenTypes)
			if err != nil {
				t.Fatalf("unexpected error: %v", err)
			}

			if diff := cmp.Diff(tt.want, tokenizer); diff != "" {
				t.Errorf("unexpected tokenizer (-want +got):\n%s", diff)
			}
		})
	}
}