BrianIsaac's picture
fix: pass config through tool_node wrapper to ToolNode.ainvoke
54a86b4
"""ReAct agent for dynamic portfolio building.
Uses LangGraph for tool calling and dynamic routing based on user goals.
"""
from typing import Annotated, Literal, Any
from typing_extensions import TypedDict
import logging
import re
import traceback
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from langchain_core.messages import ToolMessage, HumanMessage
from langchain_core.tools import tool
from langchain_core.runnables import RunnableConfig
from langchain_anthropic import ChatAnthropic
from backend.config import settings
logger = logging.getLogger(__name__)
class ReactAgentState(TypedDict):
"""State for ReAct agent."""
messages: Annotated[list, add_messages]
user_goals: list[str]
risk_tolerance: int
constraints: str
selected_tools: list[str]
portfolio_result: dict
iteration_count: int # Track iterations to prevent unbounded loops
class PortfolioBuilderAgent:
"""ReAct agent for building portfolios based on user goals.
Uses LangGraph with tool calling for dynamic tool selection based on
user-specified investment goals, risk tolerance, and constraints.
"""
def __init__(self, mcp_router):
"""Initialise the portfolio builder agent.
Args:
mcp_router: MCP router instance for calling MCP servers
"""
self.mcp_router = mcp_router
self.model = ChatAnthropic(
model=settings.anthropic_model,
api_key=settings.anthropic_api_key,
max_tokens=1024, # Reduced to limit token usage
)
self.tools = self._register_tools()
self.tools_by_name = {t.name: t for t in self.tools}
self.workflow = self._build_workflow()
def _register_tools(self) -> list:
"""Register available tools for portfolio building."""
mcp_router = self.mcp_router
@tool
async def fetch_market_data(tickers: list[str]) -> dict:
"""Fetch current market data for given tickers.
Args:
tickers: List of stock ticker symbols to fetch data for
Returns:
Dictionary mapping tickers to their market data
"""
results = {}
for ticker in tickers:
try:
data = await mcp_router.call_yahoo_finance_mcp(
"get_quote", {"tickers": [ticker]}
)
results[ticker] = data
except Exception as e:
logger.error(f"Error fetching market data for {ticker}: {e}")
results[ticker] = {"error": str(e)}
return results
@tool
async def get_fundamentals(ticker: str) -> dict:
"""Get fundamental analysis for a ticker.
Args:
ticker: Stock ticker symbol
Returns:
Fundamental data including financials, ratios, and metrics
"""
try:
return await mcp_router.call_fmp_mcp(
"get_company_profile", {"ticker": ticker}
)
except Exception as e:
logger.error(f"Error getting fundamentals for {ticker}: {e}")
return {"error": str(e)}
@tool
async def get_historical_prices(ticker: str, period: str = "3mo") -> dict:
"""Get historical price data summary for a ticker.
Args:
ticker: Stock ticker symbol
period: Time period (e.g., '1mo', '3mo', '6mo')
Returns:
Summary of historical price data with key statistics
"""
try:
data = await mcp_router.call_yahoo_finance_mcp(
"get_historical_data", {"ticker": ticker, "period": period}
)
if not data or "error" in data:
return data
prices = data.get("close_prices", [])
if not prices:
return {"error": "No price data available"}
return {
"ticker": ticker,
"period": period,
"data_points": len(prices),
"latest_price": prices[-1] if prices else None,
"first_price": prices[0] if prices else None,
"price_change_pct": round(((prices[-1] - prices[0]) / prices[0]) * 100, 2) if prices else None,
"high": max(prices) if prices else None,
"low": min(prices) if prices else None,
"avg": round(sum(prices) / len(prices), 2) if prices else None,
}
except Exception as e:
logger.error(f"Error getting historical prices for {ticker}: {e}")
return {"error": str(e)}
@tool
async def calculate_technicals(ticker: str, prices: list[float]) -> dict:
"""Calculate technical indicators for a ticker.
Args:
ticker: Stock ticker symbol
prices: List of historical closing prices
Returns:
Technical indicators including RSI, MACD, Bollinger Bands, etc.
"""
try:
return await mcp_router.call_feature_extraction_mcp(
"extract_technical_features",
{
"ticker": ticker,
"prices": prices,
"include_momentum": True,
"include_volatility": True,
"include_trend": True,
}
)
except Exception as e:
logger.error(f"Error calculating technicals for {ticker}: {e}")
return {"error": str(e)}
@tool
async def optimise_allocation(
tickers: list[str],
prices_data: dict[str, list[float]],
dates: list[str]
) -> dict:
"""Optimise portfolio allocation using Hierarchical Risk Parity.
Args:
tickers: List of ticker symbols
prices_data: Dictionary mapping tickers to their historical prices
dates: List of date strings corresponding to the prices
Returns:
Optimised weights for each ticker
"""
try:
from decimal import Decimal
market_data = []
for ticker in tickers:
if ticker in prices_data:
prices = [Decimal(str(p)) for p in prices_data[ticker]]
market_data.append({
"ticker": ticker,
"prices": prices,
"dates": dates
})
if len(market_data) < 2:
return {"error": "Need at least 2 tickers with price data"}
return await mcp_router.call_portfolio_optimizer_mcp(
"optimize_hrp", {
"market_data": market_data,
"method": "hrp",
"risk_tolerance": "moderate"
}
)
except Exception as e:
logger.error(f"Error optimising allocation: {e}")
return {"error": str(e)}
@tool
async def assess_risk(
weights: dict[str, float],
historical_prices: dict[str, list[float]],
portfolio_value: float = 10000.0
) -> dict:
"""Assess portfolio risk metrics.
Args:
weights: Dictionary mapping tickers to allocation weights (0-1)
historical_prices: Dictionary mapping tickers to their historical prices
portfolio_value: Total portfolio value in dollars
Returns:
Risk metrics including VaR, CVaR, volatility, etc.
"""
try:
from decimal import Decimal
# Construct portfolio in the required format
portfolio = []
for ticker, weight in weights.items():
if ticker in historical_prices:
prices = historical_prices[ticker]
portfolio.append({
"ticker": ticker,
"weight": Decimal(str(weight)),
"prices": [Decimal(str(p)) for p in prices]
})
if not portfolio:
return {"error": "No valid portfolio data to analyse"}
return await mcp_router.call_risk_analyzer_mcp(
"analyze_risk", {
"portfolio": portfolio,
"portfolio_value": Decimal(str(portfolio_value))
}
)
except Exception as e:
logger.error(f"Error assessing risk: {e}")
return {"error": str(e)}
@tool
async def get_news_sentiment(ticker: str) -> dict:
"""Get news sentiment analysis for a ticker.
Args:
ticker: Stock ticker symbol
Returns:
Sentiment scores and recent news analysis
"""
try:
return await mcp_router.call_news_sentiment_mcp(
"get_news_with_sentiment", {"ticker": ticker, "days_back": 7}
)
except Exception as e:
logger.error(f"Error getting news sentiment for {ticker}: {e}")
return {"error": str(e)}
return [
fetch_market_data,
get_fundamentals,
get_historical_prices,
calculate_technicals,
optimise_allocation,
assess_risk,
get_news_sentiment,
]
def _build_workflow(self):
"""Build ReAct workflow with tool calling."""
builder = StateGraph(ReactAgentState)
async def agent_node(state: ReactAgentState) -> dict:
"""Call model with available tools."""
system_prompt = self._build_system_prompt(state)
model_with_tools = self.model.bind_tools(self.tools)
messages = [
{"role": "system", "content": system_prompt}
] + state["messages"]
response = await model_with_tools.ainvoke(messages)
# Increment iteration counter
iteration_count = state.get("iteration_count", 0) + 1
return {"messages": [response], "iteration_count": iteration_count}
# Create tool node with output truncation to reduce token usage
base_tool_node = ToolNode(tools=self.tools)
async def tool_node(state: ReactAgentState, config: RunnableConfig) -> dict:
"""Execute tools and truncate outputs to reduce token usage."""
result = await base_tool_node.ainvoke(state, config)
# Smart truncation based on tool type
if "messages" in result:
truncated_messages = []
for msg in result["messages"]:
if isinstance(msg, ToolMessage):
content = str(msg.content)
tool_name = msg.name
# Smart summarization based on tool name
if tool_name == "fetch_market_data":
# Extract just price and change
try:
import json
data = json.loads(content) if isinstance(content, str) else content
if isinstance(data, dict):
summary = {k: {"price": v.get("regularMarketPrice"), "change": v.get("regularMarketChangePercent")}
for k, v in data.items() if isinstance(v, dict)}
msg.content = str(summary)
except:
msg.content = content[:200] + "..."
elif tool_name == "get_news_sentiment":
# Extract just headlines and score
try:
import json
data = json.loads(content) if isinstance(content, str) else content
if isinstance(data, dict):
headlines = [n.get('title', '')[:50] for n in data.get('news', [])[:2]]
msg.content = f"Score: {data.get('sentiment_score')}, News: {headlines}"
except:
msg.content = content[:200] + "..."
elif len(content) > 300:
truncated_content = content[:300] + f"... (truncated {len(content) - 300} chars)"
msg.content = truncated_content
truncated_messages.append(msg)
result["messages"] = truncated_messages
return result
def should_continue(state: ReactAgentState) -> Literal["tools", "end"]:
"""Determine if agent should continue or finish."""
# Check iteration limit (prevent unbounded token growth)
max_iterations = 15 # Increased from 8 to allow more complex portfolio building
if state.get("iteration_count", 0) >= max_iterations:
logger.warning(f"Max iterations ({max_iterations}) reached, ending workflow")
return "end"
last_message = state["messages"][-1]
if not hasattr(last_message, 'tool_calls') or not last_message.tool_calls:
return "end"
return "tools"
builder.add_node("agent", agent_node)
builder.add_node("tools", tool_node)
builder.add_edge(START, "agent")
builder.add_conditional_edges(
"agent",
should_continue,
{"tools": "tools", "end": END}
)
builder.add_edge("tools", "agent")
return builder.compile()
def _build_system_prompt(self, state: ReactAgentState) -> str:
"""Build system prompt based on user goals.
Args:
state: Current agent state with user preferences
Returns:
Formatted system prompt
"""
# Defensive type handling
user_goals = state.get("user_goals", [])
goals = ", ".join(user_goals) if isinstance(user_goals, list) and user_goals else "General investing"
risk = state["risk_tolerance"]
state_constraints = state.get("constraints")
constraints = (", ".join(state_constraints) if isinstance(state_constraints, list)
else str(state_constraints)) if state_constraints else "None specified"
return f"""Build portfolio. Goals: {goals}. Risk: {risk}/10. Constraints: {constraints}.
Tools: fetch_market_data, get_fundamentals, get_historical_prices, calculate_technicals, optimise_allocation, assess_risk, get_news_sentiment.
Process:
1. Get data for candidates
2. Filter by fundamentals/technicals
3. Optimise weights
4. Output portfolio
Output format:
- TICKER: X% - brief reason
Expected return: X%, Risk: Y%"""
async def run(
self,
goals: list[str],
risk_tolerance: int,
constraints: str
) -> dict:
"""Run the ReAct agent to build a portfolio.
Args:
goals: List of investment goals (e.g., ['Growth', 'Income'])
risk_tolerance: Risk tolerance score from 1-10
constraints: User-specified constraints as text
Returns:
Dictionary containing portfolio, reasoning trace, and messages
"""
logger.info(f"Building portfolio with goals={goals}, risk={risk_tolerance}")
# Normalise input types to handle any caller mistakes
goals = goals if isinstance(goals, list) else ([goals] if goals else [])
constraints = ", ".join(constraints) if isinstance(constraints, list) else (constraints or "")
initial_state: ReactAgentState = {
"messages": [
HumanMessage(content="Build me a portfolio based on my goals and constraints.")
],
"user_goals": goals,
"risk_tolerance": risk_tolerance,
"constraints": constraints,
"selected_tools": [],
"portfolio_result": {},
"iteration_count": 0
}
result = await self.workflow.ainvoke(
initial_state,
config=RunnableConfig(configurable={}, recursion_limit=25)
)
reasoning_trace = self._extract_reasoning_trace(result["messages"])
portfolio = self._extract_portfolio(result["messages"])
return {
"portfolio": portfolio,
"reasoning_trace": reasoning_trace,
"final_response": result["messages"][-1].content if result["messages"] else ""
}
async def run_stream(
self,
goals: list[str],
risk_tolerance: int,
constraints: str
):
"""Stream the ReAct agent execution with real-time updates.
Yields ChatMessage-compatible dictionaries as tool calls and reasoning
happen in real-time.
Args:
goals: List of investment goals
risk_tolerance: Risk tolerance score from 1-10
constraints: User-specified constraints as text
Yields:
dict: ChatMessage-compatible dictionaries with role, content, and metadata
"""
import time
try:
logger.info(f"Streaming portfolio build started")
logger.info(f"Input types - goals: {type(goals).__name__}, risk_tolerance: {type(risk_tolerance).__name__}, constraints: {type(constraints).__name__}")
logger.info(f"Input values - goals: {goals!r}, risk_tolerance: {risk_tolerance!r}, constraints: {constraints!r}")
# Normalise input types to handle any caller mistakes
goals = goals if isinstance(goals, list) else ([goals] if goals else [])
constraints = ", ".join(constraints) if isinstance(constraints, list) else (constraints or "")
logger.info(f"After normalisation - goals: {goals!r} (type: {type(goals).__name__})")
logger.info(f"After normalisation - constraints: {constraints!r} (type: {type(constraints).__name__})")
# Initial state
initial_state: ReactAgentState = {
"messages": [
HumanMessage(content="Build me a portfolio based on my goals and constraints.")
],
"user_goals": goals,
"risk_tolerance": risk_tolerance,
"constraints": constraints,
"selected_tools": [],
"portfolio_result": {},
"iteration_count": 0
}
# Yield initial message
logger.info(f"About to yield initial message. Goals type: {type(goals).__name__}, Goals value: {goals!r}")
initial_content = f"Starting portfolio construction based on your goals:\nβ€’ {', '.join(goals)}\nβ€’ Risk tolerance: {risk_tolerance}/10\nβ€’ Constraints: {constraints or 'None'}"
logger.info(f"Initial content type: {type(initial_content).__name__}, content length: {len(initial_content)}")
initial_message = {
"role": "assistant",
"content": initial_content,
"metadata": None
}
logger.info(f"Initial message keys: {initial_message.keys()}, content type: {type(initial_message['content']).__name__}")
yield initial_message
# Track tool calls
current_thinking = None
current_tool_call = None
tool_call_counter = 0
start_time = time.time()
message_count = 0
# Stream workflow events
logger.info(f"Starting to stream workflow events")
async for event in self.workflow.astream_events(initial_state, version="v2"):
try:
message_count += 1
kind = event['event']
logger.debug(f"Event #{message_count}: kind={kind}")
# Agent thinking (AIMessage with tool calls)
if kind == "on_chat_model_start":
logger.info(f"Chat model started - creating thinking message")
current_thinking = {
"role": "assistant",
"content": "",
"metadata": {
"title": "Thinking",
"status": "pending"
}
}
logger.debug(f"Created thinking message (will yield when complete)")
# Don't yield on start - only yield once when complete
# LLM tokens (thoughts before tool calls)
elif kind == "on_chat_model_stream":
chunk = event["data"]["chunk"]
logger.debug(f"Stream chunk type: {type(chunk).__name__}")
if hasattr(chunk, 'content'):
chunk_content = chunk.content
logger.debug(f"Chunk content type: {type(chunk_content).__name__}, value: {chunk_content!r}")
# Normalize chunk_content to string (can be list of content blocks from Anthropic API)
if isinstance(chunk_content, list):
# Extract text content from content blocks, skipping tool_use blocks
text_parts = []
for item in chunk_content:
if isinstance(item, dict):
# Extract text from text blocks, ignore tool_use blocks
if item.get('type') == 'text' and 'text' in item:
text_parts.append(item['text'])
else:
text_parts.append(str(item))
chunk_content = "".join(text_parts)
logger.debug(f"Extracted text from content blocks: {chunk_content!r}")
if chunk_content:
if current_thinking is None:
logger.warning(f"current_thinking is None but received content chunk")
current_thinking = {
"role": "assistant",
"content": "",
"metadata": {"title": "Thinking", "status": "pending"}
}
logger.debug(f"Before concatenation - current_thinking['content'] type: {type(current_thinking['content']).__name__}")
logger.debug(f"Adding chunk_content type: {type(chunk_content).__name__}, value: {chunk_content!r}")
current_thinking["content"] += chunk_content
current_thinking["metadata"]["status"] = "pending"
logger.debug(f"After concatenation - content type: {type(current_thinking['content']).__name__}, length: {len(current_thinking['content'])}")
# Don't yield on every chunk - only yield once when complete (on_chat_model_end)
else:
logger.debug(f"Chunk has no content attribute")
# Agent completed thinking
elif kind == "on_chat_model_end":
logger.info(f"Chat model ended")
if current_thinking:
elapsed = time.time() - start_time
# Don't show thinking content in chat - only show tool calls
# The thinking is internal reasoning, results are shown in structured format below
logger.info(f"Skipping thinking message display (internal reasoning only)")
current_thinking = None
# Don't yield thinking messages at all
# Tool execution started
elif kind == "on_tool_start":
tool_call_counter += 1
tool_name = event["name"]
tool_input = event["data"].get("input", {})
logger.info(f"Tool started: {tool_name} with input: {tool_input!r}")
tool_code = f"```python\n{tool_name}({tool_input})\n```"
logger.debug(f"Tool code type: {type(tool_code).__name__}")
current_tool_call = {
"role": "assistant",
"content": tool_code,
"metadata": {
"title": f"Tool: {tool_name}",
"status": "pending",
"log": f"Calling {tool_name}..."
}
}
logger.debug(f"Created tool call message (will yield when complete)")
# Don't yield on start - only yield once when complete with result
# Tool execution completed
elif kind == "on_tool_end":
logger.info(f"Tool ended")
if current_tool_call:
tool_output = event["data"].get("output")
elapsed = time.time() - start_time
logger.debug(f"Tool output type: {type(tool_output).__name__}, value: {str(tool_output)[:100]!r}")
# Normalize tool_output to string
if not isinstance(tool_output, str):
tool_output = str(tool_output)
logger.debug(f"Converted tool output to string")
# Truncate long outputs
if len(tool_output) > 300:
tool_output = tool_output[:300] + "..."
logger.debug(f"Before appending result - current_tool_call['content'] type: {type(current_tool_call['content']).__name__}")
logger.debug(f"Tool output to append type: {type(tool_output).__name__}")
result_str = f"\n\n**Result:** {tool_output}"
logger.debug(f"Result string type: {type(result_str).__name__}")
current_tool_call["content"] += result_str
current_tool_call["metadata"]["status"] = "done"
current_tool_call["metadata"]["duration"] = round(elapsed, 1)
current_tool_call["metadata"]["log"] = "Completed"
logger.debug(f"About to yield completed tool call")
yield current_tool_call
current_tool_call = None
except Exception as e:
logger.error(f"Error in event processing at line {traceback.format_exc().split('line ')[-1].split(',')[0] if 'line' in traceback.format_exc() else 'unknown'}")
logger.error(f"Event kind: {kind}")
logger.error(f"Current thinking: {current_thinking}")
logger.error(f"Current tool call: {current_tool_call}")
logger.error(f"Full traceback:\n{traceback.format_exc()}")
raise
# Final result
logger.info(f"Completed streaming {message_count} events. Getting final result...")
final_result = await self.workflow.ainvoke(
initial_state,
config=RunnableConfig(configurable={}, recursion_limit=25)
)
logger.info(f"Got final result with {len(final_result.get('messages', []))} messages")
reasoning_trace = self._extract_reasoning_trace(final_result["messages"])
portfolio = self._extract_portfolio(final_result["messages"])
final_content = final_result["messages"][-1].content if final_result["messages"] else ""
logger.info(f"Final content type: {type(final_content).__name__}")
logger.info(f"Extracted portfolio: {len(portfolio)} holdings")
# Convert content to string if it's a list of content blocks
if isinstance(final_content, list):
text_parts = []
for block in final_content:
if isinstance(block, dict):
if block.get("type") == "text":
text_parts.append(block.get("text", ""))
# Skip tool_use blocks - they're not displayable text
elif hasattr(block, 'text'):
text_parts.append(block.text)
final_content = "\n\n".join(text_parts)
logger.info(f"Converted content blocks to string, length: {len(final_content)}")
elif not isinstance(final_content, str):
final_content = str(final_content)
final_message = {
"role": "assistant",
"content": final_content or "Portfolio construction complete.",
"metadata": {
"portfolio": portfolio,
"reasoning_trace": reasoning_trace
}
}
logger.debug(f"About to yield final message with structured portfolio data")
logger.debug(f"Final message content type: {type(final_message['content']).__name__}")
yield final_message
except Exception as e:
logger.error(f"Exception in run_stream at: {traceback.format_exc()}")
logger.error(f"Exception type: {type(e).__name__}")
logger.error(f"Exception message: {str(e)}")
# Extract line number from traceback
import sys
tb = sys.exc_info()[2]
if tb:
frame = tb.tb_frame
logger.error(f"Error occurred in {frame.f_code.co_filename}:{tb.tb_lineno} in {frame.f_code.co_name}")
# Yield error message to user
yield {
"role": "assistant",
"content": f"Error during portfolio construction: {str(e)}",
"metadata": {"error": True, "error_type": type(e).__name__}
}
raise
def _extract_reasoning_trace(self, messages: list) -> list[dict]:
"""Extract reasoning trace from message history.
Args:
messages: List of messages from the workflow
Returns:
List of reasoning steps with thoughts, actions, and observations
"""
reasoning_trace = []
for msg in messages:
if hasattr(msg, 'tool_calls') and msg.tool_calls:
for tc in msg.tool_calls:
reasoning_trace.append({
"thought": msg.content if msg.content else "Executing tool",
"action": tc["name"],
"args": tc["args"]
})
elif isinstance(msg, ToolMessage):
content = msg.content
if isinstance(content, str) and len(content) > 500:
content = content[:500] + "..."
reasoning_trace.append({
"observation": content
})
return reasoning_trace
def _extract_portfolio(self, messages: list) -> list[dict]:
"""Extract portfolio recommendations from final message.
Args:
messages: List of messages from the workflow
Returns:
List of portfolio holdings with ticker, allocation, and reasoning
"""
if not messages:
return []
final_message = messages[-1]
content = final_message.content if hasattr(final_message, 'content') else ""
# Handle content that is a list of content blocks (from Anthropic API)
if isinstance(content, list):
# Extract text from content blocks
text_parts = []
for block in content:
if isinstance(block, dict):
if block.get("type") == "text":
text_parts.append(block.get("text", ""))
# Skip tool_use blocks
elif hasattr(block, 'text'):
text_parts.append(block.text)
content = "\n".join(text_parts)
logger.debug(f"Extracted text from {len(text_parts)} content blocks")
elif not isinstance(content, str):
content = str(content)
logger.debug(f"Content to parse (first 500 chars): {content[:500]}")
portfolio = []
lines = content.split('\n')
logger.debug(f"Split into {len(lines)} lines")
for line in lines:
# Updated regex to handle crypto tickers (BTC-USD), stock tickers (AAPL), and ETFs
# Supports various bullet points: -, β€’, β—‹, β—¦, etc.
match = re.match(r'[β—‹β—¦β€’\-]\s*([A-Z0-9\-\.]{1,10}):\s*(\d+(?:\.\d+)?)\s*%\s*[-–]\s*(.+)', line)
if match:
logger.debug(f"Matched line: {line}")
portfolio.append({
"ticker": match.group(1),
"allocation": float(match.group(2)),
"reasoning": match.group(3).strip()
})
elif line.strip() and any(char.isupper() for char in line):
# Log non-empty lines that contain uppercase (might be portfolio entries)
logger.debug(f"No match for line: {line}")
return portfolio