//go:build integration

package integration

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"io"
	"log/slog"
	"math/rand"
	"net"
	"net/http"
	"os"
	"path/filepath"
	"runtime"
	"strconv"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/ollama/ollama/api"
	"github.com/ollama/ollama/app/lifecycle"
	"github.com/stretchr/testify/assert"
)

func FindPort() string {
	port := 0
	if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
		var l *net.TCPListener
		if l, err = net.ListenTCP("tcp", a); err == nil {
			port = l.Addr().(*net.TCPAddr).Port
			l.Close()
		}
	}
	if port == 0 {
		port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
	}
	return strconv.Itoa(port)
}

func GetTestEndpoint() (string, string) {
	defaultPort := "11434"
	ollamaHost := os.Getenv("OLLAMA_HOST")

	scheme, hostport, ok := strings.Cut(ollamaHost, "://")
	if !ok {
		scheme, hostport = "http", ollamaHost
	}

	// trim trailing slashes
	hostport = strings.TrimRight(hostport, "/")

	host, port, err := net.SplitHostPort(hostport)
	if err != nil {
		host, port = "127.0.0.1", defaultPort
		if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
			host = ip.String()
		} else if hostport != "" {
			host = hostport
		}
	}

	if os.Getenv("OLLAMA_TEST_EXISTING") == "" && port == defaultPort {
		port = FindPort()
	}

	url := fmt.Sprintf("%s:%s", host, port)
	slog.Info("server connection", "url", url)
	return scheme, url
}

// TODO make fanicier, grab logs, etc.
var serverMutex sync.Mutex
var serverReady bool

func StartServer(ctx context.Context, ollamaHost string) error {
	// Make sure the server has been built
	CLIName, err := filepath.Abs("../ollama")
	if err != nil {
		return err
	}

	if runtime.GOOS == "windows" {
		CLIName += ".exe"
	}
	_, err = os.Stat(CLIName)
	if err != nil {
		return fmt.Errorf("CLI missing, did you forget to build first?  %w", err)
	}
	serverMutex.Lock()
	defer serverMutex.Unlock()
	if serverReady {
		return nil
	}

	if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
		slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
		os.Setenv("OLLAMA_HOST", ollamaHost)
	}

	slog.Info("starting server", "url", ollamaHost)
	done, err := lifecycle.SpawnServer(ctx, "../ollama")
	if err != nil {
		return fmt.Errorf("failed to start server: %w", err)
	}

	go func() {
		<-ctx.Done()
		serverMutex.Lock()
		defer serverMutex.Unlock()
		exitCode := <-done
		if exitCode > 0 {
			slog.Warn("server failure", "exit", exitCode)
		}
		serverReady = false
	}()

	// TODO wait only long enough for the server to be responsive...
	time.Sleep(500 * time.Millisecond)

	serverReady = true
	return nil
}

func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error {
	slog.Info("checking status of model", "model", modelName)
	showReq := &api.ShowRequest{Name: modelName}
	requestJSON, err := json.Marshal(showReq)
	if err != nil {
		return err
	}

	req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON))
	if err != nil {
		return err
	}

	// Make the request with the HTTP client
	response, err := client.Do(req.WithContext(ctx))
	if err != nil {
		return err
	}
	defer response.Body.Close()
	if response.StatusCode == 200 {
		slog.Info("model already present", "model", modelName)
		return nil
	}
	slog.Info("model missing", "status", response.StatusCode)

	pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
	requestJSON, err = json.Marshal(pullReq)
	if err != nil {
		return err
	}

	req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON))
	if err != nil {
		return err
	}
	slog.Info("pulling", "model", modelName)

	response, err = client.Do(req.WithContext(ctx))
	if err != nil {
		return err
	}
	defer response.Body.Close()
	if response.StatusCode != 200 {
		return fmt.Errorf("failed to pull model") // TODO more details perhaps
	}
	slog.Info("model pulled", "model", modelName)
	return nil
}

var serverProcMutex sync.Mutex

func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {

	// TODO maybe stuff in an init routine?
	lifecycle.InitLogging()

	requestJSON, err := json.Marshal(genReq)
	if err != nil {
		t.Fatalf("Error serializing request: %v", err)
	}
	defer func() {
		if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
			defer serverProcMutex.Unlock()
			if t.Failed() {
				fp, err := os.Open(lifecycle.ServerLogFile)
				if err != nil {
					slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
					return
				}
				data, err := io.ReadAll(fp)
				if err != nil {
					slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
					return
				}
				slog.Warn("SERVER LOG FOLLOWS")
				os.Stderr.Write(data)
				slog.Warn("END OF SERVER")
			}
			err = os.Remove(lifecycle.ServerLogFile)
			if err != nil && !os.IsNotExist(err) {
				slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
			}
		}
	}()
	scheme, testEndpoint := GetTestEndpoint()

	if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
		serverProcMutex.Lock()
		fp, err := os.CreateTemp("", "ollama-server-*.log")
		if err != nil {
			t.Fatalf("failed to generate log file: %s", err)
		}
		lifecycle.ServerLogFile = fp.Name()
		fp.Close()
		assert.NoError(t, StartServer(ctx, testEndpoint))
	}

	err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model)
	if err != nil {
		t.Fatalf("Error pulling model: %v", err)
	}

	// Make the request and get the response
	req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
	if err != nil {
		t.Fatalf("Error creating request: %v", err)
	}

	// Set the content type for the request
	req.Header.Set("Content-Type", "application/json")

	// Make the request with the HTTP client
	response, err := client.Do(req.WithContext(ctx))
	if err != nil {
		t.Fatalf("Error making request: %v", err)
	}
	defer response.Body.Close()
	body, err := io.ReadAll(response.Body)
	assert.NoError(t, err)
	assert.Equal(t, response.StatusCode, 200, string(body))

	// Verify the response is valid JSON
	var payload api.GenerateResponse
	err = json.Unmarshal(body, &payload)
	if err != nil {
		assert.NoError(t, err, body)
	}

	// Verify the response contains the expected data
	atLeastOne := false
	for _, resp := range anyResp {
		if strings.Contains(strings.ToLower(payload.Response), resp) {
			atLeastOne = true
			break
		}
	}
	assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response)
}