"""LangGraph workflow for multi-agent portfolio analysis. This implements the multi-phase architecture: Phase 1: Data Layer MCPs (Yahoo Finance, FMP, Trading-MCP, FRED) Phase 2: Computation Layer MCPs (Portfolio Optimizer, Risk Analyzer) Phase 2.5: ML Predictions (Ensemble Predictor with Chronos) Phase 3: LLM Synthesis (Portfolio Analyst Agent) """ import logging import time from typing import Dict, Any, List from datetime import datetime, timezone from decimal import Decimal from langgraph.graph import StateGraph, END from langchain_core.runnables import RunnableConfig from backend.models.agent_state import AgentState, MCPCall from backend.agents.portfolio_analyst import PortfolioAnalystAgent from backend.agents.personas import create_persona_agent, PersonaType logger = logging.getLogger(__name__) def summarize_fred_data(series_data: Dict[str, Any], indicator_name: str) -> Dict[str, Any]: """Summarize FRED time series to key statistics for token efficiency. Args: series_data: FRED API response with observations indicator_name: Name of the economic indicator Returns: Summarised statistics dictionary """ if not series_data or 'observations' not in series_data: return {} observations = series_data['observations'] if not observations: return {} # Extract values values = [float(obs.get('value', 0)) for obs in observations if obs.get('value')] if not values: return {} # Calculate statistics current = values[-1] mean = sum(values) / len(values) return { 'name': indicator_name, 'current': round(current, 4), 'mean': round(mean, 4), 'min': round(min(values), 4), 'max': round(max(values), 4), 'observations_count': len(values), 'trend': 'up' if current > mean else 'down' } class PortfolioAnalysisWorkflow: """LangGraph workflow for portfolio analysis.""" def __init__( self, mcp_router, roast_mode: bool = False, persona: PersonaType | str | None = None ): """Initialise the workflow with MCP router. Args: mcp_router: MCP router instance for calling MCP servers roast_mode: If True, use brutal honesty mode for analysis persona: Optional investor persona (e.g., 'warren_buffett', 'cathie_wood', 'ray_dalio') """ self.mcp_router = mcp_router self.roast_mode = roast_mode self.persona = persona # Initialise the appropriate analyst agent if persona: # Use persona-based analysis self.analyst_agent = create_persona_agent(persona) logger.info(f"Using persona agent: {persona}") else: # Use standard analysis (with optional roast mode) self.analyst_agent = PortfolioAnalystAgent(roast_mode=roast_mode) logger.info(f"Using standard analyst (roast_mode={roast_mode})") # Build the workflow graph self.workflow = self._build_workflow() def _build_workflow(self) -> StateGraph: """Build the LangGraph workflow with parallel feature extraction.""" workflow = StateGraph(AgentState) # Add nodes for each phase workflow.add_node("phase_1_data_layer", self._phase_1_data_layer) workflow.add_node("phase_1_5_feature_engineering", self._phase_1_5_feature_engineering) workflow.add_node("phase_2_computation", self._phase_2_computation) workflow.add_node("phase_2_5_ml_predictions", self._phase_2_5_ml_predictions) workflow.add_node("phase_3_synthesis", self._phase_3_synthesis) # Define the flow with parallel execution (fan-out/fan-in pattern) # Phase 1.5 (features) and Phase 2 (computation) run in parallel after Phase 1 workflow.set_entry_point("phase_1_data_layer") # Fan-out: Both branches start after Phase 1 workflow.add_edge("phase_1_data_layer", "phase_1_5_feature_engineering") workflow.add_edge("phase_1_data_layer", "phase_2_computation") # Fan-in: Both branches merge before Phase 2.5 workflow.add_edge("phase_1_5_feature_engineering", "phase_2_5_ml_predictions") workflow.add_edge("phase_2_computation", "phase_2_5_ml_predictions") # Continue to synthesis workflow.add_edge("phase_2_5_ml_predictions", "phase_3_synthesis") workflow.add_edge("phase_3_synthesis", END) return workflow.compile() async def _phase_1_data_layer(self, state: AgentState) -> AgentState: """Phase 1: Fetch all data from data layer MCPs. MCPs called: - Yahoo Finance: Real-time quotes and historical data - FMP: Company fundamentals - Trading-MCP: Technical indicators - FRED: Economic indicators """ logger.info("PHASE 1: Fetching data from Data Layer MCPs") phase_start = time.perf_counter() tickers = [h["ticker"] for h in state["holdings"]] try: # Fetch market data (Yahoo Finance) logger.debug(f"Fetching market data for {len(tickers)} tickers") market_data_list = await self.mcp_router.call_yahoo_finance_mcp("get_quote", {"tickers": tickers}) # Transform list to dict keyed by ticker market_data = {} for quote in market_data_list: ticker = quote.get("ticker") or quote.get("symbol") if ticker: market_data[ticker] = quote # Fetch historical data for each ticker historical_data = {} for ticker in tickers: hist = await self.mcp_router.call_yahoo_finance_mcp( "get_historical_data", {"ticker": ticker, "period": "1y", "interval": "1d"} ) historical_data[ticker] = hist # Fetch fundamentals (FMP) logger.debug("Fetching company fundamentals") fundamentals = {} for ticker in tickers: fund = await self.mcp_router.call_fmp_mcp("get_company_profile", {"ticker": ticker}) fundamentals[ticker] = fund # Fetch technical indicators (Trading-MCP) logger.debug("Calculating technical indicators") technical_indicators = {} for ticker in tickers: tech = await self.mcp_router.call_trading_mcp( "get_technical_indicators", {"ticker": ticker, "period": "3mo"} ) technical_indicators[ticker] = tech # Fetch economic data (FRED) logger.debug("Fetching economic indicators") economic_data = {} for series_id in ["GDP", "UNRATE", "DFF"]: econ = await self.mcp_router.call_fred_mcp("get_economic_series", {"series_id": series_id}) economic_data[series_id] = summarize_fred_data(econ, series_id) # Fetch news sentiment (Enhancement #3 - News Sentiment MCP) logger.debug("Fetching news sentiment for all holdings") sentiment_data = {} for ticker in tickers: try: sentiment = await self.mcp_router.call_news_sentiment_mcp( "get_news_with_sentiment", {"ticker": ticker, "days_back": 7} ) sentiment_data[ticker] = sentiment logger.debug(f"{ticker} sentiment: {sentiment.get('overall_sentiment', 0):.2f}") except Exception as e: logger.warning(f"Failed to fetch sentiment for {ticker}: {e}") # Continue with empty sentiment on error sentiment_data[ticker] = { "ticker": ticker, "overall_sentiment": 0.0, "confidence": 0.0, "article_count": 0, "articles": [], "error": str(e) } # Enrich holdings with market values based on realtime data enriched_holdings = [] for holding in state["holdings"]: ticker = holding.get("ticker") quantity = holding.get("quantity", 0) dollar_amount = holding.get("dollar_amount", 0) # Get current price from realtime_data current_price = None if ticker in market_data: price_data = market_data[ticker] current_price = price_data.get("price", 0) or price_data.get("regularMarketPrice", 0) # Calculate market value if quantity > 0 and current_price: market_value = Decimal(str(quantity)) * Decimal(str(current_price)) elif dollar_amount > 0: market_value = Decimal(str(dollar_amount)) else: market_value = Decimal("0") # Create enriched holding (immutable pattern) enriched_holding = { **holding, "current_price": current_price, "market_value": float(market_value) } enriched_holdings.append(enriched_holding) # Calculate total portfolio value total_portfolio_value = sum(h["market_value"] for h in enriched_holdings) # Calculate portfolio weights for holding in enriched_holdings: if total_portfolio_value > 0: holding["weight"] = holding["market_value"] / total_portfolio_value else: # Edge case: equal weights if total is 0 holding["weight"] = 1.0 / len(enriched_holdings) if len(enriched_holdings) > 0 else 0.0 # Log for verification logger.info(f"Portfolio total value: ${total_portfolio_value:,.2f}, weights sum: {sum(h['weight'] for h in enriched_holdings):.4f}") # Update state with enriched holdings state["holdings"] = enriched_holdings # Update state state["historical_prices"] = historical_data state["fundamentals"] = fundamentals state["realtime_data"] = market_data state["technical_indicators"] = technical_indicators state["economic_data"] = economic_data state["sentiment_data"] = sentiment_data # Enhancement #3 state["current_step"] = "phase_1_complete" # Log MCP calls state["mcp_calls"].extend([ MCPCall.model_validate({"mcp": "yahoo_finance", "tool": "get_quote"}).model_dump(), MCPCall.model_validate({"mcp": "yahoo_finance", "tool": "get_historical_data"}).model_dump(), MCPCall.model_validate({"mcp": "fmp", "tool": "get_company_profile"}).model_dump(), MCPCall.model_validate({"mcp": "trading_mcp", "tool": "get_technical_indicators"}).model_dump(), MCPCall.model_validate({"mcp": "fred", "tool": "get_economic_series"}).model_dump(), MCPCall.model_validate({"mcp": "news_sentiment", "tool": "get_news_with_sentiment"}).model_dump(), ]) # Track phase duration phase_duration_ms = int((time.perf_counter() - phase_start) * 1000) state["phase_1_duration_ms"] = phase_duration_ms logger.info(f"PHASE 1 COMPLETE: Fetched data for {len(tickers)} assets ({phase_duration_ms}ms)") except Exception as e: logger.error(f"Error in Phase 1: {e}") state["errors"].append(f"Phase 1 error: {str(e)}") return state async def _phase_1_5_feature_engineering(self, state: AgentState) -> AgentState: """Phase 1.5: Extract and select features from raw data. MCP called: - Feature Extraction: Technical indicators, normalisation, selection """ logger.info("PHASE 1.5: Feature Engineering") phase_start = time.perf_counter() state["reasoning_steps"].append("Phase 1.5: Extracting and selecting features") tickers = [h["ticker"] for h in state["holdings"]] feature_vectors = {} for ticker in tickers: try: # Get historical prices from Phase 1 data hist_data = state.get("historical_prices", {}).get(ticker, {}) prices = hist_data.get("close_prices", []) if len(prices) < 20: logger.warning(f"Insufficient price data for {ticker} ({len(prices)} < 20)") continue # Extract technical features tech_features = await self.mcp_router.call_feature_extraction_mcp( "extract_technical_features", { "ticker": ticker, "prices": prices, "include_momentum": True, "include_volatility": True, "include_trend": True, } ) # Get fundamental and sentiment features from state fundamentals = state.get("fundamentals", {}).get(ticker, {}) sentiment = state.get("sentiment_data", {}).get(ticker, {}) # Compute combined feature vector feature_vector = await self.mcp_router.call_feature_extraction_mcp( "compute_feature_vector", { "ticker": ticker, "technical_features": tech_features.get("features", {}), "fundamental_features": fundamentals, "sentiment_features": sentiment, "max_features": 30, "selection_method": "pca", } ) feature_vectors[ticker] = feature_vector logger.info( f"Extracted {feature_vector.get('feature_count', 0)} features for {ticker}" ) except Exception as e: logger.error(f"Error extracting features for {ticker}: {e}") state["errors"].append(f"Feature extraction error for {ticker}: {str(e)}") state["feature_vectors"] = feature_vectors state["reasoning_steps"].append( f"Extracted feature vectors for {len(feature_vectors)} tickers" ) # Log MCP calls state["mcp_calls"].extend([ MCPCall.model_validate({ "mcp": "feature_extraction", "tool": "extract_technical_features" }).model_dump(), MCPCall.model_validate({ "mcp": "feature_extraction", "tool": "compute_feature_vector" }).model_dump(), ]) # Track phase duration phase_duration_ms = int((time.perf_counter() - phase_start) * 1000) logger.info( f"PHASE 1.5 COMPLETE: Extracted features for {len(feature_vectors)} assets ({phase_duration_ms}ms)" ) return state async def _phase_2_computation(self, state: AgentState) -> AgentState: """Phase 2: Run computational models with data from Phase 1. MCPs called: - Portfolio Optimizer: HRP, Black-Litterman, Mean-Variance - Risk Analyzer: VaR, CVaR, Monte Carlo """ logger.info("PHASE 2: Running Computation Layer MCPs") phase_start = time.perf_counter() try: # Prepare market data for optimization market_data_list = [] for ticker, hist_data in state["historical_prices"].items(): market_data_list.append({ "ticker": ticker, "prices": hist_data.get("close_prices", []), "dates": hist_data.get("dates", []), }) # Check for single-asset portfolio (requires minimum 2 assets for optimization) unique_tickers = set(ticker for ticker in state["historical_prices"].keys()) is_single_asset = len(unique_tickers) < 2 # Run portfolio optimizations logger.debug("Running portfolio optimizations") if is_single_asset: # Single-asset fallback: cannot optimise, show 100% allocation single_ticker = list(unique_tickers)[0] logger.info(f"Single-asset portfolio detected ({single_ticker}) - skipping optimization, showing 100% allocation") # Create fallback optimization results with 100% allocation fallback_weights = {single_ticker: 1.0} # Calculate basic metrics from the single asset ticker_holding = next( (h for h in state["holdings"] if h["ticker"] == single_ticker), None ) fallback_result = { "weights": fallback_weights, "expected_return": 0.0, "volatility": 0.0, "sharpe": 0.0, } hrp_result = fallback_result bl_result = fallback_result mv_result = { **fallback_result, "message": "Portfolio optimization requires minimum 2 assets. Showing current 100% allocation." } else: # Multiple assets: proceed with normal optimization # HRP hrp_result = await self.mcp_router.call_portfolio_optimizer_mcp( "optimize_hrp", { "market_data": market_data_list, "method": "hrp", "risk_tolerance": state["risk_tolerance"], } ) # Black-Litterman bl_result = await self.mcp_router.call_portfolio_optimizer_mcp( "optimize_black_litterman", { "market_data": market_data_list, "method": "black_litterman", "risk_tolerance": state["risk_tolerance"], } ) # Mean-Variance mv_result = await self.mcp_router.call_portfolio_optimizer_mcp( "optimize_mean_variance", { "market_data": market_data_list, "method": "mean_variance", "risk_tolerance": state["risk_tolerance"], } ) # Run risk analysis logger.debug("Running risk analysis") portfolio_input = [] for holding in state["holdings"]: ticker = holding["ticker"] historical_prices = state["historical_prices"].get(ticker, {}).get("close_prices", []) portfolio_input.append({ "ticker": ticker, "weight": holding.get("weight", 0), "prices": historical_prices, }) risk_result = await self.mcp_router.call_risk_analyzer_mcp( "analyze_risk", { "portfolio": portfolio_input, "portfolio_value": sum(h.get("market_value", 0) for h in state["holdings"]), "confidence_level": 0.95, "method": "monte_carlo", "num_simulations": 10000, } ) # Update state state["optimisation_results"] = { "hrp": hrp_result, "black_litterman": bl_result, "mean_variance": mv_result, } state["risk_analysis"] = risk_result state["current_step"] = "phase_2_complete" # Log MCP calls state["mcp_calls"].extend([ MCPCall.model_validate({"mcp": "portfolio_optimizer_mcp", "tool": "optimize_hrp"}).model_dump(), MCPCall.model_validate({"mcp": "portfolio_optimizer_mcp", "tool": "optimize_black_litterman"}).model_dump(), MCPCall.model_validate({"mcp": "portfolio_optimizer_mcp", "tool": "optimize_mean_variance"}).model_dump(), MCPCall.model_validate({"mcp": "risk_analyzer_mcp", "tool": "analyze_risk"}).model_dump(), ]) # Track phase duration phase_duration_ms = int((time.perf_counter() - phase_start) * 1000) state["phase_2_duration_ms"] = phase_duration_ms logger.info(f"PHASE 2 COMPLETE: Optimizations and risk analysis done ({phase_duration_ms}ms)") except Exception as e: logger.error(f"Error in Phase 2: {e}") state["errors"].append(f"Phase 2 error: {str(e)}") return state async def _phase_2_5_ml_predictions(self, state: AgentState) -> AgentState: """Phase 2.5: Generate ML-based price forecasts using Ensemble Predictor. MCP called: - Ensemble Predictor: Chronos + statistical models for price forecasting """ logger.info("PHASE 2.5: Generating ML predictions") phase_start = time.perf_counter() try: # Generate forecasts for each holding logger.debug("Running ensemble forecasts for portfolio holdings") ensemble_forecasts = {} for holding in state["holdings"]: ticker = holding["ticker"] # Get historical prices from Phase 1 data hist_data = state["historical_prices"].get(ticker, {}) prices = hist_data.get("close_prices", []) if not prices or len(prices) < 10: logger.warning(f"Insufficient price data for {ticker}, skipping forecast") continue try: # Call ensemble predictor forecast_result = await self.mcp_router.call_ensemble_predictor_mcp( "forecast_ensemble", { "ticker": ticker, "prices": prices, "forecast_horizon": 30, # 30-day forecast "confidence_level": 0.95, "use_returns": True, # Forecast returns for stability "ensemble_method": "mean", # Simple averaging } ) ensemble_forecasts[ticker] = forecast_result logger.debug(f"Generated forecast for {ticker} using {len(forecast_result.get('models_used', []))} models") except Exception as e: logger.warning(f"Forecast failed for {ticker}: {e}") continue # Update state state["ensemble_forecasts"] = ensemble_forecasts state["current_step"] = "phase_2_5_complete" # Log MCP calls state["mcp_calls"].extend([ MCPCall.model_validate({ "mcp": "ensemble_predictor", "tool": "forecast_ensemble" }).model_dump(), ]) # Track phase duration phase_duration_ms = int((time.perf_counter() - phase_start) * 1000) state["phase_2_5_duration_ms"] = phase_duration_ms logger.info( f"PHASE 2.5 COMPLETE: Generated forecasts for {len(ensemble_forecasts)} assets ({phase_duration_ms}ms)" ) except Exception as e: logger.error(f"Error in Phase 2.5: {e}") state["errors"].append(f"Phase 2.5 error: {str(e)}") # Set empty forecasts to allow workflow to continue state["ensemble_forecasts"] = {} return state async def _phase_3_synthesis(self, state: AgentState) -> AgentState: """Phase 3: LLM synthesis of all data into actionable insights.""" logger.info("PHASE 3: LLM Synthesis") phase_start = time.perf_counter() try: # Prepare data for analyst agent portfolio_data = { "holdings": state["holdings"], "portfolio_id": state.get("portfolio_id", "unknown"), "risk_tolerance": state["risk_tolerance"], } # Call analyst agent (returns AgentResult with usage metrics) result = await self.analyst_agent.analyze_portfolio( portfolio_data=portfolio_data, market_data=state.get("realtime_data", {}), fundamentals=state.get("fundamentals", {}), technical_indicators=state.get("technical_indicators", {}), economic_data=state.get("economic_data", {}), optimization_results=state.get("optimisation_results", {}), risk_analysis=state.get("risk_analysis", {}), ensemble_forecasts=state.get("ensemble_forecasts", {}), sentiment_data=state.get("sentiment_data", {}), risk_tolerance=state["risk_tolerance"], ) # Extract analysis output and usage metrics analysis = result.output # Update state with analysis results state["ai_synthesis"] = analysis.summary state["recommendations"] = analysis.recommendations state["reasoning_steps"].extend(analysis.reasoning) state["current_step"] = "complete" # Track LLM usage metrics state["llm_input_tokens"] = result.input_tokens state["llm_output_tokens"] = result.output_tokens state["llm_total_tokens"] = result.total_tokens state["llm_request_count"] = result.request_count # Track phase duration phase_duration_ms = int((time.perf_counter() - phase_start) * 1000) state["phase_3_duration_ms"] = phase_duration_ms logger.info( f"PHASE 3 COMPLETE: Analysis generated (health score: {analysis.health_score}, " f"{result.total_tokens} tokens, {phase_duration_ms}ms)" ) except Exception as e: logger.error(f"Error in Phase 3: {e}") state["errors"].append(f"Phase 3 error: {str(e)}") return state async def run(self, initial_state: AgentState) -> AgentState: """Run the complete workflow. Args: initial_state: Initial state with portfolio and query Returns: Final state with complete analysis """ logger.info(f"Starting portfolio analysis workflow for {len(initial_state['holdings'])} holdings") result = await self.workflow.ainvoke( initial_state, config=RunnableConfig(configurable={}, recursion_limit=25) ) logger.info("Workflow complete") return result async def stream(self, initial_state: AgentState): """Stream workflow execution with progress updates. Yields progress events as each node completes, allowing real-time updates in the UI. Args: initial_state: Initial state with portfolio and query Yields: Dict with 'event' (node name), 'progress' (0-1), 'message' (description) """ logger.info(f"Streaming portfolio analysis for {len(initial_state['holdings'])} holdings") phase_info = { "phase_1_data_layer": { "progress": 0.2, "message": "Phase 1: Fetching market data from Yahoo Finance, FMP, FRED..." }, "phase_1_5_feature_engineering": { "progress": 0.4, "message": "Phase 1.5: Extracting technical features and indicators..." }, "phase_2_computation": { "progress": 0.4, "message": "Phase 2: Running portfolio optimisation and risk analysis..." }, "phase_2_5_ml_predictions": { "progress": 0.7, "message": "Phase 2.5: Generating ensemble ML predictions..." }, "phase_3_synthesis": { "progress": 0.9, "message": "Phase 3: Synthesising AI insights with Claude..." } } final_state = None config = RunnableConfig(configurable={}, recursion_limit=25) async for event in self.workflow.astream_events(initial_state, version="v2", config=config): if event["event"] == "on_chain_end": node_name = event.get("name", "") if node_name in phase_info: info = phase_info[node_name] yield { "event": node_name, "progress": info["progress"], "message": info["message"], "state": event.get("data", {}).get("output") } final_state = event.get("data", {}).get("output") yield { "event": "complete", "progress": 1.0, "message": "Analysis complete!", "state": final_state }