From 6b89f7ab63ad209b0c99b32f10220788c5ea9b22 Mon Sep 17 00:00:00 2001 From: Deluan Date: Sat, 19 Apr 2025 14:28:55 -0400 Subject: [PATCH] feat: integrate Wazero for WASM support in MCPAgent Enhance MCPAgent to support both native executables and WASM modules using Wazero. This includes: - Adding Wazero dependencies in go.mod and go.sum. - Modifying MCPAgent to initialize a shared Wazero runtime and compile/load WASM modules. - Implementing cleanup logic for WASM resources. - Updating the process initialization to handle both native and WASM paths. This change improves the flexibility of the MCPAgent in handling different server types. --- core/agents/mcp/mcp_agent.go | 387 ++++++++++++++++++++++++++++------- go.mod | 3 +- go.sum | 2 + 3 files changed, 321 insertions(+), 71 deletions(-) diff --git a/core/agents/mcp/mcp_agent.go b/core/agents/mcp/mcp_agent.go index e25251d3c..9634a8bb7 100644 --- a/core/agents/mcp/mcp_agent.go +++ b/core/agents/mcp/mcp_agent.go @@ -7,13 +7,18 @@ import ( "io" "os" "os/exec" + "path/filepath" "strings" "sync" "time" mcp "github.com/metoro-io/mcp-golang" "github.com/metoro-io/mcp-golang/transport/stdio" + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/api" + "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" + "github.com/navidrome/navidrome/conf" "github.com/navidrome/navidrome/core/agents" "github.com/navidrome/navidrome/log" "github.com/navidrome/navidrome/model" @@ -22,7 +27,7 @@ import ( // Exported constants for testing const ( McpAgentName = "mcp" - McpServerPath = "/Users/deluan/Development/navidrome/plugins-mcp/mcp-server" + McpServerPath = "/Users/deluan/Development/navidrome/plugins-mcp/mcp-server.wasm" McpToolNameGetBio = "get_artist_biography" McpToolNameGetURL = "get_artist_url" initializationTimeout = 10 * time.Second @@ -37,12 +42,22 @@ type mcpClient interface { // MCPAgent interacts with an external MCP server for metadata retrieval. // It keeps a single instance of the server process running and attempts restart on failure. +// Supports both native executables and WASM modules (via Wazero). type MCPAgent struct { mu sync.Mutex - cmd *exec.Cmd - stdin io.WriteCloser - client mcpClient // Use the interface type here + // Runtime state + stdin io.WriteCloser + client mcpClient + cmd *exec.Cmd // Stores the native process command + wasmModule api.Module // Stores the instantiated WASM module + + // Shared Wazero resources (created once) + wasmRuntime api.Closer // Shared Wazero Runtime (implements Close(context.Context)) + wasmCache wazero.CompilationCache // Shared Compilation Cache (implements Close(context.Context)) + + // WASM resources per instance (cleaned up by monitoring goroutine) + wasmCompiled api.Closer // Stores the compiled WASM module for closing // ClientOverride allows injecting a mock client for testing. // This field should ONLY be set in test code. @@ -51,20 +66,104 @@ type MCPAgent struct { func mcpConstructor(ds model.DataStore) agents.Interface { // Check if the MCP server executable exists - if _, err := os.Stat(McpServerPath); os.IsNotExist(err) { // Use exported constant - log.Warn("MCP server executable not found, disabling agent", "path", McpServerPath, "error", err) + if _, err := os.Stat(McpServerPath); os.IsNotExist(err) { + log.Warn("MCP server executable/WASM not found, disabling agent", "path", McpServerPath, "error", err) return nil } + + a := &MCPAgent{} + + // If it's a WASM path, pre-initialize the shared Wazero runtime and cache. + if strings.HasSuffix(McpServerPath, ".wasm") { + ctx := context.Background() // Use background context for setup + cacheDir := filepath.Join(conf.Server.DataFolder, "cache", "wazero") + if err := os.MkdirAll(cacheDir, 0755); err != nil { + log.Error(ctx, "Failed to create Wazero cache directory, WASM caching disabled", "path", cacheDir, "error", err) + } else { + cache, err := wazero.NewCompilationCacheWithDir(cacheDir) + if err != nil { + log.Error(ctx, "Failed to create Wazero compilation cache, WASM caching disabled", "path", cacheDir, "error", err) + } else { + // Store the specific cache type + a.wasmCache = cache + log.Info(ctx, "Wazero compilation cache enabled", "path", cacheDir) + } + } + + // Create runtime config, adding cache if it was created successfully + runtimeConfig := wazero.NewRuntimeConfig() + if a.wasmCache != nil { + // Use the stored cache directly (already correct type) + runtimeConfig = runtimeConfig.WithCompilationCache(a.wasmCache) + } + + // Create the shared runtime + runtime := wazero.NewRuntimeWithConfig(ctx, runtimeConfig) + a.wasmRuntime = runtime // Store the runtime closer + + // Instantiate WASI onto the shared runtime. If this fails, the agent is unusable for WASM. + if _, err := wasi_snapshot_preview1.Instantiate(ctx, runtime); err != nil { + log.Error(ctx, "Failed to instantiate WASI on shared Wazero runtime, MCP WASM agent disabled", "error", err) + // Close runtime and cache if WASI fails + _ = runtime.Close(ctx) + if a.wasmCache != nil { + _ = a.wasmCache.Close(ctx) // Use Close(ctx) + } + return nil // Cannot proceed if WASI fails + } + log.Info(ctx, "Shared Wazero runtime and WASI initialized for MCP agent") + } + log.Info("MCP Agent created, server will be started on first request", "serverPath", McpServerPath) - return &MCPAgent{} + return a } func (a *MCPAgent) AgentName() string { - return McpAgentName // Use exported constant + return McpAgentName } -// ensureClientInitialized starts the MCP server process and initializes the client if needed. -// It now attempts restart if the client is found to be nil. +// cleanup closes existing resources (stdin, server process/module). +// MUST be called while holding the mutex. +func (a *MCPAgent) cleanup() { + log.Debug(context.Background(), "Cleaning up MCP agent instance resources...") + if a.stdin != nil { + _ = a.stdin.Close() + a.stdin = nil + } + // Clean up native process if it exists + if a.cmd != nil && a.cmd.Process != nil { + log.Debug(context.Background(), "Killing native MCP process", "pid", a.cmd.Process.Pid) + if err := a.cmd.Process.Kill(); err != nil { + log.Error(context.Background(), "Failed to kill native process", "pid", a.cmd.Process.Pid, "error", err) + } + _ = a.cmd.Wait() + a.cmd = nil + } + // Clean up WASM module instance if it exists + if a.wasmModule != nil { + log.Debug(context.Background(), "Closing WASM module instance") + ctxClose, cancel := context.WithTimeout(context.Background(), 2*time.Second) + if err := a.wasmModule.Close(ctxClose); err != nil { + log.Error(context.Background(), "Failed to close WASM module instance", "error", err) + } + cancel() + a.wasmModule = nil + } + // Clean up compiled module ref for this instance + if a.wasmCompiled != nil { + log.Debug(context.Background(), "Closing compiled WASM module ref") + _ = a.wasmCompiled.Close(context.Background()) + a.wasmCompiled = nil + } + + // DO NOT close shared wasmRuntime or wasmCache here. + + // Mark client as invalid + a.client = nil +} + +// ensureClientInitialized starts the MCP server process (native or WASM) +// and initializes the client if needed. Attempts restart on failure. func (a *MCPAgent) ensureClientInitialized(ctx context.Context) (err error) { // --- Use override if provided (for testing) --- if a.ClientOverride != nil { @@ -78,97 +177,245 @@ func (a *MCPAgent) ensureClientInitialized(ctx context.Context) (err error) { } // --- Check if already initialized (critical section) --- - a.mu.Lock() // Acquire lock *before* checking client + a.mu.Lock() if a.client != nil { - a.mu.Unlock() // Unlock and return if already initialized + a.mu.Unlock() return nil } // --- Client is nil, proceed with initialization *while holding the lock* --- - // Defer unlock to ensure it's released even on errors during init. defer a.mu.Unlock() log.Info(ctx, "Initializing MCP client and starting/restarting server process...", "serverPath", McpServerPath) - // Use background context for the command itself - cmd := exec.CommandContext(context.Background(), McpServerPath) + // Clean up any old resources *before* starting new ones + a.cleanup() - var stdin io.WriteCloser - var stdout io.ReadCloser - var stderr strings.Builder + var hostStdinWriter io.WriteCloser + var hostStdoutReader io.ReadCloser + var startErr error + var isWasm bool - stdin, err = cmd.StdinPipe() - if err != nil { - err = fmt.Errorf("failed to get stdin pipe for MCP server: %w", err) - log.Error(ctx, "MCP init/restart failed", "error", err) - return err // defer mu.Unlock() will run - } - - stdout, err = cmd.StdoutPipe() - if err != nil { - err = fmt.Errorf("failed to get stdout pipe for MCP server: %w", err) - _ = stdin.Close() // Clean up stdin pipe - log.Error(ctx, "MCP init/restart failed", "error", err) - return err // defer mu.Unlock() will run - } - - cmd.Stderr = &stderr - - if err = cmd.Start(); err != nil { - err = fmt.Errorf("failed to start MCP server process: %w", err) - _ = stdin.Close() - _ = stdout.Close() - log.Error(ctx, "MCP init/restart failed", "error", err) - return err // defer mu.Unlock() will run - } - - currentPid := cmd.Process.Pid - log.Info(ctx, "MCP server process started/restarted", "pid", currentPid) - - // --- Start monitoring goroutine for *this* process --- - go func(processCmd *exec.Cmd, processStderr *strings.Builder, processPid int) { - waitErr := processCmd.Wait() - a.mu.Lock() - log.Warn("MCP server process exited", "pid", processPid, "error", waitErr, "stderr", processStderr.String()) - if a.cmd == processCmd { // Check if state belongs to the process that just exited - if a.stdin != nil { - _ = a.stdin.Close() - a.stdin = nil - } - a.client = nil - a.cmd = nil - log.Info("MCP agent state cleaned up after process exit", "pid", processPid) + if strings.HasSuffix(McpServerPath, ".wasm") { + isWasm = true + // Check if shared runtime exists (it should if constructor succeeded for WASM) + if a.wasmRuntime == nil { + startErr = errors.New("shared Wazero runtime not initialized") } else { - log.Debug("MCP agent process exited, but state already updated by newer process", "exitedPid", processPid) + var mod api.Module + var compiled api.Closer // Store compiled module ref per instance + hostStdinWriter, hostStdoutReader, mod, compiled, startErr = a.startWasmModule(ctx) + if startErr == nil { + a.wasmModule = mod + // wasmRuntime is already set + a.wasmCompiled = compiled // Store compiled ref for cleanup + } else { + // Ensure potential partial resources from startWasmModule are closed on error + if mod != nil { + _ = mod.Close(ctx) + } + if compiled != nil { + _ = compiled.Close(ctx) + } + // Do not close shared runtime here + } } - a.mu.Unlock() - }(cmd, &stderr, currentPid) + } else { + isWasm = false + var nativeCmd *exec.Cmd + hostStdinWriter, hostStdoutReader, nativeCmd, startErr = a.startNativeProcess(ctx) + if startErr == nil { + a.cmd = nativeCmd + } + } + + if startErr != nil { + log.Error(ctx, "Failed to start MCP server process/module", "isWasm", isWasm, "error", startErr) + // Ensure pipes are closed if start failed + if hostStdinWriter != nil { + _ = hostStdinWriter.Close() + } + if hostStdoutReader != nil { + _ = hostStdoutReader.Close() + } + // a.cleanup() was already called, specific resources (cmd/wasmModule) are nil + return fmt.Errorf("failed to start MCP server: %w", startErr) + } // --- Initialize MCP client --- - transport := stdio.NewStdioServerTransportWithIO(stdout, stdin) + transport := stdio.NewStdioServerTransportWithIO(hostStdoutReader, hostStdinWriter) clientImpl := mcp.NewClient(transport) initCtx, cancel := context.WithTimeout(context.Background(), initializationTimeout) defer cancel() if _, err = clientImpl.Initialize(initCtx); err != nil { err = fmt.Errorf("failed to initialize MCP client: %w", err) - log.Error(ctx, "MCP client initialization failed after process start", "pid", currentPid, "error", err) - if killErr := cmd.Process.Kill(); killErr != nil { - log.Error(ctx, "Failed to kill MCP server process after init failure", "pid", currentPid, "error", killErr) - } + log.Error(ctx, "MCP client initialization failed after process/module start", "isWasm", isWasm, "error", err) + // Cleanup the newly started process/module and pipes as init failed + a.cleanup() + _ = hostStdinWriter.Close() + _ = hostStdoutReader.Close() return err // defer mu.Unlock() will run } // --- Initialization successful, update agent state (still holding lock) --- - a.cmd = cmd - a.stdin = stdin + a.stdin = hostStdinWriter // This is the pipe the agent writes to a.client = clientImpl + // cmd or wasmModule/Runtime/Compiled are already set by the start helpers - log.Info(ctx, "MCP client initialized successfully", "pid", currentPid) + log.Info(ctx, "MCP client initialized successfully", "isWasm", isWasm) // defer mu.Unlock() runs here return nil // Success } +// startNativeProcess starts the MCP server as a native executable. +func (a *MCPAgent) startNativeProcess(ctx context.Context) (stdin io.WriteCloser, stdout io.ReadCloser, cmd *exec.Cmd, err error) { + log.Debug(ctx, "Starting native MCP server process", "path", McpServerPath) + cmd = exec.CommandContext(context.Background(), McpServerPath) + + stdin, err = cmd.StdinPipe() + if err != nil { + return nil, nil, nil, fmt.Errorf("native stdin pipe: %w", err) + } + + stdout, err = cmd.StdoutPipe() + if err != nil { + _ = stdin.Close() + return nil, nil, nil, fmt.Errorf("native stdout pipe: %w", err) + } + + var stderr strings.Builder + cmd.Stderr = &stderr + + if err = cmd.Start(); err != nil { + _ = stdin.Close() + _ = stdout.Close() + return nil, nil, nil, fmt.Errorf("native start: %w", err) + } + + currentPid := cmd.Process.Pid + log.Info(ctx, "Native MCP server process started", "pid", currentPid) + + // Start monitoring goroutine + go func(processCmd *exec.Cmd, processStderr *strings.Builder, processPid int) { + waitErr := processCmd.Wait() // Wait for the process to exit + a.mu.Lock() + log.Warn("Native MCP server process exited", "pid", processPid, "error", waitErr, "stderr", processStderr.String()) + // Check if the cmd matches the one we are monitoring before cleaning up + if a.cmd == processCmd { + a.cleanup() // Use the central cleanup function + log.Info("MCP agent state cleaned up after native process exit", "pid", processPid) + } else { + log.Debug("Native MCP agent process exited, but state already updated or cmd mismatch", "exitedPid", processPid) + } + a.mu.Unlock() + }(cmd, &stderr, currentPid) + + // Return the pipes connected to the process and the Cmd object + return stdin, stdout, cmd, nil +} + +// startWasmModule loads and starts the MCP server as a WASM module using the agent's shared Wazero runtime. +func (a *MCPAgent) startWasmModule(ctx context.Context) (hostStdinWriter io.WriteCloser, hostStdoutReader io.ReadCloser, mod api.Module, compiled api.Closer, err error) { + log.Debug(ctx, "Loading WASM MCP server module", "path", McpServerPath) + wasmBytes, err := os.ReadFile(McpServerPath) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("read wasm file: %w", err) + } + + // Create pipes for stdio redirection + wasmStdinReader, hostStdinWriter, err := os.Pipe() + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("wasm stdin pipe: %w", err) + } + hostStdoutReader, wasmStdoutWriter, err := os.Pipe() + if err != nil { + _ = wasmStdinReader.Close() + _ = hostStdinWriter.Close() + return nil, nil, nil, nil, fmt.Errorf("wasm stdout pipe: %w", err) + } + + // Use the SHARDED runtime from the agent struct + runtime := a.wasmRuntime.(wazero.Runtime) // Type assert to get underlying Runtime + // WASI is already instantiated on the shared runtime + + config := wazero.NewModuleConfig(). + WithStdin(wasmStdinReader). + WithStdout(wasmStdoutWriter). + WithStderr(os.Stderr). + WithArgs(McpServerPath) + + log.Debug(ctx, "Compiling WASM module (using cache if enabled)...") + // Compile module using the shared runtime (which uses the configured cache) + compiledModule, err := runtime.CompileModule(ctx, wasmBytes) + if err != nil { + _ = wasmStdinReader.Close() + _ = hostStdinWriter.Close() + _ = hostStdoutReader.Close() + _ = wasmStdoutWriter.Close() + return nil, nil, nil, nil, fmt.Errorf("compile wasm module: %w", err) + } + // Defer closing compiled module in case of errors later in this function. + shouldCloseOnError := true + defer func() { + if shouldCloseOnError && compiledModule != nil { + _ = compiledModule.Close(context.Background()) + } + }() + + log.Info(ctx, "Instantiating WASM module (will run _start)...") + var instance api.Module + instanceErrChan := make(chan error, 1) + go func() { + var instantiateErr error + instance, instantiateErr = runtime.InstantiateModule(context.Background(), compiledModule, config) + instanceErrChan <- instantiateErr + }() + + // Wait briefly for immediate instantiation errors + select { + case instantiateErr := <-instanceErrChan: + if instantiateErr != nil { + log.Error(ctx, "Failed to instantiate WASM module", "error", instantiateErr) + _ = wasmStdinReader.Close() + _ = hostStdinWriter.Close() + _ = hostStdoutReader.Close() + _ = wasmStdoutWriter.Close() + // compiledModule closed by defer + return nil, nil, nil, nil, fmt.Errorf("instantiate wasm module: %w", instantiateErr) + } + log.Warn(ctx, "WASM module instantiation returned (exited?) unexpectedly quickly.") + case <-time.After(2 * time.Second): + log.Debug(ctx, "WASM module instantiation likely blocking (server running), proceeding...") + } + + // Start a monitoring goroutine for WASM module exit/error + go func(modToMonitor api.Module, compiledToClose api.Closer, errChan chan error) { + instantiateErr := <-errChan + + a.mu.Lock() + log.Warn("WASM module exited/errored", "error", instantiateErr) + // Check if the module currently stored in the agent is the one we were monitoring. + // Compare module instance directly. Instance might be nil if instantiation failed. + if a.wasmModule != nil && a.wasmModule == modToMonitor { + a.cleanup() // This will close the module instance and compiled ref + log.Info("MCP agent state cleaned up after WASM module exit/error") + } else { + log.Debug("WASM module exited, but state already updated or module mismatch") + // Manually close the compiled module ref associated with this specific instance + // as cleanup() won't if a.wasmModule doesn't match. + if compiledToClose != nil { + _ = compiledToClose.Close(context.Background()) + } + } + a.mu.Unlock() + }(instance, compiledModule, instanceErrChan) // Pass necessary refs + + // Success: prevent deferred cleanup, return resources needed by caller + shouldCloseOnError = false + return hostStdinWriter, hostStdoutReader, instance, compiledModule, nil // Return instance and compiled module +} + // ArtistArgs defines the structure for MCP tool arguments requiring artist info. // Exported for use in tests. type ArtistArgs struct { diff --git a/go.mod b/go.mod index 7229553f6..88b148453 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/lestrrat-go/jwx/v2 v2.1.4 github.com/matoous/go-nanoid/v2 v2.1.0 github.com/mattn/go-sqlite3 v1.14.27 + github.com/metoro-io/mcp-golang v0.11.0 github.com/microcosm-cc/bluemonday v1.0.27 github.com/mileusna/useragent v1.3.5 github.com/onsi/ginkgo/v2 v2.23.4 @@ -53,6 +54,7 @@ require ( github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.20.1 github.com/stretchr/testify v1.10.0 + github.com/tetratelabs/wazero v1.9.0 github.com/unrolled/secure v1.17.0 github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 go.uber.org/goleak v1.3.0 @@ -100,7 +102,6 @@ require ( github.com/lestrrat-go/iter v1.0.2 // indirect github.com/lestrrat-go/option v1.0.1 // indirect github.com/mailru/easyjson v0.7.7 // indirect - github.com/metoro-io/mcp-golang v0.11.0 // indirect github.com/mfridman/interpolate v0.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/ogier/pflag v0.0.1 // indirect diff --git a/go.sum b/go.sum index ffc354925..c8e85a204 100644 --- a/go.sum +++ b/go.sum @@ -247,6 +247,8 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I= +github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=