refactor: DRY up MCPAgent implementation

Refactor the MCPAgent to reduce code duplication.\n\n- Consolidate GetArtistBiographyArgs and GetArtistURLArgs into a single\n  ArtistArgs struct.\n- Extract common logic (initialization check, locking, tool calling,\n  error handling, response validation) into a private callMCPTool helper method.\n- Simplify GetArtistBiography and GetArtistURL to delegate to callMCPTool.\n- Update tests to use the consolidated ArtistArgs struct.\n- Correct mutex locking in ensureClientInitialized to prevent race conditions.
This commit is contained in:
Deluan 2025-04-19 13:04:41 -04:00
parent 6e59060a01
commit 8ebefe4065
2 changed files with 84 additions and 161 deletions

View File

@ -69,28 +69,28 @@ func (a *MCPAgent) ensureClientInitialized(ctx context.Context) (err error) {
// --- Use override if provided (for testing) --- // --- Use override if provided (for testing) ---
if a.ClientOverride != nil { if a.ClientOverride != nil {
a.mu.Lock() a.mu.Lock()
// Only set if not already set (could be set by a concurrent test setup)
if a.client == nil { if a.client == nil {
a.client = a.ClientOverride a.client = a.ClientOverride
log.Debug(ctx, "Using provided MCP client override for testing") log.Debug(ctx, "Using provided MCP client override for testing")
} }
a.mu.Unlock()
return nil // Skip real initialization when override is present
}
a.mu.Lock()
// If client is already initialized and valid, we're done.
if a.client != nil {
a.mu.Unlock() a.mu.Unlock()
return nil return nil
} }
// Unlock after the check, as the rest of the function needs the lock.
a.mu.Unlock()
// --- Attempt initialization/restart --- // --- Check if already initialized (critical section) ---
a.mu.Lock() // Acquire lock *before* checking client
if a.client != nil {
a.mu.Unlock() // Unlock and return if already initialized
return nil
}
// --- Client is nil, proceed with initialization *while holding the lock* ---
// Defer unlock to ensure it's released even on errors during init.
defer a.mu.Unlock()
log.Info(ctx, "Initializing MCP client and starting/restarting server process...", "serverPath", McpServerPath) log.Info(ctx, "Initializing MCP client and starting/restarting server process...", "serverPath", McpServerPath)
// Use background context for the command itself, so it doesn't get cancelled by the request context. // Use background context for the command itself
cmd := exec.CommandContext(context.Background(), McpServerPath) cmd := exec.CommandContext(context.Background(), McpServerPath)
var stdin io.WriteCloser var stdin io.WriteCloser
@ -101,7 +101,7 @@ func (a *MCPAgent) ensureClientInitialized(ctx context.Context) (err error) {
if err != nil { if err != nil {
err = fmt.Errorf("failed to get stdin pipe for MCP server: %w", err) err = fmt.Errorf("failed to get stdin pipe for MCP server: %w", err)
log.Error(ctx, "MCP init/restart failed", "error", err) log.Error(ctx, "MCP init/restart failed", "error", err)
return err // Return error directly return err // defer mu.Unlock() will run
} }
stdout, err = cmd.StdoutPipe() stdout, err = cmd.StdoutPipe()
@ -109,7 +109,7 @@ func (a *MCPAgent) ensureClientInitialized(ctx context.Context) (err error) {
err = fmt.Errorf("failed to get stdout pipe for MCP server: %w", err) err = fmt.Errorf("failed to get stdout pipe for MCP server: %w", err)
_ = stdin.Close() // Clean up stdin pipe _ = stdin.Close() // Clean up stdin pipe
log.Error(ctx, "MCP init/restart failed", "error", err) log.Error(ctx, "MCP init/restart failed", "error", err)
return err return err // defer mu.Unlock() will run
} }
cmd.Stderr = &stderr cmd.Stderr = &stderr
@ -119,7 +119,7 @@ func (a *MCPAgent) ensureClientInitialized(ctx context.Context) (err error) {
_ = stdin.Close() _ = stdin.Close()
_ = stdout.Close() _ = stdout.Close()
log.Error(ctx, "MCP init/restart failed", "error", err) log.Error(ctx, "MCP init/restart failed", "error", err)
return err return err // defer mu.Unlock() will run
} }
currentPid := cmd.Process.Pid currentPid := cmd.Process.Pid
@ -128,28 +128,24 @@ func (a *MCPAgent) ensureClientInitialized(ctx context.Context) (err error) {
// --- Start monitoring goroutine for *this* process --- // --- Start monitoring goroutine for *this* process ---
go func(processCmd *exec.Cmd, processStderr *strings.Builder, processPid int) { go func(processCmd *exec.Cmd, processStderr *strings.Builder, processPid int) {
waitErr := processCmd.Wait() waitErr := processCmd.Wait()
// Lock immediately after Wait returns to update state atomically
a.mu.Lock() a.mu.Lock()
log.Warn("MCP server process exited", "pid", processPid, "error", waitErr, "stderr", processStderr.String()) log.Warn("MCP server process exited", "pid", processPid, "error", waitErr, "stderr", processStderr.String())
// Clean up state only if this is still the *current* process if a.cmd == processCmd { // Check if state belongs to the process that just exited
// (to avoid race condition if a quick restart happened)
if a.cmd == processCmd {
if a.stdin != nil { if a.stdin != nil {
_ = a.stdin.Close() _ = a.stdin.Close()
a.stdin = nil a.stdin = nil
} }
a.client = nil // Mark client as unusable, triggering restart on next call a.client = nil
a.cmd = nil a.cmd = nil
log.Info("MCP agent state cleaned up after process exit", "pid", processPid) log.Info("MCP agent state cleaned up after process exit", "pid", processPid)
} else { } else {
log.Debug("MCP agent process exited, but state already updated by newer process", "exitedPid", processPid) log.Debug("MCP agent process exited, but state already updated by newer process", "exitedPid", processPid)
} }
a.mu.Unlock() a.mu.Unlock()
}(cmd, &stderr, currentPid) // Pass copies/values to the goroutine }(cmd, &stderr, currentPid)
// --- Initialize MCP client --- // --- Initialize MCP client ---
transport := stdio.NewStdioServerTransportWithIO(stdout, stdin) // Use the pipes from this attempt transport := stdio.NewStdioServerTransportWithIO(stdout, stdin)
// Create the *real* mcp.Client, which satisfies our mcpClient interface
clientImpl := mcp.NewClient(transport) clientImpl := mcp.NewClient(transport)
initCtx, cancel := context.WithTimeout(context.Background(), initializationTimeout) initCtx, cancel := context.WithTimeout(context.Background(), initializationTimeout)
@ -157,169 +153,97 @@ func (a *MCPAgent) ensureClientInitialized(ctx context.Context) (err error) {
if _, err = clientImpl.Initialize(initCtx); err != nil { if _, err = clientImpl.Initialize(initCtx); err != nil {
err = fmt.Errorf("failed to initialize MCP client: %w", err) err = fmt.Errorf("failed to initialize MCP client: %w", err)
log.Error(ctx, "MCP client initialization failed after process start", "pid", currentPid, "error", err) log.Error(ctx, "MCP client initialization failed after process start", "pid", currentPid, "error", err)
// Attempt to kill the process we just started, as client init failed
if killErr := cmd.Process.Kill(); killErr != nil { if killErr := cmd.Process.Kill(); killErr != nil {
log.Error(ctx, "Failed to kill MCP server process after init failure", "pid", currentPid, "error", killErr) log.Error(ctx, "Failed to kill MCP server process after init failure", "pid", currentPid, "error", killErr)
} }
return err // Return the initialization error return err // defer mu.Unlock() will run
} }
// --- Initialization successful, update agent state --- // --- Initialization successful, update agent state (still holding lock) ---
a.mu.Lock() // Lock again to update agent state a.cmd = cmd
// Double-check if another goroutine initialized successfully in the meantime a.stdin = stdin
// Although unlikely with the outer lock/check, it's safer. a.client = clientImpl
if a.client != nil {
a.mu.Unlock()
log.Warn(ctx, "MCP client was already initialized by another routine, discarding this attempt", "pid", currentPid)
// Kill the redundant process we started
if killErr := cmd.Process.Kill(); killErr != nil {
log.Error(ctx, "Failed to kill redundant MCP server process", "pid", currentPid, "error", killErr)
}
return nil // Return success as *a* client is available
}
a.cmd = cmd // Store the successfully started command
a.stdin = stdin // Store its stdin
a.client = clientImpl // Store the successfully initialized client (as interface type)
a.mu.Unlock()
log.Info(ctx, "MCP client initialized successfully", "pid", currentPid) log.Info(ctx, "MCP client initialized successfully", "pid", currentPid)
// defer mu.Unlock() runs here
return nil // Success return nil // Success
} }
// GetArtistBiographyArgs defines the structure for the get_artist_biography MCP tool arguments. // ArtistArgs defines the structure for MCP tool arguments requiring artist info.
// Exported for use in tests. // Exported for use in tests.
type GetArtistBiographyArgs struct { type ArtistArgs struct {
ID string `json:"id"` ID string `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Mbid string `json:"mbid,omitempty"` Mbid string `json:"mbid,omitempty"`
} }
// callMCPTool is a helper to perform the common steps of calling an MCP tool.
func (a *MCPAgent) callMCPTool(ctx context.Context, toolName string, args any) (string, error) {
// Ensure the client is initialized and the server is running (attempts restart if needed)
if err := a.ensureClientInitialized(ctx); err != nil {
log.Error(ctx, "MCP agent initialization/restart failed, cannot call tool", "tool", toolName, "error", err)
return "", fmt.Errorf("MCP agent not ready: %w", err)
}
// Lock to safely access the shared client resource
a.mu.Lock()
// Check if the client is valid *after* ensuring initialization and acquiring lock.
if a.client == nil {
a.mu.Unlock() // Release lock before returning error
log.Error(ctx, "MCP client became invalid after initialization check (server process likely died)", "tool", toolName)
return "", fmt.Errorf("MCP agent process is not running")
}
// Keep a reference to the client while locked
currentClient := a.client
a.mu.Unlock() // *Release lock before* making the potentially blocking MCP call
// Call the tool using the client reference
log.Debug(ctx, "Calling MCP tool", "tool", toolName, "args", args)
response, err := currentClient.CallTool(ctx, toolName, args)
if err != nil {
// Handle potential pipe closures or other communication errors
log.Error(ctx, "Failed to call MCP tool", "tool", toolName, "error", err)
// Check if the error indicates a broken pipe, suggesting the server died
if errors.Is(err, io.ErrClosedPipe) || strings.Contains(err.Error(), "broken pipe") || strings.Contains(err.Error(), "EOF") {
log.Warn(ctx, "MCP tool call failed, possibly due to server process exit. State will be reset.", "tool", toolName)
// State reset is handled by the monitoring goroutine, just return error
return "", fmt.Errorf("MCP agent process communication error: %w", err)
}
return "", fmt.Errorf("failed to call MCP tool '%s': %w", toolName, err)
}
// Process the response
if response == nil || len(response.Content) == 0 || response.Content[0].TextContent == nil || response.Content[0].TextContent.Text == "" {
log.Warn(ctx, "MCP tool returned empty or invalid response", "tool", toolName)
return "", agents.ErrNotFound
}
// Return the text content
resultText := response.Content[0].TextContent.Text
log.Debug(ctx, "Received response from MCP agent", "tool", toolName, "length", len(resultText))
return resultText, nil
}
// GetArtistBiography retrieves the artist biography by calling the external MCP server. // GetArtistBiography retrieves the artist biography by calling the external MCP server.
func (a *MCPAgent) GetArtistBiography(ctx context.Context, id, name, mbid string) (string, error) { func (a *MCPAgent) GetArtistBiography(ctx context.Context, id, name, mbid string) (string, error) {
// Ensure the client is initialized and the server is running (attempts restart if needed) args := ArtistArgs{
if err := a.ensureClientInitialized(ctx); err != nil {
log.Error(ctx, "MCP agent initialization/restart failed, cannot get biography", "error", err)
return "", fmt.Errorf("MCP agent not ready: %w", err)
}
// Lock to ensure only one request uses the client/pipes at a time
a.mu.Lock()
// Check if the client is still valid *after* ensuring initialization and acquiring lock.
// The monitoring goroutine could have nilled it out if the process died just now.
if a.client == nil {
a.mu.Unlock() // Release lock before returning error
log.Error(ctx, "MCP client became invalid after initialization check (server process likely died)")
return "", fmt.Errorf("MCP agent process is not running")
}
// Keep a reference to the client while locked
currentClient := a.client
a.mu.Unlock() // Release lock before making the potentially blocking MCP call
log.Debug(ctx, "Calling MCP agent GetArtistBiography", "id", id, "name", name, "mbid", mbid)
// Prepare arguments for the tool call
args := GetArtistBiographyArgs{
ID: id, ID: id,
Name: name, Name: name,
Mbid: mbid, Mbid: mbid,
} }
return a.callMCPTool(ctx, McpToolNameGetBio, args)
// Call the tool using the client reference
log.Debug(ctx, "Calling MCP tool", "tool", McpToolNameGetBio, "args", args)
response, err := currentClient.CallTool(ctx, McpToolNameGetBio, args) // Use currentClient
if err != nil {
// Handle potential pipe closures or other communication errors
log.Error(ctx, "Failed to call MCP tool", "tool", McpToolNameGetBio, "error", err)
// Check if the error indicates a broken pipe, suggesting the server died
if errors.Is(err, io.ErrClosedPipe) || strings.Contains(err.Error(), "broken pipe") || strings.Contains(err.Error(), "EOF") {
log.Warn(ctx, "MCP tool call failed, possibly due to server process exit. State will be reset.")
// State reset is handled by the monitoring goroutine, just return error
return "", fmt.Errorf("MCP agent process communication error: %w", err)
}
return "", fmt.Errorf("failed to call MCP tool '%s': %w", McpToolNameGetBio, err)
}
// Process the response
if response == nil || len(response.Content) == 0 || response.Content[0].TextContent == nil || response.Content[0].TextContent.Text == "" {
log.Warn(ctx, "MCP tool returned empty or invalid response", "tool", McpToolNameGetBio)
return "", agents.ErrNotFound
}
bio := response.Content[0].TextContent.Text
log.Debug(ctx, "Received biography from MCP agent", "tool", McpToolNameGetBio, "bioLength", len(bio))
// Return the biography text
return bio, nil
}
// GetArtistURLArgs defines the structure for the get_artist_url MCP tool arguments.
// Exported for use in tests.
type GetArtistURLArgs struct {
ID string `json:"id"`
Name string `json:"name"`
Mbid string `json:"mbid,omitempty"`
} }
// GetArtistURL retrieves the artist URL by calling the external MCP server. // GetArtistURL retrieves the artist URL by calling the external MCP server.
func (a *MCPAgent) GetArtistURL(ctx context.Context, id, name, mbid string) (string, error) { func (a *MCPAgent) GetArtistURL(ctx context.Context, id, name, mbid string) (string, error) {
// Ensure the client is initialized and the server is running (attempts restart if needed) args := ArtistArgs{
if err := a.ensureClientInitialized(ctx); err != nil {
log.Error(ctx, "MCP agent initialization/restart failed, cannot get URL", "error", err)
return "", fmt.Errorf("MCP agent not ready: %w", err)
}
// Lock to ensure only one request uses the client/pipes at a time
a.mu.Lock()
// Check if the client is still valid *after* ensuring initialization and acquiring lock.
if a.client == nil {
a.mu.Unlock()
log.Error(ctx, "MCP client became invalid after initialization check (server process likely died)")
return "", fmt.Errorf("MCP agent process is not running")
}
// Keep a reference to the client while locked
currentClient := a.client
a.mu.Unlock() // Release lock before making the potentially blocking MCP call
log.Debug(ctx, "Calling MCP agent GetArtistURL", "id", id, "name", name, "mbid", mbid)
// Prepare arguments for the tool call
args := GetArtistURLArgs{
ID: id, ID: id,
Name: name, Name: name,
Mbid: mbid, Mbid: mbid,
} }
return a.callMCPTool(ctx, McpToolNameGetURL, args)
// Call the tool using the client reference
log.Debug(ctx, "Calling MCP tool", "tool", McpToolNameGetURL, "args", args)
response, err := currentClient.CallTool(ctx, McpToolNameGetURL, args) // Use currentClient
if err != nil {
// Handle potential pipe closures or other communication errors
log.Error(ctx, "Failed to call MCP tool", "tool", McpToolNameGetURL, "error", err)
// Check if the error indicates a broken pipe, suggesting the server died
if errors.Is(err, io.ErrClosedPipe) || strings.Contains(err.Error(), "broken pipe") || strings.Contains(err.Error(), "EOF") {
log.Warn(ctx, "MCP tool call failed, possibly due to server process exit. State will be reset.")
// State reset is handled by the monitoring goroutine, just return error
return "", fmt.Errorf("MCP agent process communication error: %w", err)
}
return "", fmt.Errorf("failed to call MCP tool '%s': %w", McpToolNameGetURL, err)
}
// Process the response
if response == nil || len(response.Content) == 0 || response.Content[0].TextContent == nil || response.Content[0].TextContent.Text == "" {
log.Warn(ctx, "MCP tool returned empty or invalid response", "tool", McpToolNameGetURL)
return "", agents.ErrNotFound
}
url := response.Content[0].TextContent.Text
log.Debug(ctx, "Received URL from MCP agent", "tool", McpToolNameGetURL, "url", url)
// Return the URL text
return url, nil
} }
// Ensure MCPAgent implements the required interfaces // Ensure MCPAgent implements the required interfaces
@ -328,5 +252,4 @@ var _ agents.ArtistURLRetriever = (*MCPAgent)(nil)
func init() { func init() {
agents.Register(McpAgentName, mcpConstructor) agents.Register(McpAgentName, mcpConstructor)
log.Info("Registered MCP Agent")
} }

View File

@ -71,8 +71,8 @@ var _ = Describe("MCPAgent", func() {
expectedBio := "This is the artist bio." expectedBio := "This is the artist bio."
mockClient.CallToolFunc = func(ctx context.Context, toolName string, args any) (*mcp_client.ToolResponse, error) { mockClient.CallToolFunc = func(ctx context.Context, toolName string, args any) (*mcp_client.ToolResponse, error) {
Expect(toolName).To(Equal(mcp.McpToolNameGetBio)) Expect(toolName).To(Equal(mcp.McpToolNameGetBio))
Expect(args).To(BeAssignableToTypeOf(mcp.GetArtistBiographyArgs{})) // Use exported type Expect(args).To(BeAssignableToTypeOf(mcp.ArtistArgs{}))
typedArgs := args.(mcp.GetArtistBiographyArgs) // Use exported type typedArgs := args.(mcp.ArtistArgs)
Expect(typedArgs.ID).To(Equal("id1")) Expect(typedArgs.ID).To(Equal("id1"))
Expect(typedArgs.Name).To(Equal("Artist Name")) Expect(typedArgs.Name).To(Equal("Artist Name"))
Expect(typedArgs.Mbid).To(Equal("mbid1")) Expect(typedArgs.Mbid).To(Equal("mbid1"))
@ -137,8 +137,8 @@ var _ = Describe("MCPAgent", func() {
expectedURL := "http://example.com/artist" expectedURL := "http://example.com/artist"
mockClient.CallToolFunc = func(ctx context.Context, toolName string, args any) (*mcp_client.ToolResponse, error) { mockClient.CallToolFunc = func(ctx context.Context, toolName string, args any) (*mcp_client.ToolResponse, error) {
Expect(toolName).To(Equal(mcp.McpToolNameGetURL)) Expect(toolName).To(Equal(mcp.McpToolNameGetURL))
Expect(args).To(BeAssignableToTypeOf(mcp.GetArtistURLArgs{})) // Use exported type Expect(args).To(BeAssignableToTypeOf(mcp.ArtistArgs{}))
typedArgs := args.(mcp.GetArtistURLArgs) // Use exported type typedArgs := args.(mcp.ArtistArgs)
Expect(typedArgs.ID).To(Equal("id2")) Expect(typedArgs.ID).To(Equal("id2"))
Expect(typedArgs.Name).To(Equal("Another Artist")) Expect(typedArgs.Name).To(Equal("Another Artist"))
Expect(typedArgs.Mbid).To(Equal("mbid2")) Expect(typedArgs.Mbid).To(Equal("mbid2"))