Spaces:
Running
Running
| """ | |
| Router Agent: Decides how to improve an insufficient plan. | |
| When the verifier determines the plan is insufficient, the router decides: | |
| - "Add Step": Add a new step to the plan | |
| - "Step N": Backtrack to step N and fix it | |
| """ | |
| import re | |
| from langchain_core.messages import AIMessage | |
| from ..utils.formatters import format_data_descriptions, format_plan, gemini_text | |
| from ..utils.state import DSStarState | |
| def router_node(state: DSStarState) -> dict: | |
| """ | |
| Router Agent Node: Decides how to improve the plan. | |
| Analyzes the current situation and determines whether to: | |
| 1. Add a new step to the plan | |
| 2. Backtrack and fix an existing step | |
| Args: | |
| state: Current DSStarState | |
| Returns: | |
| Dictionary with updated state fields: | |
| - router_decision: "Add Step" or "Step N" | |
| - iteration: Incremented iteration count | |
| - messages: Agent communication messages | |
| - next: "planner" (add step) or "backtrack" (fix step) | |
| """ | |
| print("=" * 60) | |
| print("ROUTER AGENT STARTING...") | |
| print("=" * 60) | |
| data_context = format_data_descriptions(state["data_descriptions"]) | |
| plan_text = format_plan(state["plan"]) | |
| prompt = f"""You are an expert data analyst router. | |
| The current plan is INSUFFICIENT to answer the question. | |
| Original Question: {state["query"]} | |
| Available Data: | |
| {data_context} | |
| Current Plan: | |
| {plan_text} | |
| Execution Result: | |
| {state["execution_result"][:500]} | |
| Task: Decide how to improve the plan: | |
| 1. If a current step is WRONG or needs fixing: Answer "Step N" (where N is the step number, e.g., "Step 2") | |
| 2. If we need to ADD a NEW step: Answer "Add Step" | |
| Answer with ONLY: "Step 1", "Step 2", etc. OR "Add Step" | |
| No explanation needed.""" | |
| try: | |
| # Get LLM response | |
| response = state["llm"].invoke(prompt) | |
| # Handle different response formats | |
| if hasattr(response, "content") and isinstance(response.content, list): | |
| response_text = gemini_text(response) | |
| elif hasattr(response, "content"): | |
| response_text = response.content | |
| else: | |
| response_text = str(response) | |
| # Parse decision | |
| response_lower = response_text.strip().lower() | |
| if "add step" in response_lower: | |
| decision = "Add Step" | |
| next_node = "planner" | |
| else: | |
| # Try to extract step number | |
| match = re.search(r"step\s+(\d+)", response_lower) | |
| if match: | |
| decision = f"Step {match.group(1)}" | |
| next_node = "backtrack" | |
| else: | |
| # Default to adding new step | |
| decision = "Add Step" | |
| next_node = "planner" | |
| print(f"\nRouter Decision: {decision}") | |
| print( | |
| f"Next Action: {'Backtrack' if next_node == 'backtrack' else 'Add New Step'}" | |
| ) | |
| print("=" * 60) | |
| return { | |
| "router_decision": decision, | |
| "messages": [AIMessage(content=f"Router: {decision}")], | |
| "iteration": state["iteration"] + 1, | |
| "next": next_node, | |
| } | |
| except Exception as e: | |
| # On error, default to adding new step | |
| print(f"\n✗ Router error: {str(e)}") | |
| print("Defaulting to 'Add Step'") | |
| return { | |
| "router_decision": "Add Step", | |
| "messages": [AIMessage(content=f"Router error, adding step: {str(e)}")], | |
| "iteration": state["iteration"] + 1, | |
| "next": "planner", | |
| } | |
| def backtrack_node(state: DSStarState) -> dict: | |
| """ | |
| Backtrack Node: Truncates plan to remove incorrect steps. | |
| When router identifies a wrong step, this node: | |
| 1. Parses the step number from router_decision | |
| 2. Truncates the plan to remove that step and all subsequent steps | |
| 3. Routes back to planner to regenerate from that point | |
| Args: | |
| state: Current DSStarState | |
| Returns: | |
| Dictionary with updated state fields: | |
| - plan: Truncated plan | |
| - messages: Agent communication messages | |
| - next: "planner" to regenerate from truncation point | |
| """ | |
| print("=" * 60) | |
| print("BACKTRACK NODE ACTIVATING...") | |
| print("=" * 60) | |
| try: | |
| # Extract step number from router decision | |
| match = re.search(r"step\s+(\d+)", state["router_decision"].lower()) | |
| if match: | |
| step_num = int(match.group(1)) | |
| else: | |
| # If parsing fails, just add new step | |
| print("Failed to parse step number, adding new step instead") | |
| return { | |
| "messages": [ | |
| AIMessage(content="Backtrack parsing failed, adding new step") | |
| ], | |
| "next": "planner", | |
| } | |
| # Truncate plan to steps before the wrong one | |
| # Keep steps 0 to (step_num - 2), which are steps 1 to (step_num - 1) in human counting | |
| truncated_plan = state["plan"][: step_num - 1] if step_num > 1 else [] | |
| print( | |
| f"Truncating plan from {len(state['plan'])} to {len(truncated_plan)} steps" | |
| ) | |
| print(f"Removed step {step_num} and beyond") | |
| print("=" * 60) | |
| # Return the truncated plan (replaces entire plan, not appends) | |
| return { | |
| "plan": truncated_plan, | |
| "messages": [AIMessage(content=f"Backtracked to step {step_num - 1}")], | |
| "next": "planner", | |
| } | |
| except Exception as e: | |
| print(f"✗ Backtrack error: {str(e)}") | |
| return { | |
| "messages": [AIMessage(content=f"Backtrack error: {str(e)}, continuing")], | |
| "next": "planner", | |
| } | |
| # Standalone test function | |
| def test_router( | |
| llm, query: str, data_descriptions: dict, plan: list, execution_result: str | |
| ): | |
| """ | |
| Test the router agent independently. | |
| Args: | |
| llm: LLM instance | |
| query: User query | |
| data_descriptions: Dict of filename -> description | |
| plan: Current plan steps | |
| execution_result: Result from code execution | |
| Returns: | |
| Dictionary with router results | |
| """ | |
| # Create minimal test state | |
| test_state = { | |
| "llm": llm, | |
| "query": query, | |
| "data_descriptions": data_descriptions, | |
| "plan": plan, | |
| "current_code": "", | |
| "execution_result": execution_result, | |
| "is_sufficient": False, | |
| "router_decision": "", | |
| "iteration": 0, | |
| "max_iterations": 20, | |
| "messages": [], | |
| "next": "router", | |
| } | |
| result = router_node(test_state) | |
| print("\n" + "=" * 60) | |
| print("ROUTER TEST RESULTS") | |
| print("=" * 60) | |
| print(f"Decision: {result.get('router_decision', 'unknown')}") | |
| print(f"Next Node: {result.get('next', 'unknown')}") | |
| return result | |