Spaces:
Running
on
Zero
Running
on
Zero
| """LangGraph workflow for multi-agent portfolio analysis. | |
| This implements the multi-phase architecture: | |
| Phase 1: Data Layer MCPs (Yahoo Finance, FMP, Trading-MCP, FRED) | |
| Phase 2: Computation Layer MCPs (Portfolio Optimizer, Risk Analyzer) | |
| Phase 2.5: ML Predictions (Ensemble Predictor with Chronos) | |
| Phase 3: LLM Synthesis (Portfolio Analyst Agent) | |
| """ | |
| import logging | |
| import time | |
| from typing import Dict, Any, List | |
| from datetime import datetime, timezone | |
| from decimal import Decimal | |
| from langgraph.graph import StateGraph, END | |
| from langchain_core.runnables import RunnableConfig | |
| from backend.models.agent_state import AgentState, MCPCall | |
| from backend.agents.portfolio_analyst import PortfolioAnalystAgent | |
| from backend.agents.personas import create_persona_agent, PersonaType | |
| logger = logging.getLogger(__name__) | |
| def summarize_fred_data(series_data: Dict[str, Any], indicator_name: str) -> Dict[str, Any]: | |
| """Summarize FRED time series to key statistics for token efficiency. | |
| Args: | |
| series_data: FRED API response with observations | |
| indicator_name: Name of the economic indicator | |
| Returns: | |
| Summarised statistics dictionary | |
| """ | |
| if not series_data or 'observations' not in series_data: | |
| return {} | |
| observations = series_data['observations'] | |
| if not observations: | |
| return {} | |
| # Extract values | |
| values = [float(obs.get('value', 0)) for obs in observations if obs.get('value')] | |
| if not values: | |
| return {} | |
| # Calculate statistics | |
| current = values[-1] | |
| mean = sum(values) / len(values) | |
| return { | |
| 'name': indicator_name, | |
| 'current': round(current, 4), | |
| 'mean': round(mean, 4), | |
| 'min': round(min(values), 4), | |
| 'max': round(max(values), 4), | |
| 'observations_count': len(values), | |
| 'trend': 'up' if current > mean else 'down' | |
| } | |
| class PortfolioAnalysisWorkflow: | |
| """LangGraph workflow for portfolio analysis.""" | |
| def __init__( | |
| self, | |
| mcp_router, | |
| roast_mode: bool = False, | |
| persona: PersonaType | str | None = None | |
| ): | |
| """Initialise the workflow with MCP router. | |
| Args: | |
| mcp_router: MCP router instance for calling MCP servers | |
| roast_mode: If True, use brutal honesty mode for analysis | |
| persona: Optional investor persona (e.g., 'warren_buffett', 'cathie_wood', 'ray_dalio') | |
| """ | |
| self.mcp_router = mcp_router | |
| self.roast_mode = roast_mode | |
| self.persona = persona | |
| # Initialise the appropriate analyst agent | |
| if persona: | |
| # Use persona-based analysis | |
| self.analyst_agent = create_persona_agent(persona) | |
| logger.info(f"Using persona agent: {persona}") | |
| else: | |
| # Use standard analysis (with optional roast mode) | |
| self.analyst_agent = PortfolioAnalystAgent(roast_mode=roast_mode) | |
| logger.info(f"Using standard analyst (roast_mode={roast_mode})") | |
| # Build the workflow graph | |
| self.workflow = self._build_workflow() | |
| def _build_workflow(self) -> StateGraph: | |
| """Build the LangGraph workflow with parallel feature extraction.""" | |
| workflow = StateGraph(AgentState) | |
| # Add nodes for each phase | |
| workflow.add_node("phase_1_data_layer", self._phase_1_data_layer) | |
| workflow.add_node("phase_1_5_feature_engineering", self._phase_1_5_feature_engineering) | |
| workflow.add_node("phase_2_computation", self._phase_2_computation) | |
| workflow.add_node("phase_2_5_ml_predictions", self._phase_2_5_ml_predictions) | |
| workflow.add_node("phase_3_synthesis", self._phase_3_synthesis) | |
| # Define the flow with parallel execution (fan-out/fan-in pattern) | |
| # Phase 1.5 (features) and Phase 2 (computation) run in parallel after Phase 1 | |
| workflow.set_entry_point("phase_1_data_layer") | |
| # Fan-out: Both branches start after Phase 1 | |
| workflow.add_edge("phase_1_data_layer", "phase_1_5_feature_engineering") | |
| workflow.add_edge("phase_1_data_layer", "phase_2_computation") | |
| # Fan-in: Both branches merge before Phase 2.5 | |
| workflow.add_edge("phase_1_5_feature_engineering", "phase_2_5_ml_predictions") | |
| workflow.add_edge("phase_2_computation", "phase_2_5_ml_predictions") | |
| # Continue to synthesis | |
| workflow.add_edge("phase_2_5_ml_predictions", "phase_3_synthesis") | |
| workflow.add_edge("phase_3_synthesis", END) | |
| return workflow.compile() | |
| async def _phase_1_data_layer(self, state: AgentState) -> AgentState: | |
| """Phase 1: Fetch all data from data layer MCPs. | |
| MCPs called: | |
| - Yahoo Finance: Real-time quotes and historical data | |
| - FMP: Company fundamentals | |
| - Trading-MCP: Technical indicators | |
| - FRED: Economic indicators | |
| """ | |
| logger.info("PHASE 1: Fetching data from Data Layer MCPs") | |
| phase_start = time.perf_counter() | |
| tickers = [h["ticker"] for h in state["holdings"]] | |
| try: | |
| # Fetch market data (Yahoo Finance) | |
| logger.debug(f"Fetching market data for {len(tickers)} tickers") | |
| market_data_list = await self.mcp_router.call_yahoo_finance_mcp("get_quote", {"tickers": tickers}) | |
| # Transform list to dict keyed by ticker | |
| market_data = {} | |
| for quote in market_data_list: | |
| ticker = quote.get("ticker") or quote.get("symbol") | |
| if ticker: | |
| market_data[ticker] = quote | |
| # Fetch historical data for each ticker | |
| historical_data = {} | |
| for ticker in tickers: | |
| hist = await self.mcp_router.call_yahoo_finance_mcp( | |
| "get_historical_data", | |
| {"ticker": ticker, "period": "1y", "interval": "1d"} | |
| ) | |
| historical_data[ticker] = hist | |
| # Fetch fundamentals (FMP) | |
| logger.debug("Fetching company fundamentals") | |
| fundamentals = {} | |
| for ticker in tickers: | |
| fund = await self.mcp_router.call_fmp_mcp("get_company_profile", {"ticker": ticker}) | |
| fundamentals[ticker] = fund | |
| # Fetch technical indicators (Trading-MCP) | |
| logger.debug("Calculating technical indicators") | |
| technical_indicators = {} | |
| for ticker in tickers: | |
| tech = await self.mcp_router.call_trading_mcp( | |
| "get_technical_indicators", | |
| {"ticker": ticker, "period": "3mo"} | |
| ) | |
| technical_indicators[ticker] = tech | |
| # Fetch economic data (FRED) | |
| logger.debug("Fetching economic indicators") | |
| economic_data = {} | |
| for series_id in ["GDP", "UNRATE", "DFF"]: | |
| econ = await self.mcp_router.call_fred_mcp("get_economic_series", {"series_id": series_id}) | |
| economic_data[series_id] = summarize_fred_data(econ, series_id) | |
| # Fetch news sentiment (Enhancement #3 - News Sentiment MCP) | |
| logger.debug("Fetching news sentiment for all holdings") | |
| sentiment_data = {} | |
| for ticker in tickers: | |
| try: | |
| sentiment = await self.mcp_router.call_news_sentiment_mcp( | |
| "get_news_with_sentiment", | |
| {"ticker": ticker, "days_back": 7} | |
| ) | |
| sentiment_data[ticker] = sentiment | |
| logger.debug(f"{ticker} sentiment: {sentiment.get('overall_sentiment', 0):.2f}") | |
| except Exception as e: | |
| logger.warning(f"Failed to fetch sentiment for {ticker}: {e}") | |
| # Continue with empty sentiment on error | |
| sentiment_data[ticker] = { | |
| "ticker": ticker, | |
| "overall_sentiment": 0.0, | |
| "confidence": 0.0, | |
| "article_count": 0, | |
| "articles": [], | |
| "error": str(e) | |
| } | |
| # Enrich holdings with market values based on realtime data | |
| enriched_holdings = [] | |
| for holding in state["holdings"]: | |
| ticker = holding.get("ticker") | |
| quantity = holding.get("quantity", 0) | |
| dollar_amount = holding.get("dollar_amount", 0) | |
| # Get current price from realtime_data | |
| current_price = None | |
| if ticker in market_data: | |
| price_data = market_data[ticker] | |
| current_price = price_data.get("price", 0) or price_data.get("regularMarketPrice", 0) | |
| # Calculate market value | |
| if quantity > 0 and current_price: | |
| market_value = Decimal(str(quantity)) * Decimal(str(current_price)) | |
| elif dollar_amount > 0: | |
| market_value = Decimal(str(dollar_amount)) | |
| else: | |
| market_value = Decimal("0") | |
| # Create enriched holding (immutable pattern) | |
| enriched_holding = { | |
| **holding, | |
| "current_price": current_price, | |
| "market_value": float(market_value) | |
| } | |
| enriched_holdings.append(enriched_holding) | |
| # Calculate total portfolio value | |
| total_portfolio_value = sum(h["market_value"] for h in enriched_holdings) | |
| # Calculate portfolio weights | |
| for holding in enriched_holdings: | |
| if total_portfolio_value > 0: | |
| holding["weight"] = holding["market_value"] / total_portfolio_value | |
| else: | |
| # Edge case: equal weights if total is 0 | |
| holding["weight"] = 1.0 / len(enriched_holdings) if len(enriched_holdings) > 0 else 0.0 | |
| # Log for verification | |
| logger.info(f"Portfolio total value: ${total_portfolio_value:,.2f}, weights sum: {sum(h['weight'] for h in enriched_holdings):.4f}") | |
| # Update state with enriched holdings | |
| state["holdings"] = enriched_holdings | |
| # Update state | |
| state["historical_prices"] = historical_data | |
| state["fundamentals"] = fundamentals | |
| state["realtime_data"] = market_data | |
| state["technical_indicators"] = technical_indicators | |
| state["economic_data"] = economic_data | |
| state["sentiment_data"] = sentiment_data # Enhancement #3 | |
| state["current_step"] = "phase_1_complete" | |
| # Log MCP calls | |
| state["mcp_calls"].extend([ | |
| MCPCall.model_validate({"mcp": "yahoo_finance", "tool": "get_quote"}).model_dump(), | |
| MCPCall.model_validate({"mcp": "yahoo_finance", "tool": "get_historical_data"}).model_dump(), | |
| MCPCall.model_validate({"mcp": "fmp", "tool": "get_company_profile"}).model_dump(), | |
| MCPCall.model_validate({"mcp": "trading_mcp", "tool": "get_technical_indicators"}).model_dump(), | |
| MCPCall.model_validate({"mcp": "fred", "tool": "get_economic_series"}).model_dump(), | |
| MCPCall.model_validate({"mcp": "news_sentiment", "tool": "get_news_with_sentiment"}).model_dump(), | |
| ]) | |
| # Track phase duration | |
| phase_duration_ms = int((time.perf_counter() - phase_start) * 1000) | |
| state["phase_1_duration_ms"] = phase_duration_ms | |
| logger.info(f"PHASE 1 COMPLETE: Fetched data for {len(tickers)} assets ({phase_duration_ms}ms)") | |
| except Exception as e: | |
| logger.error(f"Error in Phase 1: {e}") | |
| state["errors"].append(f"Phase 1 error: {str(e)}") | |
| return state | |
| async def _phase_1_5_feature_engineering(self, state: AgentState) -> AgentState: | |
| """Phase 1.5: Extract and select features from raw data. | |
| MCP called: | |
| - Feature Extraction: Technical indicators, normalisation, selection | |
| """ | |
| logger.info("PHASE 1.5: Feature Engineering") | |
| phase_start = time.perf_counter() | |
| state["reasoning_steps"].append("Phase 1.5: Extracting and selecting features") | |
| tickers = [h["ticker"] for h in state["holdings"]] | |
| feature_vectors = {} | |
| for ticker in tickers: | |
| try: | |
| # Get historical prices from Phase 1 data | |
| hist_data = state.get("historical_prices", {}).get(ticker, {}) | |
| prices = hist_data.get("close_prices", []) | |
| if len(prices) < 20: | |
| logger.warning(f"Insufficient price data for {ticker} ({len(prices)} < 20)") | |
| continue | |
| # Extract technical features | |
| tech_features = await self.mcp_router.call_feature_extraction_mcp( | |
| "extract_technical_features", | |
| { | |
| "ticker": ticker, | |
| "prices": prices, | |
| "include_momentum": True, | |
| "include_volatility": True, | |
| "include_trend": True, | |
| } | |
| ) | |
| # Get fundamental and sentiment features from state | |
| fundamentals = state.get("fundamentals", {}).get(ticker, {}) | |
| sentiment = state.get("sentiment_data", {}).get(ticker, {}) | |
| # Compute combined feature vector | |
| feature_vector = await self.mcp_router.call_feature_extraction_mcp( | |
| "compute_feature_vector", | |
| { | |
| "ticker": ticker, | |
| "technical_features": tech_features.get("features", {}), | |
| "fundamental_features": fundamentals, | |
| "sentiment_features": sentiment, | |
| "max_features": 30, | |
| "selection_method": "pca", | |
| } | |
| ) | |
| feature_vectors[ticker] = feature_vector | |
| logger.info( | |
| f"Extracted {feature_vector.get('feature_count', 0)} features for {ticker}" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error extracting features for {ticker}: {e}") | |
| state["errors"].append(f"Feature extraction error for {ticker}: {str(e)}") | |
| state["feature_vectors"] = feature_vectors | |
| state["reasoning_steps"].append( | |
| f"Extracted feature vectors for {len(feature_vectors)} tickers" | |
| ) | |
| # Log MCP calls | |
| state["mcp_calls"].extend([ | |
| MCPCall.model_validate({ | |
| "mcp": "feature_extraction", | |
| "tool": "extract_technical_features" | |
| }).model_dump(), | |
| MCPCall.model_validate({ | |
| "mcp": "feature_extraction", | |
| "tool": "compute_feature_vector" | |
| }).model_dump(), | |
| ]) | |
| # Track phase duration | |
| phase_duration_ms = int((time.perf_counter() - phase_start) * 1000) | |
| logger.info( | |
| f"PHASE 1.5 COMPLETE: Extracted features for {len(feature_vectors)} assets ({phase_duration_ms}ms)" | |
| ) | |
| return state | |
| async def _phase_2_computation(self, state: AgentState) -> AgentState: | |
| """Phase 2: Run computational models with data from Phase 1. | |
| MCPs called: | |
| - Portfolio Optimizer: HRP, Black-Litterman, Mean-Variance | |
| - Risk Analyzer: VaR, CVaR, Monte Carlo | |
| """ | |
| logger.info("PHASE 2: Running Computation Layer MCPs") | |
| phase_start = time.perf_counter() | |
| try: | |
| # Prepare market data for optimization | |
| market_data_list = [] | |
| for ticker, hist_data in state["historical_prices"].items(): | |
| market_data_list.append({ | |
| "ticker": ticker, | |
| "prices": hist_data.get("close_prices", []), | |
| "dates": hist_data.get("dates", []), | |
| }) | |
| # Check for single-asset portfolio (requires minimum 2 assets for optimization) | |
| unique_tickers = set(ticker for ticker in state["historical_prices"].keys()) | |
| is_single_asset = len(unique_tickers) < 2 | |
| # Run portfolio optimizations | |
| logger.debug("Running portfolio optimizations") | |
| if is_single_asset: | |
| # Single-asset fallback: cannot optimise, show 100% allocation | |
| single_ticker = list(unique_tickers)[0] | |
| logger.info(f"Single-asset portfolio detected ({single_ticker}) - skipping optimization, showing 100% allocation") | |
| # Create fallback optimization results with 100% allocation | |
| fallback_weights = {single_ticker: 1.0} | |
| # Calculate basic metrics from the single asset | |
| ticker_holding = next( | |
| (h for h in state["holdings"] if h["ticker"] == single_ticker), | |
| None | |
| ) | |
| fallback_result = { | |
| "weights": fallback_weights, | |
| "expected_return": 0.0, | |
| "volatility": 0.0, | |
| "sharpe": 0.0, | |
| } | |
| hrp_result = fallback_result | |
| bl_result = fallback_result | |
| mv_result = { | |
| **fallback_result, | |
| "message": "Portfolio optimization requires minimum 2 assets. Showing current 100% allocation." | |
| } | |
| else: | |
| # Multiple assets: proceed with normal optimization | |
| # HRP | |
| hrp_result = await self.mcp_router.call_portfolio_optimizer_mcp( | |
| "optimize_hrp", | |
| { | |
| "market_data": market_data_list, | |
| "method": "hrp", | |
| "risk_tolerance": state["risk_tolerance"], | |
| } | |
| ) | |
| # Black-Litterman | |
| bl_result = await self.mcp_router.call_portfolio_optimizer_mcp( | |
| "optimize_black_litterman", | |
| { | |
| "market_data": market_data_list, | |
| "method": "black_litterman", | |
| "risk_tolerance": state["risk_tolerance"], | |
| } | |
| ) | |
| # Mean-Variance | |
| mv_result = await self.mcp_router.call_portfolio_optimizer_mcp( | |
| "optimize_mean_variance", | |
| { | |
| "market_data": market_data_list, | |
| "method": "mean_variance", | |
| "risk_tolerance": state["risk_tolerance"], | |
| } | |
| ) | |
| # Run risk analysis | |
| logger.debug("Running risk analysis") | |
| portfolio_input = [] | |
| for holding in state["holdings"]: | |
| ticker = holding["ticker"] | |
| historical_prices = state["historical_prices"].get(ticker, {}).get("close_prices", []) | |
| portfolio_input.append({ | |
| "ticker": ticker, | |
| "weight": holding.get("weight", 0), | |
| "prices": historical_prices, | |
| }) | |
| risk_result = await self.mcp_router.call_risk_analyzer_mcp( | |
| "analyze_risk", | |
| { | |
| "portfolio": portfolio_input, | |
| "portfolio_value": sum(h.get("market_value", 0) for h in state["holdings"]), | |
| "confidence_level": 0.95, | |
| "method": "monte_carlo", | |
| "num_simulations": 10000, | |
| } | |
| ) | |
| # Update state | |
| state["optimisation_results"] = { | |
| "hrp": hrp_result, | |
| "black_litterman": bl_result, | |
| "mean_variance": mv_result, | |
| } | |
| state["risk_analysis"] = risk_result | |
| state["current_step"] = "phase_2_complete" | |
| # Log MCP calls | |
| state["mcp_calls"].extend([ | |
| MCPCall.model_validate({"mcp": "portfolio_optimizer_mcp", "tool": "optimize_hrp"}).model_dump(), | |
| MCPCall.model_validate({"mcp": "portfolio_optimizer_mcp", "tool": "optimize_black_litterman"}).model_dump(), | |
| MCPCall.model_validate({"mcp": "portfolio_optimizer_mcp", "tool": "optimize_mean_variance"}).model_dump(), | |
| MCPCall.model_validate({"mcp": "risk_analyzer_mcp", "tool": "analyze_risk"}).model_dump(), | |
| ]) | |
| # Track phase duration | |
| phase_duration_ms = int((time.perf_counter() - phase_start) * 1000) | |
| state["phase_2_duration_ms"] = phase_duration_ms | |
| logger.info(f"PHASE 2 COMPLETE: Optimizations and risk analysis done ({phase_duration_ms}ms)") | |
| except Exception as e: | |
| logger.error(f"Error in Phase 2: {e}") | |
| state["errors"].append(f"Phase 2 error: {str(e)}") | |
| return state | |
| async def _phase_2_5_ml_predictions(self, state: AgentState) -> AgentState: | |
| """Phase 2.5: Generate ML-based price forecasts using Ensemble Predictor. | |
| MCP called: | |
| - Ensemble Predictor: Chronos + statistical models for price forecasting | |
| """ | |
| logger.info("PHASE 2.5: Generating ML predictions") | |
| phase_start = time.perf_counter() | |
| try: | |
| # Generate forecasts for each holding | |
| logger.debug("Running ensemble forecasts for portfolio holdings") | |
| ensemble_forecasts = {} | |
| for holding in state["holdings"]: | |
| ticker = holding["ticker"] | |
| # Get historical prices from Phase 1 data | |
| hist_data = state["historical_prices"].get(ticker, {}) | |
| prices = hist_data.get("close_prices", []) | |
| if not prices or len(prices) < 10: | |
| logger.warning(f"Insufficient price data for {ticker}, skipping forecast") | |
| continue | |
| try: | |
| # Call ensemble predictor | |
| forecast_result = await self.mcp_router.call_ensemble_predictor_mcp( | |
| "forecast_ensemble", | |
| { | |
| "ticker": ticker, | |
| "prices": prices, | |
| "forecast_horizon": 30, # 30-day forecast | |
| "confidence_level": 0.95, | |
| "use_returns": True, # Forecast returns for stability | |
| "ensemble_method": "mean", # Simple averaging | |
| } | |
| ) | |
| ensemble_forecasts[ticker] = forecast_result | |
| logger.debug(f"Generated forecast for {ticker} using {len(forecast_result.get('models_used', []))} models") | |
| except Exception as e: | |
| logger.warning(f"Forecast failed for {ticker}: {e}") | |
| continue | |
| # Update state | |
| state["ensemble_forecasts"] = ensemble_forecasts | |
| state["current_step"] = "phase_2_5_complete" | |
| # Log MCP calls | |
| state["mcp_calls"].extend([ | |
| MCPCall.model_validate({ | |
| "mcp": "ensemble_predictor", | |
| "tool": "forecast_ensemble" | |
| }).model_dump(), | |
| ]) | |
| # Track phase duration | |
| phase_duration_ms = int((time.perf_counter() - phase_start) * 1000) | |
| state["phase_2_5_duration_ms"] = phase_duration_ms | |
| logger.info( | |
| f"PHASE 2.5 COMPLETE: Generated forecasts for {len(ensemble_forecasts)} assets ({phase_duration_ms}ms)" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in Phase 2.5: {e}") | |
| state["errors"].append(f"Phase 2.5 error: {str(e)}") | |
| # Set empty forecasts to allow workflow to continue | |
| state["ensemble_forecasts"] = {} | |
| return state | |
| async def _phase_3_synthesis(self, state: AgentState) -> AgentState: | |
| """Phase 3: LLM synthesis of all data into actionable insights.""" | |
| logger.info("PHASE 3: LLM Synthesis") | |
| phase_start = time.perf_counter() | |
| try: | |
| # Prepare data for analyst agent | |
| portfolio_data = { | |
| "holdings": state["holdings"], | |
| "portfolio_id": state.get("portfolio_id", "unknown"), | |
| "risk_tolerance": state["risk_tolerance"], | |
| } | |
| # Call analyst agent (returns AgentResult with usage metrics) | |
| result = await self.analyst_agent.analyze_portfolio( | |
| portfolio_data=portfolio_data, | |
| market_data=state.get("realtime_data", {}), | |
| fundamentals=state.get("fundamentals", {}), | |
| technical_indicators=state.get("technical_indicators", {}), | |
| economic_data=state.get("economic_data", {}), | |
| optimization_results=state.get("optimisation_results", {}), | |
| risk_analysis=state.get("risk_analysis", {}), | |
| ensemble_forecasts=state.get("ensemble_forecasts", {}), | |
| sentiment_data=state.get("sentiment_data", {}), | |
| risk_tolerance=state["risk_tolerance"], | |
| ) | |
| # Extract analysis output and usage metrics | |
| analysis = result.output | |
| # Update state with analysis results | |
| state["ai_synthesis"] = analysis.summary | |
| state["recommendations"] = analysis.recommendations | |
| state["reasoning_steps"].extend(analysis.reasoning) | |
| state["current_step"] = "complete" | |
| # Track LLM usage metrics | |
| state["llm_input_tokens"] = result.input_tokens | |
| state["llm_output_tokens"] = result.output_tokens | |
| state["llm_total_tokens"] = result.total_tokens | |
| state["llm_request_count"] = result.request_count | |
| # Track phase duration | |
| phase_duration_ms = int((time.perf_counter() - phase_start) * 1000) | |
| state["phase_3_duration_ms"] = phase_duration_ms | |
| logger.info( | |
| f"PHASE 3 COMPLETE: Analysis generated (health score: {analysis.health_score}, " | |
| f"{result.total_tokens} tokens, {phase_duration_ms}ms)" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in Phase 3: {e}") | |
| state["errors"].append(f"Phase 3 error: {str(e)}") | |
| return state | |
| async def run(self, initial_state: AgentState) -> AgentState: | |
| """Run the complete workflow. | |
| Args: | |
| initial_state: Initial state with portfolio and query | |
| Returns: | |
| Final state with complete analysis | |
| """ | |
| logger.info(f"Starting portfolio analysis workflow for {len(initial_state['holdings'])} holdings") | |
| result = await self.workflow.ainvoke( | |
| initial_state, | |
| config=RunnableConfig(configurable={}, recursion_limit=25) | |
| ) | |
| logger.info("Workflow complete") | |
| return result | |
| async def stream(self, initial_state: AgentState): | |
| """Stream workflow execution with progress updates. | |
| Yields progress events as each node completes, allowing real-time | |
| updates in the UI. | |
| Args: | |
| initial_state: Initial state with portfolio and query | |
| Yields: | |
| Dict with 'event' (node name), 'progress' (0-1), 'message' (description) | |
| """ | |
| logger.info(f"Streaming portfolio analysis for {len(initial_state['holdings'])} holdings") | |
| phase_info = { | |
| "phase_1_data_layer": { | |
| "progress": 0.2, | |
| "message": "Phase 1: Fetching market data from Yahoo Finance, FMP, FRED..." | |
| }, | |
| "phase_1_5_feature_engineering": { | |
| "progress": 0.4, | |
| "message": "Phase 1.5: Extracting technical features and indicators..." | |
| }, | |
| "phase_2_computation": { | |
| "progress": 0.4, | |
| "message": "Phase 2: Running portfolio optimisation and risk analysis..." | |
| }, | |
| "phase_2_5_ml_predictions": { | |
| "progress": 0.7, | |
| "message": "Phase 2.5: Generating ensemble ML predictions..." | |
| }, | |
| "phase_3_synthesis": { | |
| "progress": 0.9, | |
| "message": "Phase 3: Synthesising AI insights with Claude..." | |
| } | |
| } | |
| final_state = None | |
| config = RunnableConfig(configurable={}, recursion_limit=25) | |
| async for event in self.workflow.astream_events(initial_state, version="v2", config=config): | |
| if event["event"] == "on_chain_end": | |
| node_name = event.get("name", "") | |
| if node_name in phase_info: | |
| info = phase_info[node_name] | |
| yield { | |
| "event": node_name, | |
| "progress": info["progress"], | |
| "message": info["message"], | |
| "state": event.get("data", {}).get("output") | |
| } | |
| final_state = event.get("data", {}).get("output") | |
| yield { | |
| "event": "complete", | |
| "progress": 1.0, | |
| "message": "Analysis complete!", | |
| "state": final_state | |
| } | |