""" Router Agent: Decides how to improve an insufficient plan. When the verifier determines the plan is insufficient, the router decides: - "Add Step": Add a new step to the plan - "Step N": Backtrack to step N and fix it """ import re from langchain_core.messages import AIMessage from ..utils.formatters import format_data_descriptions, format_plan, gemini_text from ..utils.state import DSStarState def router_node(state: DSStarState) -> dict: """ Router Agent Node: Decides how to improve the plan. Analyzes the current situation and determines whether to: 1. Add a new step to the plan 2. Backtrack and fix an existing step Args: state: Current DSStarState Returns: Dictionary with updated state fields: - router_decision: "Add Step" or "Step N" - iteration: Incremented iteration count - messages: Agent communication messages - next: "planner" (add step) or "backtrack" (fix step) """ print("=" * 60) print("ROUTER AGENT STARTING...") print("=" * 60) data_context = format_data_descriptions(state["data_descriptions"]) plan_text = format_plan(state["plan"]) prompt = f"""You are an expert data analyst router. The current plan is INSUFFICIENT to answer the question. Original Question: {state["query"]} Available Data: {data_context} Current Plan: {plan_text} Execution Result: {state["execution_result"][:500]} Task: Decide how to improve the plan: 1. If a current step is WRONG or needs fixing: Answer "Step N" (where N is the step number, e.g., "Step 2") 2. If we need to ADD a NEW step: Answer "Add Step" Answer with ONLY: "Step 1", "Step 2", etc. OR "Add Step" No explanation needed.""" try: # Get LLM response response = state["llm"].invoke(prompt) # Handle different response formats if hasattr(response, "content") and isinstance(response.content, list): response_text = gemini_text(response) elif hasattr(response, "content"): response_text = response.content else: response_text = str(response) # Parse decision response_lower = response_text.strip().lower() if "add step" in response_lower: decision = "Add Step" next_node = "planner" else: # Try to extract step number match = re.search(r"step\s+(\d+)", response_lower) if match: decision = f"Step {match.group(1)}" next_node = "backtrack" else: # Default to adding new step decision = "Add Step" next_node = "planner" print(f"\nRouter Decision: {decision}") print( f"Next Action: {'Backtrack' if next_node == 'backtrack' else 'Add New Step'}" ) print("=" * 60) return { "router_decision": decision, "messages": [AIMessage(content=f"Router: {decision}")], "iteration": state["iteration"] + 1, "next": next_node, } except Exception as e: # On error, default to adding new step print(f"\nāœ— Router error: {str(e)}") print("Defaulting to 'Add Step'") return { "router_decision": "Add Step", "messages": [AIMessage(content=f"Router error, adding step: {str(e)}")], "iteration": state["iteration"] + 1, "next": "planner", } def backtrack_node(state: DSStarState) -> dict: """ Backtrack Node: Truncates plan to remove incorrect steps. When router identifies a wrong step, this node: 1. Parses the step number from router_decision 2. Truncates the plan to remove that step and all subsequent steps 3. Routes back to planner to regenerate from that point Args: state: Current DSStarState Returns: Dictionary with updated state fields: - plan: Truncated plan - messages: Agent communication messages - next: "planner" to regenerate from truncation point """ print("=" * 60) print("BACKTRACK NODE ACTIVATING...") print("=" * 60) try: # Extract step number from router decision match = re.search(r"step\s+(\d+)", state["router_decision"].lower()) if match: step_num = int(match.group(1)) else: # If parsing fails, just add new step print("Failed to parse step number, adding new step instead") return { "messages": [ AIMessage(content="Backtrack parsing failed, adding new step") ], "next": "planner", } # Truncate plan to steps before the wrong one # Keep steps 0 to (step_num - 2), which are steps 1 to (step_num - 1) in human counting truncated_plan = state["plan"][: step_num - 1] if step_num > 1 else [] print( f"Truncating plan from {len(state['plan'])} to {len(truncated_plan)} steps" ) print(f"Removed step {step_num} and beyond") print("=" * 60) # Return the truncated plan (replaces entire plan, not appends) return { "plan": truncated_plan, "messages": [AIMessage(content=f"Backtracked to step {step_num - 1}")], "next": "planner", } except Exception as e: print(f"āœ— Backtrack error: {str(e)}") return { "messages": [AIMessage(content=f"Backtrack error: {str(e)}, continuing")], "next": "planner", } # Standalone test function def test_router( llm, query: str, data_descriptions: dict, plan: list, execution_result: str ): """ Test the router agent independently. Args: llm: LLM instance query: User query data_descriptions: Dict of filename -> description plan: Current plan steps execution_result: Result from code execution Returns: Dictionary with router results """ # Create minimal test state test_state = { "llm": llm, "query": query, "data_descriptions": data_descriptions, "plan": plan, "current_code": "", "execution_result": execution_result, "is_sufficient": False, "router_decision": "", "iteration": 0, "max_iterations": 20, "messages": [], "next": "router", } result = router_node(test_state) print("\n" + "=" * 60) print("ROUTER TEST RESULTS") print("=" * 60) print(f"Decision: {result.get('router_decision', 'unknown')}") print(f"Next Node: {result.get('next', 'unknown')}") return result