perf: pre-compile WASM module for MCP agent

Modified the MCP agent constructor to pre-compile the WASM module when detected. This shifts the costly compilation step out of the first API request path.

The `MCPWasm` implementation now stores the `wazero.CompiledModule` provided by the constructor and uses it directly for instantiation via `runtime.InstantiateModule()` when the agent is first used or restarted. This significantly speeds up the initialization during the first request.

Updated tests and cleanup logic to accommodate the shared nature of the pre-compiled module.
This commit is contained in:
Deluan 2025-04-19 19:23:23 -04:00
parent 8660cb4fff
commit 97b101685e
2 changed files with 64 additions and 68 deletions

View File

@ -107,8 +107,29 @@ func mcpConstructor(ds model.DataStore) agents.Interface {
return nil // Fatal error: WASI required
}
agentImpl = newMCPWasm(runtime, cache)
log.Info(ctx, "Shared Wazero runtime, WASI, cache, and host functions initialized for MCP agent")
// Compile the module
log.Debug(ctx, "Pre-compiling WASM module...", "path", McpServerPath)
wasmBytes, err := os.ReadFile(McpServerPath)
if err != nil {
log.Error(ctx, "Failed to read WASM module file, disabling agent", "path", McpServerPath, "error", err)
_ = runtime.Close(ctx)
if cache != nil {
_ = cache.Close(ctx)
}
return nil
}
compiledModule, err := runtime.CompileModule(ctx, wasmBytes)
if err != nil {
log.Error(ctx, "Failed to pre-compile WASM module, disabling agent", "path", McpServerPath, "error", err)
_ = runtime.Close(ctx)
if cache != nil {
_ = cache.Close(ctx)
}
return nil
}
agentImpl = newMCPWasm(runtime, cache, compiledModule)
log.Info(ctx, "Shared Wazero runtime, WASI, cache, host functions initialized, and module pre-compiled for MCP agent")
} else {
log.Info("Configuring MCP agent for native execution", "path", McpServerPath)
@ -132,7 +153,7 @@ func NewAgentForTesting(mockClient mcpClient) agents.Interface {
// For WASM testing, we might not need the full runtime setup,
// just the struct. Pass nil for shared resources for now.
// We rely on the mockClient being used before any real WASM interaction.
wasmImpl := newMCPWasm(nil, nil)
wasmImpl := newMCPWasm(nil, nil, nil)
wasmImpl.ClientOverride = mockClient
agentImpl = wasmImpl
} else {

View File

@ -27,18 +27,21 @@ type MCPWasm struct {
client mcpClient
// Shared resources (passed in, not owned by this struct)
wasmRuntime api.Closer // Closer for the shared Wazero Runtime
wasmCache wazero.CompilationCache // Shared Compilation Cache (can be nil)
wasmRuntime api.Closer // Shared Wazero Runtime
wasmCache wazero.CompilationCache // Shared Compilation Cache (can be nil)
preCompiledModule wazero.CompiledModule // Pre-compiled module from constructor
// ClientOverride allows injecting a mock client for testing this specific implementation.
ClientOverride mcpClient // TODO: Consider if this is the best way to test
}
// newMCPWasm creates a new instance of the WASM MCP agent implementation.
func newMCPWasm(runtime api.Closer, cache wazero.CompilationCache) *MCPWasm {
// It stores the shared runtime, cache, and the pre-compiled module.
func newMCPWasm(runtime api.Closer, cache wazero.CompilationCache, compiledModule wazero.CompiledModule) *MCPWasm {
return &MCPWasm{
wasmRuntime: runtime,
wasmCache: cache,
wasmRuntime: runtime,
wasmCache: cache,
preCompiledModule: compiledModule,
}
}
@ -81,9 +84,9 @@ func (w *MCPWasm) ensureClientInitialized_locked(ctx context.Context) error {
return errors.New("shared Wazero runtime not initialized for MCPWasm")
}
hostStdinWriter, hostStdoutReader, mod, compiled, startErr := w.startModule_locked(ctx)
if startErr != nil {
log.Error(ctx, "Failed to start WASM MCP server module", "error", startErr)
hostStdinWriter, hostStdoutReader, mod, err := w.startModule_locked(ctx)
if err != nil {
log.Error(ctx, "Failed to start WASM MCP server module", "error", err)
// Ensure pipes are closed if start failed
if hostStdinWriter != nil {
_ = hostStdinWriter.Close()
@ -92,7 +95,7 @@ func (w *MCPWasm) ensureClientInitialized_locked(ctx context.Context) error {
_ = hostStdoutReader.Close()
}
// startModule_locked handles cleanup of mod/compiled on error
return fmt.Errorf("failed to start WASM MCP server: %w", startErr)
return fmt.Errorf("failed to start WASM MCP server: %w", err)
}
transport := stdio.NewStdioServerTransportWithIO(hostStdoutReader, hostStdinWriter)
@ -104,8 +107,8 @@ func (w *MCPWasm) ensureClientInitialized_locked(ctx context.Context) error {
err := fmt.Errorf("failed to initialize WASM MCP client: %w", initErr)
log.Error(ctx, "WASM MCP client initialization failed", "error", err)
// Cleanup the newly started module and close pipes
w.wasmModule = mod // Temporarily set so cleanup can close it
w.wasmCompiled = compiled // Temporarily set so cleanup can close it
w.wasmModule = mod // Temporarily set so cleanup can close it
w.wasmCompiled = nil // We don't store the compiled instance ref anymore, just the module instance
w.cleanupResources_locked()
if hostStdinWriter != nil {
_ = hostStdinWriter.Close()
@ -117,7 +120,7 @@ func (w *MCPWasm) ensureClientInitialized_locked(ctx context.Context) error {
}
w.wasmModule = mod
w.wasmCompiled = compiled
w.wasmCompiled = nil // We don't store the compiled instance ref anymore, just the module instance
w.stdin = hostStdinWriter
w.client = clientImpl
@ -190,32 +193,25 @@ func (w *MCPWasm) cleanupResources_locked() {
cancel()
w.wasmModule = nil
}
// Close the compiled module reference for this instance
if w.wasmCompiled != nil {
log.Debug(context.Background(), "Closing compiled WASM module ref")
if err := w.wasmCompiled.Close(context.Background()); err != nil {
log.Error(context.Background(), "Failed to close compiled WASM module ref", "error", err)
}
w.wasmCompiled = nil
}
// Mark client as invalid
w.client = nil
// DO NOT close w.wasmCompiled (instance ref)
// DO NOT close w.preCompiledModule (shared pre-compiled code)
// DO NOT CLOSE w.wasmRuntime or w.wasmCache here!
w.client = nil
}
// startModule loads and starts the MCP server as a WASM module.
// It now uses the pre-compiled module.
// MUST be called with the mutex HELD.
func (w *MCPWasm) startModule_locked(ctx context.Context) (hostStdinWriter io.WriteCloser, hostStdoutReader io.ReadCloser, mod api.Module, compiled api.Closer, err error) {
log.Debug(ctx, "Loading WASM MCP server module", "path", McpServerPath)
wasmBytes, err := os.ReadFile(McpServerPath)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("read wasm file: %w", err)
func (w *MCPWasm) startModule_locked(ctx context.Context) (hostStdinWriter io.WriteCloser, hostStdoutReader io.ReadCloser, mod api.Module, err error) {
// Check for pre-compiled module
if w.preCompiledModule == nil {
return nil, nil, nil, errors.New("pre-compiled WASM module is nil")
}
// Create pipes for stdio redirection
wasmStdinReader, hostStdinWriter, err := os.Pipe()
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("wasm stdin pipe: %w", err)
return nil, nil, nil, fmt.Errorf("wasm stdin pipe: %w", err)
}
// Defer close pipes on error exit
shouldClosePipesOnError := true
@ -233,7 +229,7 @@ func (w *MCPWasm) startModule_locked(ctx context.Context) (hostStdinWriter io.Wr
hostStdoutReader, wasmStdoutWriter, err := os.Pipe()
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("wasm stdout pipe: %w", err)
return nil, nil, nil, fmt.Errorf("wasm stdout pipe: %w", err)
}
// Defer close pipes on error exit
defer func() {
@ -250,11 +246,10 @@ func (w *MCPWasm) startModule_locked(ctx context.Context) (hostStdinWriter io.Wr
// Use the SHARDED runtime from the agent struct
runtime, ok := w.wasmRuntime.(wazero.Runtime)
if !ok || runtime == nil {
return nil, nil, nil, nil, errors.New("wasmRuntime is not initialized or not a wazero.Runtime")
return nil, nil, nil, errors.New("wasmRuntime is not initialized or not a wazero.Runtime")
}
// Prepare module configuration
// Host functions and WASI are already part of the shared runtime
// Prepare module configuration (host funcs/WASI already instantiated on runtime)
config := wazero.NewModuleConfig().
WithStdin(wasmStdinReader).
WithStdout(wasmStdoutWriter).
@ -262,27 +257,13 @@ func (w *MCPWasm) startModule_locked(ctx context.Context) (hostStdinWriter io.Wr
WithArgs(McpServerPath).
WithFS(os.DirFS("/")) // Keep FS access for now
log.Debug(ctx, "Compiling WASM module (using cache if enabled)...")
// Compile module using the shared runtime
compiledModule, err := runtime.CompileModule(ctx, wasmBytes)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("compile wasm module: %w", err)
}
// Defer closing compiled module only if an error occurs later in this function.
shouldCloseCompiledOnError := true
defer func() {
if shouldCloseCompiledOnError && compiledModule != nil {
_ = compiledModule.Close(context.Background())
}
}()
log.Info(ctx, "Instantiating WASM module (will run _start)...")
var instance api.Module
log.Info(ctx, "Instantiating pre-compiled WASM module (will run _start)...")
var moduleInstance api.Module
instanceErrChan := make(chan error, 1)
go func() {
var instantiateErr error
// Use context.Background() for the module's main execution context
instance, instantiateErr = runtime.InstantiateModule(context.Background(), compiledModule, config)
moduleInstance, instantiateErr = runtime.InstantiateModule(context.Background(), w.preCompiledModule, config)
instanceErrChan <- instantiateErr
}()
@ -290,19 +271,17 @@ func (w *MCPWasm) startModule_locked(ctx context.Context) (hostStdinWriter io.Wr
select {
case instantiateErr := <-instanceErrChan:
if instantiateErr != nil {
log.Error(ctx, "Failed to instantiate WASM module", "error", instantiateErr)
// compiledModule closed by defer
// pipes closed by defer
return nil, nil, nil, nil, fmt.Errorf("instantiate wasm module: %w", instantiateErr)
log.Error(ctx, "Failed to instantiate pre-compiled WASM module", "error", instantiateErr)
// Pipes closed by defer
return nil, nil, nil, fmt.Errorf("instantiate wasm module: %w", instantiateErr)
}
log.Warn(ctx, "WASM module instantiation returned (exited?) immediately without error.")
case <-time.After(2 * time.Second):
case <-time.After(1 * time.Second): // Shorter wait now, instantiation should be faster
log.Debug(ctx, "WASM module instantiation likely blocking (server running), proceeding...")
}
// Start a monitoring goroutine for WASM module exit/error
// Pass required values to the goroutine closure
go func(instanceToMonitor api.Module, compiledToClose api.Closer, errChan chan error) {
go func(instanceToMonitor api.Module, errChan chan error) {
// This blocks until the instance created by InstantiateModule exits or errors.
instantiateErr := <-errChan
@ -314,17 +293,13 @@ func (w *MCPWasm) startModule_locked(ctx context.Context) (hostStdinWriter io.Wr
w.cleanupResources_locked() // Use the locked version
log.Info("MCP WASM agent state cleaned up after module exit/error")
} else {
log.Debug("WASM module exited, but state already updated/module mismatch. Explicitly closing this instance's compiled ref.")
// Manually close the compiled module ref associated with *this specific instance*
if compiledToClose != nil {
_ = compiledToClose.Close(context.Background())
}
log.Debug("WASM module exited, but state already updated/module mismatch. No cleanup needed by this goroutine.")
// No need to close anything here, the pre-compiled module is shared
}
w.mu.Unlock()
}(instance, compiledModule, instanceErrChan)
}(moduleInstance, instanceErrChan)
// Success: prevent deferred cleanup of pipes and compiled module
// Success: prevent deferred cleanup of pipes
shouldClosePipesOnError = false
shouldCloseCompiledOnError = false
return hostStdinWriter, hostStdoutReader, instance, compiledModule, nil
return hostStdinWriter, hostStdoutReader, moduleInstance, nil
}