mirror of
https://github.com/navidrome/navidrome.git
synced 2025-05-08 22:31:07 +03:00
refactor: separate native and WASM process handling in MCPAgent
- Moved the native process handling logic from mcp_agent.go to a new file mcp_process_native.go for better organization. - Introduced a new file mcp_host_functions.go to define and register host functions for WASM modules. - Updated MCPAgent to ensure proper initialization and cleanup of both native and WASM processes, enhancing code clarity and maintainability. - Added comments to clarify the purpose of changes and ensure future developers understand the structure.
This commit is contained in:
parent
674129a34b
commit
73da7550d6
@ -27,7 +27,7 @@ import (
|
|||||||
// Exported constants for testing
|
// Exported constants for testing
|
||||||
const (
|
const (
|
||||||
McpAgentName = "mcp"
|
McpAgentName = "mcp"
|
||||||
McpServerPath = "/Users/deluan/Development/navidrome/plugins-mcp/mcp-server"
|
McpServerPath = "/Users/deluan/Development/navidrome/plugins-mcp/mcp-server.wasm"
|
||||||
McpToolNameGetBio = "get_artist_biography"
|
McpToolNameGetBio = "get_artist_biography"
|
||||||
McpToolNameGetURL = "get_artist_url"
|
McpToolNameGetURL = "get_artist_url"
|
||||||
initializationTimeout = 10 * time.Second
|
initializationTimeout = 10 * time.Second
|
||||||
@ -73,7 +73,7 @@ func mcpConstructor(ds model.DataStore) agents.Interface {
|
|||||||
|
|
||||||
a := &MCPAgent{}
|
a := &MCPAgent{}
|
||||||
|
|
||||||
// If it's a WASM path, pre-initialize the shared Wazero runtime and cache.
|
// If it's a WASM path, pre-initialize the shared Wazero runtime, cache, and host functions.
|
||||||
if strings.HasSuffix(McpServerPath, ".wasm") {
|
if strings.HasSuffix(McpServerPath, ".wasm") {
|
||||||
ctx := context.Background() // Use background context for setup
|
ctx := context.Background() // Use background context for setup
|
||||||
cacheDir := filepath.Join(conf.Server.DataFolder, "cache", "wazero")
|
cacheDir := filepath.Join(conf.Server.DataFolder, "cache", "wazero")
|
||||||
@ -84,7 +84,6 @@ func mcpConstructor(ds model.DataStore) agents.Interface {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx, "Failed to create Wazero compilation cache, WASM caching disabled", "path", cacheDir, "error", err)
|
log.Error(ctx, "Failed to create Wazero compilation cache, WASM caching disabled", "path", cacheDir, "error", err)
|
||||||
} else {
|
} else {
|
||||||
// Store the specific cache type
|
|
||||||
a.wasmCache = cache
|
a.wasmCache = cache
|
||||||
log.Info(ctx, "Wazero compilation cache enabled", "path", cacheDir)
|
log.Info(ctx, "Wazero compilation cache enabled", "path", cacheDir)
|
||||||
}
|
}
|
||||||
@ -93,12 +92,24 @@ func mcpConstructor(ds model.DataStore) agents.Interface {
|
|||||||
// Create runtime config, adding cache if it was created successfully
|
// Create runtime config, adding cache if it was created successfully
|
||||||
runtimeConfig := wazero.NewRuntimeConfig()
|
runtimeConfig := wazero.NewRuntimeConfig()
|
||||||
if a.wasmCache != nil {
|
if a.wasmCache != nil {
|
||||||
// Use the stored cache directly (already correct type)
|
|
||||||
runtimeConfig = runtimeConfig.WithCompilationCache(a.wasmCache)
|
runtimeConfig = runtimeConfig.WithCompilationCache(a.wasmCache)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the shared runtime
|
// Create the shared runtime
|
||||||
runtime := wazero.NewRuntimeWithConfig(ctx, runtimeConfig)
|
runtime := wazero.NewRuntimeWithConfig(ctx, runtimeConfig)
|
||||||
|
|
||||||
|
// --- Register Host Functions --- Must happen BEFORE WASI instantiation if WASI needs them?
|
||||||
|
// Actually, WASI instantiation is separate from host func instantiation.
|
||||||
|
if err := registerHostFunctions(ctx, runtime); err != nil {
|
||||||
|
// Error already logged by registerHostFunctions
|
||||||
|
_ = runtime.Close(ctx)
|
||||||
|
if a.wasmCache != nil {
|
||||||
|
_ = a.wasmCache.Close(ctx)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// --- End Host Function Registration ---
|
||||||
|
|
||||||
a.wasmRuntime = runtime // Store the runtime closer
|
a.wasmRuntime = runtime // Store the runtime closer
|
||||||
|
|
||||||
// Instantiate WASI onto the shared runtime. If this fails, the agent is unusable for WASM.
|
// Instantiate WASI onto the shared runtime. If this fails, the agent is unusable for WASM.
|
||||||
@ -107,11 +118,11 @@ func mcpConstructor(ds model.DataStore) agents.Interface {
|
|||||||
// Close runtime and cache if WASI fails
|
// Close runtime and cache if WASI fails
|
||||||
_ = runtime.Close(ctx)
|
_ = runtime.Close(ctx)
|
||||||
if a.wasmCache != nil {
|
if a.wasmCache != nil {
|
||||||
_ = a.wasmCache.Close(ctx) // Use Close(ctx)
|
_ = a.wasmCache.Close(ctx)
|
||||||
}
|
}
|
||||||
return nil // Cannot proceed if WASI fails
|
return nil // Cannot proceed if WASI fails
|
||||||
}
|
}
|
||||||
log.Info(ctx, "Shared Wazero runtime and WASI initialized for MCP agent")
|
log.Info(ctx, "Shared Wazero runtime, WASI, and host functions initialized for MCP agent")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("MCP Agent created, server will be started on first request", "serverPath", McpServerPath)
|
log.Info("MCP Agent created, server will be started on first request", "serverPath", McpServerPath)
|
||||||
@ -133,9 +144,10 @@ func (a *MCPAgent) cleanup() {
|
|||||||
// Clean up native process if it exists
|
// Clean up native process if it exists
|
||||||
if a.cmd != nil && a.cmd.Process != nil {
|
if a.cmd != nil && a.cmd.Process != nil {
|
||||||
log.Debug(context.Background(), "Killing native MCP process", "pid", a.cmd.Process.Pid)
|
log.Debug(context.Background(), "Killing native MCP process", "pid", a.cmd.Process.Pid)
|
||||||
if err := a.cmd.Process.Kill(); err != nil {
|
if err := a.cmd.Process.Kill(); err != nil && !errors.Is(err, os.ErrProcessDone) {
|
||||||
log.Error(context.Background(), "Failed to kill native process", "pid", a.cmd.Process.Pid, "error", err)
|
log.Error(context.Background(), "Failed to kill native process", "pid", a.cmd.Process.Pid, "error", err)
|
||||||
}
|
}
|
||||||
|
// Wait might return an error if already killed/exited, ignore it.
|
||||||
_ = a.cmd.Wait()
|
_ = a.cmd.Wait()
|
||||||
a.cmd = nil
|
a.cmd = nil
|
||||||
}
|
}
|
||||||
@ -144,7 +156,10 @@ func (a *MCPAgent) cleanup() {
|
|||||||
log.Debug(context.Background(), "Closing WASM module instance")
|
log.Debug(context.Background(), "Closing WASM module instance")
|
||||||
ctxClose, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
ctxClose, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
if err := a.wasmModule.Close(ctxClose); err != nil {
|
if err := a.wasmModule.Close(ctxClose); err != nil {
|
||||||
log.Error(context.Background(), "Failed to close WASM module instance", "error", err)
|
// Ignore context deadline exceeded as it means close was successful but slow
|
||||||
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
log.Error(context.Background(), "Failed to close WASM module instance", "error", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
a.wasmModule = nil
|
a.wasmModule = nil
|
||||||
@ -152,7 +167,10 @@ func (a *MCPAgent) cleanup() {
|
|||||||
// Clean up compiled module ref for this instance
|
// Clean up compiled module ref for this instance
|
||||||
if a.wasmCompiled != nil {
|
if a.wasmCompiled != nil {
|
||||||
log.Debug(context.Background(), "Closing compiled WASM module ref")
|
log.Debug(context.Background(), "Closing compiled WASM module ref")
|
||||||
_ = a.wasmCompiled.Close(context.Background())
|
// Use background context, Close should be quick
|
||||||
|
if err := a.wasmCompiled.Close(context.Background()); err != nil {
|
||||||
|
log.Error(context.Background(), "Failed to close compiled WASM module ref", "error", err)
|
||||||
|
}
|
||||||
a.wasmCompiled = nil
|
a.wasmCompiled = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -207,16 +225,14 @@ func (a *MCPAgent) ensureClientInitialized(ctx context.Context) (err error) {
|
|||||||
hostStdinWriter, hostStdoutReader, mod, compiled, startErr = a.startWasmModule(ctx)
|
hostStdinWriter, hostStdoutReader, mod, compiled, startErr = a.startWasmModule(ctx)
|
||||||
if startErr == nil {
|
if startErr == nil {
|
||||||
a.wasmModule = mod
|
a.wasmModule = mod
|
||||||
// wasmRuntime is already set
|
|
||||||
a.wasmCompiled = compiled // Store compiled ref for cleanup
|
a.wasmCompiled = compiled // Store compiled ref for cleanup
|
||||||
} else {
|
} else {
|
||||||
// Ensure potential partial resources from startWasmModule are closed on error
|
// Ensure potential partial resources from startWasmModule are closed on error
|
||||||
|
// startWasmModule's deferred cleanup should handle pipes and compiled module.
|
||||||
|
// Mod instance might need closing if instantiation partially succeeded before erroring.
|
||||||
if mod != nil {
|
if mod != nil {
|
||||||
_ = mod.Close(ctx)
|
_ = mod.Close(ctx)
|
||||||
}
|
}
|
||||||
if compiled != nil {
|
|
||||||
_ = compiled.Close(ctx)
|
|
||||||
}
|
|
||||||
// Do not close shared runtime here
|
// Do not close shared runtime here
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -231,7 +247,7 @@ func (a *MCPAgent) ensureClientInitialized(ctx context.Context) (err error) {
|
|||||||
|
|
||||||
if startErr != nil {
|
if startErr != nil {
|
||||||
log.Error(ctx, "Failed to start MCP server process/module", "isWasm", isWasm, "error", startErr)
|
log.Error(ctx, "Failed to start MCP server process/module", "isWasm", isWasm, "error", startErr)
|
||||||
// Ensure pipes are closed if start failed
|
// Ensure pipes are closed if start failed (start functions might have deferred closes, but belt-and-suspenders)
|
||||||
if hostStdinWriter != nil {
|
if hostStdinWriter != nil {
|
||||||
_ = hostStdinWriter.Close()
|
_ = hostStdinWriter.Close()
|
||||||
}
|
}
|
||||||
@ -242,19 +258,24 @@ func (a *MCPAgent) ensureClientInitialized(ctx context.Context) (err error) {
|
|||||||
return fmt.Errorf("failed to start MCP server: %w", startErr)
|
return fmt.Errorf("failed to start MCP server: %w", startErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Initialize MCP client ---
|
// --- Initialize MCP client --- (Ensure stdio transport import)
|
||||||
transport := stdio.NewStdioServerTransportWithIO(hostStdoutReader, hostStdinWriter)
|
transport := stdio.NewStdioServerTransportWithIO(hostStdoutReader, hostStdinWriter)
|
||||||
clientImpl := mcp.NewClient(transport)
|
clientImpl := mcp.NewClient(transport)
|
||||||
|
|
||||||
initCtx, cancel := context.WithTimeout(context.Background(), initializationTimeout)
|
initCtx, cancel := context.WithTimeout(context.Background(), initializationTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if _, err = clientImpl.Initialize(initCtx); err != nil {
|
if _, initErr := clientImpl.Initialize(initCtx); initErr != nil {
|
||||||
err = fmt.Errorf("failed to initialize MCP client: %w", err)
|
err = fmt.Errorf("failed to initialize MCP client: %w", initErr)
|
||||||
log.Error(ctx, "MCP client initialization failed after process/module start", "isWasm", isWasm, "error", err)
|
log.Error(ctx, "MCP client initialization failed after process/module start", "isWasm", isWasm, "error", err)
|
||||||
// Cleanup the newly started process/module and pipes as init failed
|
// Cleanup the newly started process/module and pipes as init failed
|
||||||
a.cleanup()
|
a.cleanup() // This should handle cmd/wasmModule
|
||||||
_ = hostStdinWriter.Close()
|
// Close the pipes directly as cleanup() doesn't know about them
|
||||||
_ = hostStdoutReader.Close()
|
if hostStdinWriter != nil {
|
||||||
|
_ = hostStdinWriter.Close()
|
||||||
|
}
|
||||||
|
if hostStdoutReader != nil {
|
||||||
|
_ = hostStdoutReader.Close()
|
||||||
|
}
|
||||||
return err // defer mu.Unlock() will run
|
return err // defer mu.Unlock() will run
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -268,52 +289,7 @@ func (a *MCPAgent) ensureClientInitialized(ctx context.Context) (err error) {
|
|||||||
return nil // Success
|
return nil // Success
|
||||||
}
|
}
|
||||||
|
|
||||||
// startNativeProcess starts the MCP server as a native executable.
|
// startNativeProcess was moved to mcp_process_native.go
|
||||||
func (a *MCPAgent) startNativeProcess(ctx context.Context) (stdin io.WriteCloser, stdout io.ReadCloser, cmd *exec.Cmd, err error) {
|
|
||||||
log.Debug(ctx, "Starting native MCP server process", "path", McpServerPath)
|
|
||||||
cmd = exec.CommandContext(context.Background(), McpServerPath)
|
|
||||||
|
|
||||||
stdin, err = cmd.StdinPipe()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, fmt.Errorf("native stdin pipe: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
stdout, err = cmd.StdoutPipe()
|
|
||||||
if err != nil {
|
|
||||||
_ = stdin.Close()
|
|
||||||
return nil, nil, nil, fmt.Errorf("native stdout pipe: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var stderr strings.Builder
|
|
||||||
cmd.Stderr = &stderr
|
|
||||||
|
|
||||||
if err = cmd.Start(); err != nil {
|
|
||||||
_ = stdin.Close()
|
|
||||||
_ = stdout.Close()
|
|
||||||
return nil, nil, nil, fmt.Errorf("native start: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
currentPid := cmd.Process.Pid
|
|
||||||
log.Info(ctx, "Native MCP server process started", "pid", currentPid)
|
|
||||||
|
|
||||||
// Start monitoring goroutine
|
|
||||||
go func(processCmd *exec.Cmd, processStderr *strings.Builder, processPid int) {
|
|
||||||
waitErr := processCmd.Wait() // Wait for the process to exit
|
|
||||||
a.mu.Lock()
|
|
||||||
log.Warn("Native MCP server process exited", "pid", processPid, "error", waitErr, "stderr", processStderr.String())
|
|
||||||
// Check if the cmd matches the one we are monitoring before cleaning up
|
|
||||||
if a.cmd == processCmd {
|
|
||||||
a.cleanup() // Use the central cleanup function
|
|
||||||
log.Info("MCP agent state cleaned up after native process exit", "pid", processPid)
|
|
||||||
} else {
|
|
||||||
log.Debug("Native MCP agent process exited, but state already updated or cmd mismatch", "exitedPid", processPid)
|
|
||||||
}
|
|
||||||
a.mu.Unlock()
|
|
||||||
}(cmd, &stderr, currentPid)
|
|
||||||
|
|
||||||
// Return the pipes connected to the process and the Cmd object
|
|
||||||
return stdin, stdout, cmd, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// startWasmModule loads and starts the MCP server as a WASM module using the agent's shared Wazero runtime.
|
// startWasmModule loads and starts the MCP server as a WASM module using the agent's shared Wazero runtime.
|
||||||
func (a *MCPAgent) startWasmModule(ctx context.Context) (hostStdinWriter io.WriteCloser, hostStdoutReader io.ReadCloser, mod api.Module, compiled api.Closer, err error) {
|
func (a *MCPAgent) startWasmModule(ctx context.Context) (hostStdinWriter io.WriteCloser, hostStdoutReader io.ReadCloser, mod api.Module, compiled api.Closer, err error) {
|
||||||
@ -328,15 +304,34 @@ func (a *MCPAgent) startWasmModule(ctx context.Context) (hostStdinWriter io.Writ
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, nil, fmt.Errorf("wasm stdin pipe: %w", err)
|
return nil, nil, nil, nil, fmt.Errorf("wasm stdin pipe: %w", err)
|
||||||
}
|
}
|
||||||
|
// Defer close pipes on error exit
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
_ = wasmStdinReader.Close()
|
||||||
|
_ = hostStdinWriter.Close()
|
||||||
|
// hostStdoutReader and wasmStdoutWriter handled below
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
hostStdoutReader, wasmStdoutWriter, err := os.Pipe()
|
hostStdoutReader, wasmStdoutWriter, err := os.Pipe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = wasmStdinReader.Close()
|
_ = wasmStdinReader.Close() // Close previous pipe
|
||||||
_ = hostStdinWriter.Close()
|
_ = hostStdinWriter.Close() // Close previous pipe
|
||||||
return nil, nil, nil, nil, fmt.Errorf("wasm stdout pipe: %w", err)
|
return nil, nil, nil, nil, fmt.Errorf("wasm stdout pipe: %w", err)
|
||||||
}
|
}
|
||||||
|
// Defer close pipes on error exit
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
_ = hostStdoutReader.Close()
|
||||||
|
_ = wasmStdoutWriter.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Use the SHARDED runtime from the agent struct
|
// Use the SHARDED runtime from the agent struct
|
||||||
runtime := a.wasmRuntime.(wazero.Runtime) // Type assert to get underlying Runtime
|
runtime, ok := a.wasmRuntime.(wazero.Runtime)
|
||||||
|
if !ok || runtime == nil {
|
||||||
|
return nil, nil, nil, nil, errors.New("wasmRuntime is not initialized or not a wazero.Runtime")
|
||||||
|
}
|
||||||
// WASI is already instantiated on the shared runtime
|
// WASI is already instantiated on the shared runtime
|
||||||
|
|
||||||
config := wazero.NewModuleConfig().
|
config := wazero.NewModuleConfig().
|
||||||
@ -353,16 +348,13 @@ func (a *MCPAgent) startWasmModule(ctx context.Context) (hostStdinWriter io.Writ
|
|||||||
// Compile module using the shared runtime (which uses the configured cache)
|
// Compile module using the shared runtime (which uses the configured cache)
|
||||||
compiledModule, err := runtime.CompileModule(ctx, wasmBytes)
|
compiledModule, err := runtime.CompileModule(ctx, wasmBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = wasmStdinReader.Close()
|
|
||||||
_ = hostStdinWriter.Close()
|
|
||||||
_ = hostStdoutReader.Close()
|
|
||||||
_ = wasmStdoutWriter.Close()
|
|
||||||
return nil, nil, nil, nil, fmt.Errorf("compile wasm module: %w", err)
|
return nil, nil, nil, nil, fmt.Errorf("compile wasm module: %w", err)
|
||||||
}
|
}
|
||||||
// Defer closing compiled module in case of errors later in this function.
|
// Defer closing compiled module in case of errors later in this function.
|
||||||
shouldCloseOnError := true
|
// Caller (ensureClientInitialized) is responsible for closing on success.
|
||||||
|
shouldCloseCompiledOnError := true
|
||||||
defer func() {
|
defer func() {
|
||||||
if shouldCloseOnError && compiledModule != nil {
|
if shouldCloseCompiledOnError && compiledModule != nil {
|
||||||
_ = compiledModule.Close(context.Background())
|
_ = compiledModule.Close(context.Background())
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -372,6 +364,7 @@ func (a *MCPAgent) startWasmModule(ctx context.Context) (hostStdinWriter io.Writ
|
|||||||
instanceErrChan := make(chan error, 1)
|
instanceErrChan := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
var instantiateErr error
|
var instantiateErr error
|
||||||
|
// Use context.Background() for the module's main execution context
|
||||||
instance, instantiateErr = runtime.InstantiateModule(context.Background(), compiledModule, config)
|
instance, instantiateErr = runtime.InstantiateModule(context.Background(), compiledModule, config)
|
||||||
instanceErrChan <- instantiateErr
|
instanceErrChan <- instantiateErr
|
||||||
}()
|
}()
|
||||||
@ -381,33 +374,36 @@ func (a *MCPAgent) startWasmModule(ctx context.Context) (hostStdinWriter io.Writ
|
|||||||
case instantiateErr := <-instanceErrChan:
|
case instantiateErr := <-instanceErrChan:
|
||||||
if instantiateErr != nil {
|
if instantiateErr != nil {
|
||||||
log.Error(ctx, "Failed to instantiate WASM module", "error", instantiateErr)
|
log.Error(ctx, "Failed to instantiate WASM module", "error", instantiateErr)
|
||||||
_ = wasmStdinReader.Close()
|
|
||||||
_ = hostStdinWriter.Close()
|
|
||||||
_ = hostStdoutReader.Close()
|
|
||||||
_ = wasmStdoutWriter.Close()
|
|
||||||
// compiledModule closed by defer
|
// compiledModule closed by defer
|
||||||
|
// pipes closed by defer
|
||||||
return nil, nil, nil, nil, fmt.Errorf("instantiate wasm module: %w", instantiateErr)
|
return nil, nil, nil, nil, fmt.Errorf("instantiate wasm module: %w", instantiateErr)
|
||||||
}
|
}
|
||||||
log.Warn(ctx, "WASM module instantiation returned (exited?) unexpectedly quickly.")
|
// If instantiateErr is nil here, the module exited immediately without error. Log it.
|
||||||
|
log.Warn(ctx, "WASM module instantiation returned (exited?) immediately without error.")
|
||||||
|
// Proceed to start monitoring, but return the (already closed) instance
|
||||||
|
// Pipes will be closed by the successful return path.
|
||||||
case <-time.After(2 * time.Second):
|
case <-time.After(2 * time.Second):
|
||||||
log.Debug(ctx, "WASM module instantiation likely blocking (server running), proceeding...")
|
log.Debug(ctx, "WASM module instantiation likely blocking (server running), proceeding...")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a monitoring goroutine for WASM module exit/error
|
// Start a monitoring goroutine for WASM module exit/error
|
||||||
go func(modToMonitor api.Module, compiledToClose api.Closer, errChan chan error) {
|
go func(modToMonitor api.Module, compiledToClose api.Closer, errChan chan error) {
|
||||||
|
// This will block until the instance created by InstantiateModule exits or errors.
|
||||||
instantiateErr := <-errChan
|
instantiateErr := <-errChan
|
||||||
|
|
||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
log.Warn("WASM module exited/errored", "error", instantiateErr)
|
log.Warn("WASM module exited/errored", "error", instantiateErr)
|
||||||
// Check if the module currently stored in the agent is the one we were monitoring.
|
// Check if the module currently stored in the agent is the one we were monitoring.
|
||||||
// Compare module instance directly. Instance might be nil if instantiation failed.
|
// Use the central cleanup which handles nil checks.
|
||||||
if a.wasmModule != nil && a.wasmModule == modToMonitor {
|
if a.wasmModule == modToMonitor {
|
||||||
a.cleanup() // This will close the module instance and compiled ref
|
a.cleanup() // This will close the module instance and compiled ref
|
||||||
log.Info("MCP agent state cleaned up after WASM module exit/error")
|
log.Info("MCP agent state cleaned up after WASM module exit/error")
|
||||||
} else {
|
} else {
|
||||||
log.Debug("WASM module exited, but state already updated or module mismatch")
|
// This case can happen if cleanup was called manually or if a new instance
|
||||||
|
// was started before the old one finished exiting.
|
||||||
|
log.Debug("WASM module exited, but state already updated or module mismatch. Explicitly closing compiled ref if needed.")
|
||||||
// Manually close the compiled module ref associated with this specific instance
|
// Manually close the compiled module ref associated with this specific instance
|
||||||
// as cleanup() won't if a.wasmModule doesn't match.
|
// as cleanup() won't if a.wasmModule doesn't match or is nil.
|
||||||
if compiledToClose != nil {
|
if compiledToClose != nil {
|
||||||
_ = compiledToClose.Close(context.Background())
|
_ = compiledToClose.Close(context.Background())
|
||||||
}
|
}
|
||||||
@ -415,8 +411,8 @@ func (a *MCPAgent) startWasmModule(ctx context.Context) (hostStdinWriter io.Writ
|
|||||||
a.mu.Unlock()
|
a.mu.Unlock()
|
||||||
}(instance, compiledModule, instanceErrChan) // Pass necessary refs
|
}(instance, compiledModule, instanceErrChan) // Pass necessary refs
|
||||||
|
|
||||||
// Success: prevent deferred cleanup, return resources needed by caller
|
// Success: prevent deferred cleanup of compiled module, return resources needed by caller
|
||||||
shouldCloseOnError = false
|
shouldCloseCompiledOnError = false
|
||||||
return hostStdinWriter, hostStdoutReader, instance, compiledModule, nil // Return instance and compiled module
|
return hostStdinWriter, hostStdoutReader, instance, compiledModule, nil // Return instance and compiled module
|
||||||
}
|
}
|
||||||
|
|
||||||
|
189
core/agents/mcp/mcp_host_functions.go
Normal file
189
core/agents/mcp/mcp_host_functions.go
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/navidrome/navidrome/log"
|
||||||
|
"github.com/tetratelabs/wazero"
|
||||||
|
"github.com/tetratelabs/wazero/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// httpClient is a shared HTTP client for host function reuse.
|
||||||
|
var httpClient = &http.Client{
|
||||||
|
// Consider adding a default timeout
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerHostFunctions defines and registers the host functions (e.g., http_fetch)
|
||||||
|
// into the provided Wazero runtime.
|
||||||
|
func registerHostFunctions(ctx context.Context, runtime wazero.Runtime) error {
|
||||||
|
// Define and Instantiate Host Module "env"
|
||||||
|
_, err := runtime.NewHostModuleBuilder("env"). // "env" is the conventional module name
|
||||||
|
NewFunctionBuilder().
|
||||||
|
WithFunc(httpFetch). // Register our Go function
|
||||||
|
Export("http_fetch"). // Export it with the name WASM will use
|
||||||
|
Instantiate(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "Failed to instantiate 'env' host module with httpFetch", "error", err)
|
||||||
|
return fmt.Errorf("instantiate host module 'env': %w", err)
|
||||||
|
}
|
||||||
|
log.Info(ctx, "Instantiated 'env' host module with http_fetch function")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// httpFetch is the host function exposed to WASM.
|
||||||
|
// ... (full implementation as provided previously) ...
|
||||||
|
// Returns:
|
||||||
|
// - 0 on success (request completed, results written).
|
||||||
|
// - 1 on host-side failure (e.g., memory access error, invalid input).
|
||||||
|
func httpFetch(
|
||||||
|
ctx context.Context, mod api.Module, // Standard Wazero host function params
|
||||||
|
// Request details
|
||||||
|
urlPtr, urlLen uint32,
|
||||||
|
methodPtr, methodLen uint32,
|
||||||
|
bodyPtr, bodyLen uint32,
|
||||||
|
timeoutMillis uint32,
|
||||||
|
// Result pointers
|
||||||
|
resultStatusPtr uint32,
|
||||||
|
resultBodyPtr uint32, resultBodyCapacity uint32, resultBodyLenPtr uint32,
|
||||||
|
resultErrorPtr uint32, resultErrorCapacity uint32, resultErrorLenPtr uint32,
|
||||||
|
) uint32 { // Using uint32 for status code convention (0=success, 1=failure)
|
||||||
|
mem := mod.Memory()
|
||||||
|
|
||||||
|
// --- Read Inputs ---
|
||||||
|
urlBytes, ok := mem.Read(urlPtr, urlLen)
|
||||||
|
if !ok {
|
||||||
|
log.Error(ctx, "httpFetch host error: failed to read URL from WASM memory")
|
||||||
|
// Cannot write error back as we don't have the pointers validated yet
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
url := string(urlBytes)
|
||||||
|
|
||||||
|
methodBytes, ok := mem.Read(methodPtr, methodLen)
|
||||||
|
if !ok {
|
||||||
|
log.Error(ctx, "httpFetch host error: failed to read method from WASM memory", "url", url)
|
||||||
|
return 1 // Bail out
|
||||||
|
}
|
||||||
|
method := string(methodBytes)
|
||||||
|
if method == "" {
|
||||||
|
method = "GET" // Default to GET
|
||||||
|
}
|
||||||
|
|
||||||
|
var reqBody io.Reader
|
||||||
|
if bodyLen > 0 {
|
||||||
|
bodyBytes, ok := mem.Read(bodyPtr, bodyLen)
|
||||||
|
if !ok {
|
||||||
|
log.Error(ctx, "httpFetch host error: failed to read body from WASM memory", "url", url, "method", method)
|
||||||
|
return 1 // Bail out
|
||||||
|
}
|
||||||
|
reqBody = bytes.NewReader(bodyBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := time.Duration(timeoutMillis) * time.Millisecond
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = 30 * time.Second // Default timeout matching httpClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Prepare and Execute Request ---
|
||||||
|
log.Debug(ctx, "httpFetch executing request", "method", method, "url", url, "timeout", timeout)
|
||||||
|
|
||||||
|
// Use a specific context for the request, derived from the host function's context
|
||||||
|
// but with the specific timeout for this call.
|
||||||
|
reqCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(reqCtx, method, url, reqBody)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := fmt.Sprintf("failed to create request: %v", err)
|
||||||
|
log.Error(ctx, "httpFetch host error", "url", url, "method", method, "error", errMsg)
|
||||||
|
writeStringResult(mem, resultErrorPtr, resultErrorCapacity, resultErrorLenPtr, errMsg)
|
||||||
|
mem.WriteUint32Le(resultStatusPtr, 0) // Write 0 status on creation error
|
||||||
|
mem.WriteUint32Le(resultBodyLenPtr, 0) // No body
|
||||||
|
return 0 // Indicate results (including error) were written
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Consider adding a User-Agent?
|
||||||
|
// req.Header.Set("User-Agent", "Navidrome/MCP-Agent-Host")
|
||||||
|
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
// Handle client-side errors (network, DNS, timeout)
|
||||||
|
errMsg := fmt.Sprintf("failed to execute request: %v", err)
|
||||||
|
log.Error(ctx, "httpFetch host error", "url", url, "method", method, "error", errMsg)
|
||||||
|
writeStringResult(mem, resultErrorPtr, resultErrorCapacity, resultErrorLenPtr, errMsg)
|
||||||
|
mem.WriteUint32Le(resultStatusPtr, 0) // Write 0 status on transport error
|
||||||
|
mem.WriteUint32Le(resultBodyLenPtr, 0)
|
||||||
|
return 0 // Indicate results written
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// --- Process Response ---
|
||||||
|
statusCode := uint32(resp.StatusCode)
|
||||||
|
responseBodyBytes, readErr := io.ReadAll(resp.Body)
|
||||||
|
if readErr != nil {
|
||||||
|
errMsg := fmt.Sprintf("failed to read response body: %v", readErr)
|
||||||
|
log.Error(ctx, "httpFetch host error", "url", url, "method", method, "status", statusCode, "error", errMsg)
|
||||||
|
writeStringResult(mem, resultErrorPtr, resultErrorCapacity, resultErrorLenPtr, errMsg)
|
||||||
|
mem.WriteUint32Le(resultStatusPtr, statusCode) // Write actual status code
|
||||||
|
mem.WriteUint32Le(resultBodyLenPtr, 0)
|
||||||
|
return 0 // Indicate results written
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Write Results Back to WASM Memory ---
|
||||||
|
log.Debug(ctx, "httpFetch writing results", "url", url, "method", method, "status", statusCode, "bodyLen", len(responseBodyBytes))
|
||||||
|
|
||||||
|
// Write status code
|
||||||
|
if !mem.WriteUint32Le(resultStatusPtr, statusCode) {
|
||||||
|
log.Error(ctx, "httpFetch host error: failed to write status code to WASM memory")
|
||||||
|
return 1 // Host error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write response body (checking capacity)
|
||||||
|
if !writeBytesResult(mem, resultBodyPtr, resultBodyCapacity, resultBodyLenPtr, responseBodyBytes) {
|
||||||
|
// If body write fails (likely due to capacity), write an error message instead.
|
||||||
|
errMsg := fmt.Sprintf("response body size (%d) exceeds buffer capacity (%d)", len(responseBodyBytes), resultBodyCapacity)
|
||||||
|
log.Error(ctx, "httpFetch host error", "url", url, "method", method, "status", statusCode, "error", errMsg)
|
||||||
|
writeStringResult(mem, resultErrorPtr, resultErrorCapacity, resultErrorLenPtr, errMsg)
|
||||||
|
mem.WriteUint32Le(resultBodyLenPtr, 0) // Ensure body length is 0 if we wrote an error
|
||||||
|
} else {
|
||||||
|
// Write empty error string if body write was successful
|
||||||
|
mem.WriteUint32Le(resultErrorLenPtr, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0 // Success
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to write string results, respecting capacity. Returns true on success.
|
||||||
|
func writeStringResult(mem api.Memory, ptr, capacity, lenPtr uint32, result string) bool {
|
||||||
|
bytes := []byte(result)
|
||||||
|
return writeBytesResult(mem, ptr, capacity, lenPtr, bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to write byte results, respecting capacity. Returns true on success.
|
||||||
|
func writeBytesResult(mem api.Memory, ptr, capacity, lenPtr uint32, result []byte) bool {
|
||||||
|
resultLen := uint32(len(result))
|
||||||
|
writeLen := resultLen
|
||||||
|
if writeLen > capacity {
|
||||||
|
log.Warn(context.Background(), "WASM host write truncated", "requested", resultLen, "capacity", capacity)
|
||||||
|
writeLen = capacity // Truncate if too large for buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
if writeLen > 0 {
|
||||||
|
if !mem.Write(ptr, result[:writeLen]) {
|
||||||
|
log.Error(context.Background(), "WASM host memory write failed", "ptr", ptr, "len", writeLen)
|
||||||
|
return false // Memory write failed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the *original* length of the data (even if truncated) so the WASM side knows.
|
||||||
|
if !mem.WriteUint32Le(lenPtr, resultLen) {
|
||||||
|
log.Error(context.Background(), "WASM host memory length write failed", "lenPtr", lenPtr, "len", resultLen)
|
||||||
|
return false // Memory write failed
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
59
core/agents/mcp/mcp_process_native.go
Normal file
59
core/agents/mcp/mcp_process_native.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/navidrome/navidrome/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// startNativeProcess starts the MCP server as a native executable.
|
||||||
|
func (a *MCPAgent) startNativeProcess(ctx context.Context) (stdin io.WriteCloser, stdout io.ReadCloser, cmd *exec.Cmd, err error) {
|
||||||
|
log.Debug(ctx, "Starting native MCP server process", "path", McpServerPath)
|
||||||
|
cmd = exec.CommandContext(context.Background(), McpServerPath) // Use Background context for long-running process
|
||||||
|
|
||||||
|
stdin, err = cmd.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("native stdin pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stdout, err = cmd.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
_ = stdin.Close()
|
||||||
|
return nil, nil, nil, fmt.Errorf("native stdout pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var stderr strings.Builder
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
|
||||||
|
if err = cmd.Start(); err != nil {
|
||||||
|
_ = stdin.Close()
|
||||||
|
_ = stdout.Close()
|
||||||
|
return nil, nil, nil, fmt.Errorf("native start: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
currentPid := cmd.Process.Pid
|
||||||
|
log.Info(ctx, "Native MCP server process started", "pid", currentPid)
|
||||||
|
|
||||||
|
// Start monitoring goroutine
|
||||||
|
go func(processCmd *exec.Cmd, processStderr *strings.Builder, processPid int) {
|
||||||
|
waitErr := processCmd.Wait() // Wait for the process to exit
|
||||||
|
a.mu.Lock()
|
||||||
|
log.Warn("Native MCP server process exited", "pid", processPid, "error", waitErr, "stderr", processStderr.String())
|
||||||
|
// Check if the cmd matches the one we are monitoring before cleaning up
|
||||||
|
// Use the central cleanup function which handles nil checks.
|
||||||
|
if a.cmd == processCmd {
|
||||||
|
a.cleanup() // Use the central cleanup function
|
||||||
|
log.Info("MCP agent state cleaned up after native process exit", "pid", processPid)
|
||||||
|
} else {
|
||||||
|
log.Debug("Native MCP agent process exited, but state already updated or cmd mismatch", "exitedPid", processPid)
|
||||||
|
}
|
||||||
|
a.mu.Unlock()
|
||||||
|
}(cmd, &stderr, currentPid)
|
||||||
|
|
||||||
|
// Return the pipes connected to the process and the Cmd object
|
||||||
|
return stdin, stdout, cmd, nil
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user