BrianIsaac's picture
fix: add RunnableConfig to workflow.ainvoke and astream_events calls
3645e99
"""LangGraph workflow for multi-agent portfolio analysis.
This implements the multi-phase architecture:
Phase 1: Data Layer MCPs (Yahoo Finance, FMP, Trading-MCP, FRED)
Phase 2: Computation Layer MCPs (Portfolio Optimizer, Risk Analyzer)
Phase 2.5: ML Predictions (Ensemble Predictor with Chronos)
Phase 3: LLM Synthesis (Portfolio Analyst Agent)
"""
import logging
import time
from typing import Dict, Any, List
from datetime import datetime, timezone
from decimal import Decimal
from langgraph.graph import StateGraph, END
from langchain_core.runnables import RunnableConfig
from backend.models.agent_state import AgentState, MCPCall
from backend.agents.portfolio_analyst import PortfolioAnalystAgent
from backend.agents.personas import create_persona_agent, PersonaType
logger = logging.getLogger(__name__)
def summarize_fred_data(series_data: Dict[str, Any], indicator_name: str) -> Dict[str, Any]:
"""Summarize FRED time series to key statistics for token efficiency.
Args:
series_data: FRED API response with observations
indicator_name: Name of the economic indicator
Returns:
Summarised statistics dictionary
"""
if not series_data or 'observations' not in series_data:
return {}
observations = series_data['observations']
if not observations:
return {}
# Extract values
values = [float(obs.get('value', 0)) for obs in observations if obs.get('value')]
if not values:
return {}
# Calculate statistics
current = values[-1]
mean = sum(values) / len(values)
return {
'name': indicator_name,
'current': round(current, 4),
'mean': round(mean, 4),
'min': round(min(values), 4),
'max': round(max(values), 4),
'observations_count': len(values),
'trend': 'up' if current > mean else 'down'
}
class PortfolioAnalysisWorkflow:
"""LangGraph workflow for portfolio analysis."""
def __init__(
self,
mcp_router,
roast_mode: bool = False,
persona: PersonaType | str | None = None
):
"""Initialise the workflow with MCP router.
Args:
mcp_router: MCP router instance for calling MCP servers
roast_mode: If True, use brutal honesty mode for analysis
persona: Optional investor persona (e.g., 'warren_buffett', 'cathie_wood', 'ray_dalio')
"""
self.mcp_router = mcp_router
self.roast_mode = roast_mode
self.persona = persona
# Initialise the appropriate analyst agent
if persona:
# Use persona-based analysis
self.analyst_agent = create_persona_agent(persona)
logger.info(f"Using persona agent: {persona}")
else:
# Use standard analysis (with optional roast mode)
self.analyst_agent = PortfolioAnalystAgent(roast_mode=roast_mode)
logger.info(f"Using standard analyst (roast_mode={roast_mode})")
# Build the workflow graph
self.workflow = self._build_workflow()
def _build_workflow(self) -> StateGraph:
"""Build the LangGraph workflow with parallel feature extraction."""
workflow = StateGraph(AgentState)
# Add nodes for each phase
workflow.add_node("phase_1_data_layer", self._phase_1_data_layer)
workflow.add_node("phase_1_5_feature_engineering", self._phase_1_5_feature_engineering)
workflow.add_node("phase_2_computation", self._phase_2_computation)
workflow.add_node("phase_2_5_ml_predictions", self._phase_2_5_ml_predictions)
workflow.add_node("phase_3_synthesis", self._phase_3_synthesis)
# Define the flow with parallel execution (fan-out/fan-in pattern)
# Phase 1.5 (features) and Phase 2 (computation) run in parallel after Phase 1
workflow.set_entry_point("phase_1_data_layer")
# Fan-out: Both branches start after Phase 1
workflow.add_edge("phase_1_data_layer", "phase_1_5_feature_engineering")
workflow.add_edge("phase_1_data_layer", "phase_2_computation")
# Fan-in: Both branches merge before Phase 2.5
workflow.add_edge("phase_1_5_feature_engineering", "phase_2_5_ml_predictions")
workflow.add_edge("phase_2_computation", "phase_2_5_ml_predictions")
# Continue to synthesis
workflow.add_edge("phase_2_5_ml_predictions", "phase_3_synthesis")
workflow.add_edge("phase_3_synthesis", END)
return workflow.compile()
async def _phase_1_data_layer(self, state: AgentState) -> AgentState:
"""Phase 1: Fetch all data from data layer MCPs.
MCPs called:
- Yahoo Finance: Real-time quotes and historical data
- FMP: Company fundamentals
- Trading-MCP: Technical indicators
- FRED: Economic indicators
"""
logger.info("PHASE 1: Fetching data from Data Layer MCPs")
phase_start = time.perf_counter()
tickers = [h["ticker"] for h in state["holdings"]]
try:
# Fetch market data (Yahoo Finance)
logger.debug(f"Fetching market data for {len(tickers)} tickers")
market_data_list = await self.mcp_router.call_yahoo_finance_mcp("get_quote", {"tickers": tickers})
# Transform list to dict keyed by ticker
market_data = {}
for quote in market_data_list:
ticker = quote.get("ticker") or quote.get("symbol")
if ticker:
market_data[ticker] = quote
# Fetch historical data for each ticker
historical_data = {}
for ticker in tickers:
hist = await self.mcp_router.call_yahoo_finance_mcp(
"get_historical_data",
{"ticker": ticker, "period": "1y", "interval": "1d"}
)
historical_data[ticker] = hist
# Fetch fundamentals (FMP)
logger.debug("Fetching company fundamentals")
fundamentals = {}
for ticker in tickers:
fund = await self.mcp_router.call_fmp_mcp("get_company_profile", {"ticker": ticker})
fundamentals[ticker] = fund
# Fetch technical indicators (Trading-MCP)
logger.debug("Calculating technical indicators")
technical_indicators = {}
for ticker in tickers:
tech = await self.mcp_router.call_trading_mcp(
"get_technical_indicators",
{"ticker": ticker, "period": "3mo"}
)
technical_indicators[ticker] = tech
# Fetch economic data (FRED)
logger.debug("Fetching economic indicators")
economic_data = {}
for series_id in ["GDP", "UNRATE", "DFF"]:
econ = await self.mcp_router.call_fred_mcp("get_economic_series", {"series_id": series_id})
economic_data[series_id] = summarize_fred_data(econ, series_id)
# Fetch news sentiment (Enhancement #3 - News Sentiment MCP)
logger.debug("Fetching news sentiment for all holdings")
sentiment_data = {}
for ticker in tickers:
try:
sentiment = await self.mcp_router.call_news_sentiment_mcp(
"get_news_with_sentiment",
{"ticker": ticker, "days_back": 7}
)
sentiment_data[ticker] = sentiment
logger.debug(f"{ticker} sentiment: {sentiment.get('overall_sentiment', 0):.2f}")
except Exception as e:
logger.warning(f"Failed to fetch sentiment for {ticker}: {e}")
# Continue with empty sentiment on error
sentiment_data[ticker] = {
"ticker": ticker,
"overall_sentiment": 0.0,
"confidence": 0.0,
"article_count": 0,
"articles": [],
"error": str(e)
}
# Enrich holdings with market values based on realtime data
enriched_holdings = []
for holding in state["holdings"]:
ticker = holding.get("ticker")
quantity = holding.get("quantity", 0)
dollar_amount = holding.get("dollar_amount", 0)
# Get current price from realtime_data
current_price = None
if ticker in market_data:
price_data = market_data[ticker]
current_price = price_data.get("price", 0) or price_data.get("regularMarketPrice", 0)
# Calculate market value
if quantity > 0 and current_price:
market_value = Decimal(str(quantity)) * Decimal(str(current_price))
elif dollar_amount > 0:
market_value = Decimal(str(dollar_amount))
else:
market_value = Decimal("0")
# Create enriched holding (immutable pattern)
enriched_holding = {
**holding,
"current_price": current_price,
"market_value": float(market_value)
}
enriched_holdings.append(enriched_holding)
# Calculate total portfolio value
total_portfolio_value = sum(h["market_value"] for h in enriched_holdings)
# Calculate portfolio weights
for holding in enriched_holdings:
if total_portfolio_value > 0:
holding["weight"] = holding["market_value"] / total_portfolio_value
else:
# Edge case: equal weights if total is 0
holding["weight"] = 1.0 / len(enriched_holdings) if len(enriched_holdings) > 0 else 0.0
# Log for verification
logger.info(f"Portfolio total value: ${total_portfolio_value:,.2f}, weights sum: {sum(h['weight'] for h in enriched_holdings):.4f}")
# Update state with enriched holdings
state["holdings"] = enriched_holdings
# Update state
state["historical_prices"] = historical_data
state["fundamentals"] = fundamentals
state["realtime_data"] = market_data
state["technical_indicators"] = technical_indicators
state["economic_data"] = economic_data
state["sentiment_data"] = sentiment_data # Enhancement #3
state["current_step"] = "phase_1_complete"
# Log MCP calls
state["mcp_calls"].extend([
MCPCall.model_validate({"mcp": "yahoo_finance", "tool": "get_quote"}).model_dump(),
MCPCall.model_validate({"mcp": "yahoo_finance", "tool": "get_historical_data"}).model_dump(),
MCPCall.model_validate({"mcp": "fmp", "tool": "get_company_profile"}).model_dump(),
MCPCall.model_validate({"mcp": "trading_mcp", "tool": "get_technical_indicators"}).model_dump(),
MCPCall.model_validate({"mcp": "fred", "tool": "get_economic_series"}).model_dump(),
MCPCall.model_validate({"mcp": "news_sentiment", "tool": "get_news_with_sentiment"}).model_dump(),
])
# Track phase duration
phase_duration_ms = int((time.perf_counter() - phase_start) * 1000)
state["phase_1_duration_ms"] = phase_duration_ms
logger.info(f"PHASE 1 COMPLETE: Fetched data for {len(tickers)} assets ({phase_duration_ms}ms)")
except Exception as e:
logger.error(f"Error in Phase 1: {e}")
state["errors"].append(f"Phase 1 error: {str(e)}")
return state
async def _phase_1_5_feature_engineering(self, state: AgentState) -> AgentState:
"""Phase 1.5: Extract and select features from raw data.
MCP called:
- Feature Extraction: Technical indicators, normalisation, selection
"""
logger.info("PHASE 1.5: Feature Engineering")
phase_start = time.perf_counter()
state["reasoning_steps"].append("Phase 1.5: Extracting and selecting features")
tickers = [h["ticker"] for h in state["holdings"]]
feature_vectors = {}
for ticker in tickers:
try:
# Get historical prices from Phase 1 data
hist_data = state.get("historical_prices", {}).get(ticker, {})
prices = hist_data.get("close_prices", [])
if len(prices) < 20:
logger.warning(f"Insufficient price data for {ticker} ({len(prices)} < 20)")
continue
# Extract technical features
tech_features = await self.mcp_router.call_feature_extraction_mcp(
"extract_technical_features",
{
"ticker": ticker,
"prices": prices,
"include_momentum": True,
"include_volatility": True,
"include_trend": True,
}
)
# Get fundamental and sentiment features from state
fundamentals = state.get("fundamentals", {}).get(ticker, {})
sentiment = state.get("sentiment_data", {}).get(ticker, {})
# Compute combined feature vector
feature_vector = await self.mcp_router.call_feature_extraction_mcp(
"compute_feature_vector",
{
"ticker": ticker,
"technical_features": tech_features.get("features", {}),
"fundamental_features": fundamentals,
"sentiment_features": sentiment,
"max_features": 30,
"selection_method": "pca",
}
)
feature_vectors[ticker] = feature_vector
logger.info(
f"Extracted {feature_vector.get('feature_count', 0)} features for {ticker}"
)
except Exception as e:
logger.error(f"Error extracting features for {ticker}: {e}")
state["errors"].append(f"Feature extraction error for {ticker}: {str(e)}")
state["feature_vectors"] = feature_vectors
state["reasoning_steps"].append(
f"Extracted feature vectors for {len(feature_vectors)} tickers"
)
# Log MCP calls
state["mcp_calls"].extend([
MCPCall.model_validate({
"mcp": "feature_extraction",
"tool": "extract_technical_features"
}).model_dump(),
MCPCall.model_validate({
"mcp": "feature_extraction",
"tool": "compute_feature_vector"
}).model_dump(),
])
# Track phase duration
phase_duration_ms = int((time.perf_counter() - phase_start) * 1000)
logger.info(
f"PHASE 1.5 COMPLETE: Extracted features for {len(feature_vectors)} assets ({phase_duration_ms}ms)"
)
return state
async def _phase_2_computation(self, state: AgentState) -> AgentState:
"""Phase 2: Run computational models with data from Phase 1.
MCPs called:
- Portfolio Optimizer: HRP, Black-Litterman, Mean-Variance
- Risk Analyzer: VaR, CVaR, Monte Carlo
"""
logger.info("PHASE 2: Running Computation Layer MCPs")
phase_start = time.perf_counter()
try:
# Prepare market data for optimization
market_data_list = []
for ticker, hist_data in state["historical_prices"].items():
market_data_list.append({
"ticker": ticker,
"prices": hist_data.get("close_prices", []),
"dates": hist_data.get("dates", []),
})
# Check for single-asset portfolio (requires minimum 2 assets for optimization)
unique_tickers = set(ticker for ticker in state["historical_prices"].keys())
is_single_asset = len(unique_tickers) < 2
# Run portfolio optimizations
logger.debug("Running portfolio optimizations")
if is_single_asset:
# Single-asset fallback: cannot optimise, show 100% allocation
single_ticker = list(unique_tickers)[0]
logger.info(f"Single-asset portfolio detected ({single_ticker}) - skipping optimization, showing 100% allocation")
# Create fallback optimization results with 100% allocation
fallback_weights = {single_ticker: 1.0}
# Calculate basic metrics from the single asset
ticker_holding = next(
(h for h in state["holdings"] if h["ticker"] == single_ticker),
None
)
fallback_result = {
"weights": fallback_weights,
"expected_return": 0.0,
"volatility": 0.0,
"sharpe": 0.0,
}
hrp_result = fallback_result
bl_result = fallback_result
mv_result = {
**fallback_result,
"message": "Portfolio optimization requires minimum 2 assets. Showing current 100% allocation."
}
else:
# Multiple assets: proceed with normal optimization
# HRP
hrp_result = await self.mcp_router.call_portfolio_optimizer_mcp(
"optimize_hrp",
{
"market_data": market_data_list,
"method": "hrp",
"risk_tolerance": state["risk_tolerance"],
}
)
# Black-Litterman
bl_result = await self.mcp_router.call_portfolio_optimizer_mcp(
"optimize_black_litterman",
{
"market_data": market_data_list,
"method": "black_litterman",
"risk_tolerance": state["risk_tolerance"],
}
)
# Mean-Variance
mv_result = await self.mcp_router.call_portfolio_optimizer_mcp(
"optimize_mean_variance",
{
"market_data": market_data_list,
"method": "mean_variance",
"risk_tolerance": state["risk_tolerance"],
}
)
# Run risk analysis
logger.debug("Running risk analysis")
portfolio_input = []
for holding in state["holdings"]:
ticker = holding["ticker"]
historical_prices = state["historical_prices"].get(ticker, {}).get("close_prices", [])
portfolio_input.append({
"ticker": ticker,
"weight": holding.get("weight", 0),
"prices": historical_prices,
})
risk_result = await self.mcp_router.call_risk_analyzer_mcp(
"analyze_risk",
{
"portfolio": portfolio_input,
"portfolio_value": sum(h.get("market_value", 0) for h in state["holdings"]),
"confidence_level": 0.95,
"method": "monte_carlo",
"num_simulations": 10000,
}
)
# Update state
state["optimisation_results"] = {
"hrp": hrp_result,
"black_litterman": bl_result,
"mean_variance": mv_result,
}
state["risk_analysis"] = risk_result
state["current_step"] = "phase_2_complete"
# Log MCP calls
state["mcp_calls"].extend([
MCPCall.model_validate({"mcp": "portfolio_optimizer_mcp", "tool": "optimize_hrp"}).model_dump(),
MCPCall.model_validate({"mcp": "portfolio_optimizer_mcp", "tool": "optimize_black_litterman"}).model_dump(),
MCPCall.model_validate({"mcp": "portfolio_optimizer_mcp", "tool": "optimize_mean_variance"}).model_dump(),
MCPCall.model_validate({"mcp": "risk_analyzer_mcp", "tool": "analyze_risk"}).model_dump(),
])
# Track phase duration
phase_duration_ms = int((time.perf_counter() - phase_start) * 1000)
state["phase_2_duration_ms"] = phase_duration_ms
logger.info(f"PHASE 2 COMPLETE: Optimizations and risk analysis done ({phase_duration_ms}ms)")
except Exception as e:
logger.error(f"Error in Phase 2: {e}")
state["errors"].append(f"Phase 2 error: {str(e)}")
return state
async def _phase_2_5_ml_predictions(self, state: AgentState) -> AgentState:
"""Phase 2.5: Generate ML-based price forecasts using Ensemble Predictor.
MCP called:
- Ensemble Predictor: Chronos + statistical models for price forecasting
"""
logger.info("PHASE 2.5: Generating ML predictions")
phase_start = time.perf_counter()
try:
# Generate forecasts for each holding
logger.debug("Running ensemble forecasts for portfolio holdings")
ensemble_forecasts = {}
for holding in state["holdings"]:
ticker = holding["ticker"]
# Get historical prices from Phase 1 data
hist_data = state["historical_prices"].get(ticker, {})
prices = hist_data.get("close_prices", [])
if not prices or len(prices) < 10:
logger.warning(f"Insufficient price data for {ticker}, skipping forecast")
continue
try:
# Call ensemble predictor
forecast_result = await self.mcp_router.call_ensemble_predictor_mcp(
"forecast_ensemble",
{
"ticker": ticker,
"prices": prices,
"forecast_horizon": 30, # 30-day forecast
"confidence_level": 0.95,
"use_returns": True, # Forecast returns for stability
"ensemble_method": "mean", # Simple averaging
}
)
ensemble_forecasts[ticker] = forecast_result
logger.debug(f"Generated forecast for {ticker} using {len(forecast_result.get('models_used', []))} models")
except Exception as e:
logger.warning(f"Forecast failed for {ticker}: {e}")
continue
# Update state
state["ensemble_forecasts"] = ensemble_forecasts
state["current_step"] = "phase_2_5_complete"
# Log MCP calls
state["mcp_calls"].extend([
MCPCall.model_validate({
"mcp": "ensemble_predictor",
"tool": "forecast_ensemble"
}).model_dump(),
])
# Track phase duration
phase_duration_ms = int((time.perf_counter() - phase_start) * 1000)
state["phase_2_5_duration_ms"] = phase_duration_ms
logger.info(
f"PHASE 2.5 COMPLETE: Generated forecasts for {len(ensemble_forecasts)} assets ({phase_duration_ms}ms)"
)
except Exception as e:
logger.error(f"Error in Phase 2.5: {e}")
state["errors"].append(f"Phase 2.5 error: {str(e)}")
# Set empty forecasts to allow workflow to continue
state["ensemble_forecasts"] = {}
return state
async def _phase_3_synthesis(self, state: AgentState) -> AgentState:
"""Phase 3: LLM synthesis of all data into actionable insights."""
logger.info("PHASE 3: LLM Synthesis")
phase_start = time.perf_counter()
try:
# Prepare data for analyst agent
portfolio_data = {
"holdings": state["holdings"],
"portfolio_id": state.get("portfolio_id", "unknown"),
"risk_tolerance": state["risk_tolerance"],
}
# Call analyst agent (returns AgentResult with usage metrics)
result = await self.analyst_agent.analyze_portfolio(
portfolio_data=portfolio_data,
market_data=state.get("realtime_data", {}),
fundamentals=state.get("fundamentals", {}),
technical_indicators=state.get("technical_indicators", {}),
economic_data=state.get("economic_data", {}),
optimization_results=state.get("optimisation_results", {}),
risk_analysis=state.get("risk_analysis", {}),
ensemble_forecasts=state.get("ensemble_forecasts", {}),
sentiment_data=state.get("sentiment_data", {}),
risk_tolerance=state["risk_tolerance"],
)
# Extract analysis output and usage metrics
analysis = result.output
# Update state with analysis results
state["ai_synthesis"] = analysis.summary
state["recommendations"] = analysis.recommendations
state["reasoning_steps"].extend(analysis.reasoning)
state["current_step"] = "complete"
# Track LLM usage metrics
state["llm_input_tokens"] = result.input_tokens
state["llm_output_tokens"] = result.output_tokens
state["llm_total_tokens"] = result.total_tokens
state["llm_request_count"] = result.request_count
# Track phase duration
phase_duration_ms = int((time.perf_counter() - phase_start) * 1000)
state["phase_3_duration_ms"] = phase_duration_ms
logger.info(
f"PHASE 3 COMPLETE: Analysis generated (health score: {analysis.health_score}, "
f"{result.total_tokens} tokens, {phase_duration_ms}ms)"
)
except Exception as e:
logger.error(f"Error in Phase 3: {e}")
state["errors"].append(f"Phase 3 error: {str(e)}")
return state
async def run(self, initial_state: AgentState) -> AgentState:
"""Run the complete workflow.
Args:
initial_state: Initial state with portfolio and query
Returns:
Final state with complete analysis
"""
logger.info(f"Starting portfolio analysis workflow for {len(initial_state['holdings'])} holdings")
result = await self.workflow.ainvoke(
initial_state,
config=RunnableConfig(configurable={}, recursion_limit=25)
)
logger.info("Workflow complete")
return result
async def stream(self, initial_state: AgentState):
"""Stream workflow execution with progress updates.
Yields progress events as each node completes, allowing real-time
updates in the UI.
Args:
initial_state: Initial state with portfolio and query
Yields:
Dict with 'event' (node name), 'progress' (0-1), 'message' (description)
"""
logger.info(f"Streaming portfolio analysis for {len(initial_state['holdings'])} holdings")
phase_info = {
"phase_1_data_layer": {
"progress": 0.2,
"message": "Phase 1: Fetching market data from Yahoo Finance, FMP, FRED..."
},
"phase_1_5_feature_engineering": {
"progress": 0.4,
"message": "Phase 1.5: Extracting technical features and indicators..."
},
"phase_2_computation": {
"progress": 0.4,
"message": "Phase 2: Running portfolio optimisation and risk analysis..."
},
"phase_2_5_ml_predictions": {
"progress": 0.7,
"message": "Phase 2.5: Generating ensemble ML predictions..."
},
"phase_3_synthesis": {
"progress": 0.9,
"message": "Phase 3: Synthesising AI insights with Claude..."
}
}
final_state = None
config = RunnableConfig(configurable={}, recursion_limit=25)
async for event in self.workflow.astream_events(initial_state, version="v2", config=config):
if event["event"] == "on_chain_end":
node_name = event.get("name", "")
if node_name in phase_info:
info = phase_info[node_name]
yield {
"event": node_name,
"progress": info["progress"],
"message": info["message"],
"state": event.get("data", {}).get("output")
}
final_state = event.get("data", {}).get("output")
yield {
"event": "complete",
"progress": 1.0,
"message": "Analysis complete!",
"state": final_state
}