mirror of
https://github.com/navidrome/navidrome.git
synced 2025-05-05 21:01:08 +03:00
refactor: DRY up MCPAgent implementation
Refactor the MCPAgent to reduce code duplication.\n\n- Consolidate GetArtistBiographyArgs and GetArtistURLArgs into a single\n ArtistArgs struct.\n- Extract common logic (initialization check, locking, tool calling,\n error handling, response validation) into a private callMCPTool helper method.\n- Simplify GetArtistBiography and GetArtistURL to delegate to callMCPTool.\n- Update tests to use the consolidated ArtistArgs struct.\n- Correct mutex locking in ensureClientInitialized to prevent race conditions.
This commit is contained in:
parent
6e59060a01
commit
8ebefe4065
@ -69,28 +69,28 @@ func (a *MCPAgent) ensureClientInitialized(ctx context.Context) (err error) {
|
||||
// --- Use override if provided (for testing) ---
|
||||
if a.ClientOverride != nil {
|
||||
a.mu.Lock()
|
||||
// Only set if not already set (could be set by a concurrent test setup)
|
||||
if a.client == nil {
|
||||
a.client = a.ClientOverride
|
||||
log.Debug(ctx, "Using provided MCP client override for testing")
|
||||
}
|
||||
a.mu.Unlock()
|
||||
return nil // Skip real initialization when override is present
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
// If client is already initialized and valid, we're done.
|
||||
if a.client != nil {
|
||||
a.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
// Unlock after the check, as the rest of the function needs the lock.
|
||||
a.mu.Unlock()
|
||||
|
||||
// --- Attempt initialization/restart ---
|
||||
// --- Check if already initialized (critical section) ---
|
||||
a.mu.Lock() // Acquire lock *before* checking client
|
||||
if a.client != nil {
|
||||
a.mu.Unlock() // Unlock and return if already initialized
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Client is nil, proceed with initialization *while holding the lock* ---
|
||||
// Defer unlock to ensure it's released even on errors during init.
|
||||
defer a.mu.Unlock()
|
||||
|
||||
log.Info(ctx, "Initializing MCP client and starting/restarting server process...", "serverPath", McpServerPath)
|
||||
|
||||
// Use background context for the command itself, so it doesn't get cancelled by the request context.
|
||||
// Use background context for the command itself
|
||||
cmd := exec.CommandContext(context.Background(), McpServerPath)
|
||||
|
||||
var stdin io.WriteCloser
|
||||
@ -101,7 +101,7 @@ func (a *MCPAgent) ensureClientInitialized(ctx context.Context) (err error) {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to get stdin pipe for MCP server: %w", err)
|
||||
log.Error(ctx, "MCP init/restart failed", "error", err)
|
||||
return err // Return error directly
|
||||
return err // defer mu.Unlock() will run
|
||||
}
|
||||
|
||||
stdout, err = cmd.StdoutPipe()
|
||||
@ -109,7 +109,7 @@ func (a *MCPAgent) ensureClientInitialized(ctx context.Context) (err error) {
|
||||
err = fmt.Errorf("failed to get stdout pipe for MCP server: %w", err)
|
||||
_ = stdin.Close() // Clean up stdin pipe
|
||||
log.Error(ctx, "MCP init/restart failed", "error", err)
|
||||
return err
|
||||
return err // defer mu.Unlock() will run
|
||||
}
|
||||
|
||||
cmd.Stderr = &stderr
|
||||
@ -119,7 +119,7 @@ func (a *MCPAgent) ensureClientInitialized(ctx context.Context) (err error) {
|
||||
_ = stdin.Close()
|
||||
_ = stdout.Close()
|
||||
log.Error(ctx, "MCP init/restart failed", "error", err)
|
||||
return err
|
||||
return err // defer mu.Unlock() will run
|
||||
}
|
||||
|
||||
currentPid := cmd.Process.Pid
|
||||
@ -128,28 +128,24 @@ func (a *MCPAgent) ensureClientInitialized(ctx context.Context) (err error) {
|
||||
// --- Start monitoring goroutine for *this* process ---
|
||||
go func(processCmd *exec.Cmd, processStderr *strings.Builder, processPid int) {
|
||||
waitErr := processCmd.Wait()
|
||||
// Lock immediately after Wait returns to update state atomically
|
||||
a.mu.Lock()
|
||||
log.Warn("MCP server process exited", "pid", processPid, "error", waitErr, "stderr", processStderr.String())
|
||||
// Clean up state only if this is still the *current* process
|
||||
// (to avoid race condition if a quick restart happened)
|
||||
if a.cmd == processCmd {
|
||||
if a.cmd == processCmd { // Check if state belongs to the process that just exited
|
||||
if a.stdin != nil {
|
||||
_ = a.stdin.Close()
|
||||
a.stdin = nil
|
||||
}
|
||||
a.client = nil // Mark client as unusable, triggering restart on next call
|
||||
a.client = nil
|
||||
a.cmd = nil
|
||||
log.Info("MCP agent state cleaned up after process exit", "pid", processPid)
|
||||
} else {
|
||||
log.Debug("MCP agent process exited, but state already updated by newer process", "exitedPid", processPid)
|
||||
}
|
||||
a.mu.Unlock()
|
||||
}(cmd, &stderr, currentPid) // Pass copies/values to the goroutine
|
||||
}(cmd, &stderr, currentPid)
|
||||
|
||||
// --- Initialize MCP client ---
|
||||
transport := stdio.NewStdioServerTransportWithIO(stdout, stdin) // Use the pipes from this attempt
|
||||
// Create the *real* mcp.Client, which satisfies our mcpClient interface
|
||||
transport := stdio.NewStdioServerTransportWithIO(stdout, stdin)
|
||||
clientImpl := mcp.NewClient(transport)
|
||||
|
||||
initCtx, cancel := context.WithTimeout(context.Background(), initializationTimeout)
|
||||
@ -157,169 +153,97 @@ func (a *MCPAgent) ensureClientInitialized(ctx context.Context) (err error) {
|
||||
if _, err = clientImpl.Initialize(initCtx); err != nil {
|
||||
err = fmt.Errorf("failed to initialize MCP client: %w", err)
|
||||
log.Error(ctx, "MCP client initialization failed after process start", "pid", currentPid, "error", err)
|
||||
// Attempt to kill the process we just started, as client init failed
|
||||
if killErr := cmd.Process.Kill(); killErr != nil {
|
||||
log.Error(ctx, "Failed to kill MCP server process after init failure", "pid", currentPid, "error", killErr)
|
||||
}
|
||||
return err // Return the initialization error
|
||||
return err // defer mu.Unlock() will run
|
||||
}
|
||||
|
||||
// --- Initialization successful, update agent state ---
|
||||
a.mu.Lock() // Lock again to update agent state
|
||||
// Double-check if another goroutine initialized successfully in the meantime
|
||||
// Although unlikely with the outer lock/check, it's safer.
|
||||
if a.client != nil {
|
||||
a.mu.Unlock()
|
||||
log.Warn(ctx, "MCP client was already initialized by another routine, discarding this attempt", "pid", currentPid)
|
||||
// Kill the redundant process we started
|
||||
if killErr := cmd.Process.Kill(); killErr != nil {
|
||||
log.Error(ctx, "Failed to kill redundant MCP server process", "pid", currentPid, "error", killErr)
|
||||
}
|
||||
return nil // Return success as *a* client is available
|
||||
}
|
||||
|
||||
a.cmd = cmd // Store the successfully started command
|
||||
a.stdin = stdin // Store its stdin
|
||||
a.client = clientImpl // Store the successfully initialized client (as interface type)
|
||||
a.mu.Unlock()
|
||||
// --- Initialization successful, update agent state (still holding lock) ---
|
||||
a.cmd = cmd
|
||||
a.stdin = stdin
|
||||
a.client = clientImpl
|
||||
|
||||
log.Info(ctx, "MCP client initialized successfully", "pid", currentPid)
|
||||
// defer mu.Unlock() runs here
|
||||
return nil // Success
|
||||
}
|
||||
|
||||
// GetArtistBiographyArgs defines the structure for the get_artist_biography MCP tool arguments.
|
||||
// ArtistArgs defines the structure for MCP tool arguments requiring artist info.
|
||||
// Exported for use in tests.
|
||||
type GetArtistBiographyArgs struct {
|
||||
type ArtistArgs struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Mbid string `json:"mbid,omitempty"`
|
||||
}
|
||||
|
||||
// callMCPTool is a helper to perform the common steps of calling an MCP tool.
|
||||
func (a *MCPAgent) callMCPTool(ctx context.Context, toolName string, args any) (string, error) {
|
||||
// Ensure the client is initialized and the server is running (attempts restart if needed)
|
||||
if err := a.ensureClientInitialized(ctx); err != nil {
|
||||
log.Error(ctx, "MCP agent initialization/restart failed, cannot call tool", "tool", toolName, "error", err)
|
||||
return "", fmt.Errorf("MCP agent not ready: %w", err)
|
||||
}
|
||||
|
||||
// Lock to safely access the shared client resource
|
||||
a.mu.Lock()
|
||||
|
||||
// Check if the client is valid *after* ensuring initialization and acquiring lock.
|
||||
if a.client == nil {
|
||||
a.mu.Unlock() // Release lock before returning error
|
||||
log.Error(ctx, "MCP client became invalid after initialization check (server process likely died)", "tool", toolName)
|
||||
return "", fmt.Errorf("MCP agent process is not running")
|
||||
}
|
||||
|
||||
// Keep a reference to the client while locked
|
||||
currentClient := a.client
|
||||
a.mu.Unlock() // *Release lock before* making the potentially blocking MCP call
|
||||
|
||||
// Call the tool using the client reference
|
||||
log.Debug(ctx, "Calling MCP tool", "tool", toolName, "args", args)
|
||||
response, err := currentClient.CallTool(ctx, toolName, args)
|
||||
if err != nil {
|
||||
// Handle potential pipe closures or other communication errors
|
||||
log.Error(ctx, "Failed to call MCP tool", "tool", toolName, "error", err)
|
||||
// Check if the error indicates a broken pipe, suggesting the server died
|
||||
if errors.Is(err, io.ErrClosedPipe) || strings.Contains(err.Error(), "broken pipe") || strings.Contains(err.Error(), "EOF") {
|
||||
log.Warn(ctx, "MCP tool call failed, possibly due to server process exit. State will be reset.", "tool", toolName)
|
||||
// State reset is handled by the monitoring goroutine, just return error
|
||||
return "", fmt.Errorf("MCP agent process communication error: %w", err)
|
||||
}
|
||||
return "", fmt.Errorf("failed to call MCP tool '%s': %w", toolName, err)
|
||||
}
|
||||
|
||||
// Process the response
|
||||
if response == nil || len(response.Content) == 0 || response.Content[0].TextContent == nil || response.Content[0].TextContent.Text == "" {
|
||||
log.Warn(ctx, "MCP tool returned empty or invalid response", "tool", toolName)
|
||||
return "", agents.ErrNotFound
|
||||
}
|
||||
|
||||
// Return the text content
|
||||
resultText := response.Content[0].TextContent.Text
|
||||
log.Debug(ctx, "Received response from MCP agent", "tool", toolName, "length", len(resultText))
|
||||
return resultText, nil
|
||||
}
|
||||
|
||||
// GetArtistBiography retrieves the artist biography by calling the external MCP server.
|
||||
func (a *MCPAgent) GetArtistBiography(ctx context.Context, id, name, mbid string) (string, error) {
|
||||
// Ensure the client is initialized and the server is running (attempts restart if needed)
|
||||
if err := a.ensureClientInitialized(ctx); err != nil {
|
||||
log.Error(ctx, "MCP agent initialization/restart failed, cannot get biography", "error", err)
|
||||
return "", fmt.Errorf("MCP agent not ready: %w", err)
|
||||
}
|
||||
|
||||
// Lock to ensure only one request uses the client/pipes at a time
|
||||
a.mu.Lock()
|
||||
|
||||
// Check if the client is still valid *after* ensuring initialization and acquiring lock.
|
||||
// The monitoring goroutine could have nilled it out if the process died just now.
|
||||
if a.client == nil {
|
||||
a.mu.Unlock() // Release lock before returning error
|
||||
log.Error(ctx, "MCP client became invalid after initialization check (server process likely died)")
|
||||
return "", fmt.Errorf("MCP agent process is not running")
|
||||
}
|
||||
|
||||
// Keep a reference to the client while locked
|
||||
currentClient := a.client
|
||||
a.mu.Unlock() // Release lock before making the potentially blocking MCP call
|
||||
|
||||
log.Debug(ctx, "Calling MCP agent GetArtistBiography", "id", id, "name", name, "mbid", mbid)
|
||||
|
||||
// Prepare arguments for the tool call
|
||||
args := GetArtistBiographyArgs{
|
||||
args := ArtistArgs{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Mbid: mbid,
|
||||
}
|
||||
|
||||
// Call the tool using the client reference
|
||||
log.Debug(ctx, "Calling MCP tool", "tool", McpToolNameGetBio, "args", args)
|
||||
response, err := currentClient.CallTool(ctx, McpToolNameGetBio, args) // Use currentClient
|
||||
if err != nil {
|
||||
// Handle potential pipe closures or other communication errors
|
||||
log.Error(ctx, "Failed to call MCP tool", "tool", McpToolNameGetBio, "error", err)
|
||||
// Check if the error indicates a broken pipe, suggesting the server died
|
||||
if errors.Is(err, io.ErrClosedPipe) || strings.Contains(err.Error(), "broken pipe") || strings.Contains(err.Error(), "EOF") {
|
||||
log.Warn(ctx, "MCP tool call failed, possibly due to server process exit. State will be reset.")
|
||||
// State reset is handled by the monitoring goroutine, just return error
|
||||
return "", fmt.Errorf("MCP agent process communication error: %w", err)
|
||||
}
|
||||
return "", fmt.Errorf("failed to call MCP tool '%s': %w", McpToolNameGetBio, err)
|
||||
}
|
||||
|
||||
// Process the response
|
||||
if response == nil || len(response.Content) == 0 || response.Content[0].TextContent == nil || response.Content[0].TextContent.Text == "" {
|
||||
log.Warn(ctx, "MCP tool returned empty or invalid response", "tool", McpToolNameGetBio)
|
||||
return "", agents.ErrNotFound
|
||||
}
|
||||
|
||||
bio := response.Content[0].TextContent.Text
|
||||
log.Debug(ctx, "Received biography from MCP agent", "tool", McpToolNameGetBio, "bioLength", len(bio))
|
||||
|
||||
// Return the biography text
|
||||
return bio, nil
|
||||
}
|
||||
|
||||
// GetArtistURLArgs defines the structure for the get_artist_url MCP tool arguments.
|
||||
// Exported for use in tests.
|
||||
type GetArtistURLArgs struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Mbid string `json:"mbid,omitempty"`
|
||||
return a.callMCPTool(ctx, McpToolNameGetBio, args)
|
||||
}
|
||||
|
||||
// GetArtistURL retrieves the artist URL by calling the external MCP server.
|
||||
func (a *MCPAgent) GetArtistURL(ctx context.Context, id, name, mbid string) (string, error) {
|
||||
// Ensure the client is initialized and the server is running (attempts restart if needed)
|
||||
if err := a.ensureClientInitialized(ctx); err != nil {
|
||||
log.Error(ctx, "MCP agent initialization/restart failed, cannot get URL", "error", err)
|
||||
return "", fmt.Errorf("MCP agent not ready: %w", err)
|
||||
}
|
||||
|
||||
// Lock to ensure only one request uses the client/pipes at a time
|
||||
a.mu.Lock()
|
||||
|
||||
// Check if the client is still valid *after* ensuring initialization and acquiring lock.
|
||||
if a.client == nil {
|
||||
a.mu.Unlock()
|
||||
log.Error(ctx, "MCP client became invalid after initialization check (server process likely died)")
|
||||
return "", fmt.Errorf("MCP agent process is not running")
|
||||
}
|
||||
|
||||
// Keep a reference to the client while locked
|
||||
currentClient := a.client
|
||||
a.mu.Unlock() // Release lock before making the potentially blocking MCP call
|
||||
|
||||
log.Debug(ctx, "Calling MCP agent GetArtistURL", "id", id, "name", name, "mbid", mbid)
|
||||
|
||||
// Prepare arguments for the tool call
|
||||
args := GetArtistURLArgs{
|
||||
args := ArtistArgs{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Mbid: mbid,
|
||||
}
|
||||
|
||||
// Call the tool using the client reference
|
||||
log.Debug(ctx, "Calling MCP tool", "tool", McpToolNameGetURL, "args", args)
|
||||
response, err := currentClient.CallTool(ctx, McpToolNameGetURL, args) // Use currentClient
|
||||
if err != nil {
|
||||
// Handle potential pipe closures or other communication errors
|
||||
log.Error(ctx, "Failed to call MCP tool", "tool", McpToolNameGetURL, "error", err)
|
||||
// Check if the error indicates a broken pipe, suggesting the server died
|
||||
if errors.Is(err, io.ErrClosedPipe) || strings.Contains(err.Error(), "broken pipe") || strings.Contains(err.Error(), "EOF") {
|
||||
log.Warn(ctx, "MCP tool call failed, possibly due to server process exit. State will be reset.")
|
||||
// State reset is handled by the monitoring goroutine, just return error
|
||||
return "", fmt.Errorf("MCP agent process communication error: %w", err)
|
||||
}
|
||||
return "", fmt.Errorf("failed to call MCP tool '%s': %w", McpToolNameGetURL, err)
|
||||
}
|
||||
|
||||
// Process the response
|
||||
if response == nil || len(response.Content) == 0 || response.Content[0].TextContent == nil || response.Content[0].TextContent.Text == "" {
|
||||
log.Warn(ctx, "MCP tool returned empty or invalid response", "tool", McpToolNameGetURL)
|
||||
return "", agents.ErrNotFound
|
||||
}
|
||||
|
||||
url := response.Content[0].TextContent.Text
|
||||
log.Debug(ctx, "Received URL from MCP agent", "tool", McpToolNameGetURL, "url", url)
|
||||
|
||||
// Return the URL text
|
||||
return url, nil
|
||||
return a.callMCPTool(ctx, McpToolNameGetURL, args)
|
||||
}
|
||||
|
||||
// Ensure MCPAgent implements the required interfaces
|
||||
@ -328,5 +252,4 @@ var _ agents.ArtistURLRetriever = (*MCPAgent)(nil)
|
||||
|
||||
func init() {
|
||||
agents.Register(McpAgentName, mcpConstructor)
|
||||
log.Info("Registered MCP Agent")
|
||||
}
|
||||
|
@ -71,8 +71,8 @@ var _ = Describe("MCPAgent", func() {
|
||||
expectedBio := "This is the artist bio."
|
||||
mockClient.CallToolFunc = func(ctx context.Context, toolName string, args any) (*mcp_client.ToolResponse, error) {
|
||||
Expect(toolName).To(Equal(mcp.McpToolNameGetBio))
|
||||
Expect(args).To(BeAssignableToTypeOf(mcp.GetArtistBiographyArgs{})) // Use exported type
|
||||
typedArgs := args.(mcp.GetArtistBiographyArgs) // Use exported type
|
||||
Expect(args).To(BeAssignableToTypeOf(mcp.ArtistArgs{}))
|
||||
typedArgs := args.(mcp.ArtistArgs)
|
||||
Expect(typedArgs.ID).To(Equal("id1"))
|
||||
Expect(typedArgs.Name).To(Equal("Artist Name"))
|
||||
Expect(typedArgs.Mbid).To(Equal("mbid1"))
|
||||
@ -137,8 +137,8 @@ var _ = Describe("MCPAgent", func() {
|
||||
expectedURL := "http://example.com/artist"
|
||||
mockClient.CallToolFunc = func(ctx context.Context, toolName string, args any) (*mcp_client.ToolResponse, error) {
|
||||
Expect(toolName).To(Equal(mcp.McpToolNameGetURL))
|
||||
Expect(args).To(BeAssignableToTypeOf(mcp.GetArtistURLArgs{})) // Use exported type
|
||||
typedArgs := args.(mcp.GetArtistURLArgs) // Use exported type
|
||||
Expect(args).To(BeAssignableToTypeOf(mcp.ArtistArgs{}))
|
||||
typedArgs := args.(mcp.ArtistArgs)
|
||||
Expect(typedArgs.ID).To(Equal("id2"))
|
||||
Expect(typedArgs.Name).To(Equal("Another Artist"))
|
||||
Expect(typedArgs.Mbid).To(Equal("mbid2"))
|
||||
|
Loading…
x
Reference in New Issue
Block a user