DS-STAR / src /agents /router_agent.py
anurag-deo's picture
Upload folder using huggingface_hub
8ff817c verified
"""
Router Agent: Decides how to improve an insufficient plan.
When the verifier determines the plan is insufficient, the router decides:
- "Add Step": Add a new step to the plan
- "Step N": Backtrack to step N and fix it
"""
import re
from langchain_core.messages import AIMessage
from ..utils.formatters import format_data_descriptions, format_plan, gemini_text
from ..utils.state import DSStarState
def router_node(state: DSStarState) -> dict:
"""
Router Agent Node: Decides how to improve the plan.
Analyzes the current situation and determines whether to:
1. Add a new step to the plan
2. Backtrack and fix an existing step
Args:
state: Current DSStarState
Returns:
Dictionary with updated state fields:
- router_decision: "Add Step" or "Step N"
- iteration: Incremented iteration count
- messages: Agent communication messages
- next: "planner" (add step) or "backtrack" (fix step)
"""
print("=" * 60)
print("ROUTER AGENT STARTING...")
print("=" * 60)
data_context = format_data_descriptions(state["data_descriptions"])
plan_text = format_plan(state["plan"])
prompt = f"""You are an expert data analyst router.
The current plan is INSUFFICIENT to answer the question.
Original Question: {state["query"]}
Available Data:
{data_context}
Current Plan:
{plan_text}
Execution Result:
{state["execution_result"][:500]}
Task: Decide how to improve the plan:
1. If a current step is WRONG or needs fixing: Answer "Step N" (where N is the step number, e.g., "Step 2")
2. If we need to ADD a NEW step: Answer "Add Step"
Answer with ONLY: "Step 1", "Step 2", etc. OR "Add Step"
No explanation needed."""
try:
# Get LLM response
response = state["llm"].invoke(prompt)
# Handle different response formats
if hasattr(response, "content") and isinstance(response.content, list):
response_text = gemini_text(response)
elif hasattr(response, "content"):
response_text = response.content
else:
response_text = str(response)
# Parse decision
response_lower = response_text.strip().lower()
if "add step" in response_lower:
decision = "Add Step"
next_node = "planner"
else:
# Try to extract step number
match = re.search(r"step\s+(\d+)", response_lower)
if match:
decision = f"Step {match.group(1)}"
next_node = "backtrack"
else:
# Default to adding new step
decision = "Add Step"
next_node = "planner"
print(f"\nRouter Decision: {decision}")
print(
f"Next Action: {'Backtrack' if next_node == 'backtrack' else 'Add New Step'}"
)
print("=" * 60)
return {
"router_decision": decision,
"messages": [AIMessage(content=f"Router: {decision}")],
"iteration": state["iteration"] + 1,
"next": next_node,
}
except Exception as e:
# On error, default to adding new step
print(f"\n✗ Router error: {str(e)}")
print("Defaulting to 'Add Step'")
return {
"router_decision": "Add Step",
"messages": [AIMessage(content=f"Router error, adding step: {str(e)}")],
"iteration": state["iteration"] + 1,
"next": "planner",
}
def backtrack_node(state: DSStarState) -> dict:
"""
Backtrack Node: Truncates plan to remove incorrect steps.
When router identifies a wrong step, this node:
1. Parses the step number from router_decision
2. Truncates the plan to remove that step and all subsequent steps
3. Routes back to planner to regenerate from that point
Args:
state: Current DSStarState
Returns:
Dictionary with updated state fields:
- plan: Truncated plan
- messages: Agent communication messages
- next: "planner" to regenerate from truncation point
"""
print("=" * 60)
print("BACKTRACK NODE ACTIVATING...")
print("=" * 60)
try:
# Extract step number from router decision
match = re.search(r"step\s+(\d+)", state["router_decision"].lower())
if match:
step_num = int(match.group(1))
else:
# If parsing fails, just add new step
print("Failed to parse step number, adding new step instead")
return {
"messages": [
AIMessage(content="Backtrack parsing failed, adding new step")
],
"next": "planner",
}
# Truncate plan to steps before the wrong one
# Keep steps 0 to (step_num - 2), which are steps 1 to (step_num - 1) in human counting
truncated_plan = state["plan"][: step_num - 1] if step_num > 1 else []
print(
f"Truncating plan from {len(state['plan'])} to {len(truncated_plan)} steps"
)
print(f"Removed step {step_num} and beyond")
print("=" * 60)
# Return the truncated plan (replaces entire plan, not appends)
return {
"plan": truncated_plan,
"messages": [AIMessage(content=f"Backtracked to step {step_num - 1}")],
"next": "planner",
}
except Exception as e:
print(f"✗ Backtrack error: {str(e)}")
return {
"messages": [AIMessage(content=f"Backtrack error: {str(e)}, continuing")],
"next": "planner",
}
# Standalone test function
def test_router(
llm, query: str, data_descriptions: dict, plan: list, execution_result: str
):
"""
Test the router agent independently.
Args:
llm: LLM instance
query: User query
data_descriptions: Dict of filename -> description
plan: Current plan steps
execution_result: Result from code execution
Returns:
Dictionary with router results
"""
# Create minimal test state
test_state = {
"llm": llm,
"query": query,
"data_descriptions": data_descriptions,
"plan": plan,
"current_code": "",
"execution_result": execution_result,
"is_sufficient": False,
"router_decision": "",
"iteration": 0,
"max_iterations": 20,
"messages": [],
"next": "router",
}
result = router_node(test_state)
print("\n" + "=" * 60)
print("ROUTER TEST RESULTS")
print("=" * 60)
print(f"Decision: {result.get('router_decision', 'unknown')}")
print(f"Next Node: {result.get('next', 'unknown')}")
return result