"""LangGraph agent state models. This module defines the state schema for multi-agent orchestration using LangGraph. """ from typing import TypedDict, Annotated, Sequence, Dict, Any, Optional, List from datetime import datetime, timezone import operator from pydantic import BaseModel, Field def merge_dicts(current: Dict[str, Any], new: Dict[str, Any]) -> Dict[str, Any]: """Merge two dictionaries for LangGraph state updates. Args: current: Current dictionary value new: New dictionary to merge in Returns: Merged dictionary with new values overwriting current """ if current is None: return new or {} if new is None: return current return {**current, **new} def last_value(current: Any, new: Any) -> Any: """Reducer that keeps the last non-None value. Args: current: Current value new: New value Returns: New value if not None, otherwise current """ return new if new is not None else current def sum_or_last(current: Optional[int], new: Optional[int]) -> Optional[int]: """Reducer for metrics that sums values or keeps last. Args: current: Current value new: New value Returns: Sum if both are set, otherwise the non-None value """ if current is None: return new if new is None: return current return current + new class AgentState(TypedDict): """LangGraph state for multi-agent portfolio analysis workflow. This state is passed between agents in the workflow and accumulates results from each phase of the analysis. """ # Input (use last_value reducer for parallel safety) portfolio_id: Annotated[str, last_value] user_query: Annotated[str, last_value] risk_tolerance: Annotated[str, last_value] holdings: Annotated[List[Dict[str, Any]], last_value] # Phase 1: Data Layer Results (from MCPs) historical_prices: Annotated[Dict[str, Any], merge_dicts] fundamentals: Annotated[Dict[str, Any], merge_dicts] economic_data: Annotated[Dict[str, Any], merge_dicts] realtime_data: Annotated[Dict[str, Any], merge_dicts] technical_indicators: Annotated[Dict[str, Any], merge_dicts] sentiment_data: Annotated[Dict[str, Any], merge_dicts] # Enhancement #3: News Sentiment MCP # Phase 1.5: Feature Engineering feature_vectors: Annotated[Dict[str, Any], merge_dicts] # Phase 2: Computation Layer Results optimisation_results: Annotated[Dict[str, Any], merge_dicts] risk_analysis: Annotated[Dict[str, Any], merge_dicts] # Phase 2.5: ML Predictions (P1) ensemble_forecasts: Annotated[Dict[str, Any], merge_dicts] # Phase 3: LLM Synthesis ai_synthesis: Annotated[str, last_value] recommendations: Annotated[List[str], last_value] reasoning_steps: Annotated[List[str], operator.add] # Metadata current_step: Annotated[str, last_value] errors: Annotated[List[str], operator.add] mcp_calls: Annotated[List[Dict[str, Any]], operator.add] # Performance Metrics (sum for parallel branches) phase_1_duration_ms: Annotated[Optional[int], last_value] phase_1_5_duration_ms: Annotated[Optional[int], last_value] phase_2_duration_ms: Annotated[Optional[int], last_value] phase_2_5_duration_ms: Annotated[Optional[int], last_value] phase_3_duration_ms: Annotated[Optional[int], last_value] llm_input_tokens: Annotated[Optional[int], sum_or_last] llm_output_tokens: Annotated[Optional[int], sum_or_last] llm_total_tokens: Annotated[Optional[int], sum_or_last] llm_request_count: Annotated[Optional[int], sum_or_last] class MCPCall(BaseModel): """Record of an MCP tool call. Accepts both 'mcp_server' and 'mcp' field names for backward compatibility. """ model_config = {"populate_by_name": True} mcp_server: str = Field(..., validation_alias="mcp", description="MCP server name") tool_name: str = Field(..., validation_alias="tool", description="Tool called") parameters: Dict[str, Any] = Field(default_factory=dict) result: Optional[Dict[str, Any]] = None error: Optional[str] = None duration_ms: Optional[int] = Field(None, ge=0) timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) class AgentMessage(BaseModel): """Message from an agent with metadata.""" role: str = Field(..., description="Agent role (user/assistant/system)") content: str = Field(..., min_length=1) agent_name: Optional[str] = Field(None, description="Name of agent that generated message") thinking: Optional[str] = Field(None, description="Agent reasoning") tools_used: Optional[List[MCPCall]] = Field(default_factory=list) confidence: Optional[float] = Field(None, ge=0.0, le=1.0) timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) class WorkflowStatus(BaseModel): """Status of the multi-agent workflow.""" session_id: str current_phase: str = Field(..., description="Current execution phase") phase_1_complete: bool = Field(default=False) phase_2_complete: bool = Field(default=False) phase_3_complete: bool = Field(default=False) errors: List[str] = Field(default_factory=list) mcp_calls: List[MCPCall] = Field(default_factory=list) started_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) completed_at: Optional[datetime] = None execution_time_ms: Optional[int] = None