File size: 6,955 Bytes
8ff817c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""

Router Agent: Decides how to improve an insufficient plan.



When the verifier determines the plan is insufficient, the router decides:

- "Add Step": Add a new step to the plan

- "Step N": Backtrack to step N and fix it

"""

import re

from langchain_core.messages import AIMessage

from ..utils.formatters import format_data_descriptions, format_plan, gemini_text
from ..utils.state import DSStarState


def router_node(state: DSStarState) -> dict:
    """

    Router Agent Node: Decides how to improve the plan.



    Analyzes the current situation and determines whether to:

    1. Add a new step to the plan

    2. Backtrack and fix an existing step



    Args:

        state: Current DSStarState



    Returns:

        Dictionary with updated state fields:

        - router_decision: "Add Step" or "Step N"

        - iteration: Incremented iteration count

        - messages: Agent communication messages

        - next: "planner" (add step) or "backtrack" (fix step)

    """
    print("=" * 60)
    print("ROUTER AGENT STARTING...")
    print("=" * 60)

    data_context = format_data_descriptions(state["data_descriptions"])
    plan_text = format_plan(state["plan"])

    prompt = f"""You are an expert data analyst router.



The current plan is INSUFFICIENT to answer the question.



Original Question: {state["query"]}



Available Data:

{data_context}



Current Plan:

{plan_text}



Execution Result:

{state["execution_result"][:500]}



Task: Decide how to improve the plan:

1. If a current step is WRONG or needs fixing: Answer "Step N" (where N is the step number, e.g., "Step 2")

2. If we need to ADD a NEW step: Answer "Add Step"



Answer with ONLY: "Step 1", "Step 2", etc. OR "Add Step"

No explanation needed."""

    try:
        # Get LLM response
        response = state["llm"].invoke(prompt)

        # Handle different response formats
        if hasattr(response, "content") and isinstance(response.content, list):
            response_text = gemini_text(response)
        elif hasattr(response, "content"):
            response_text = response.content
        else:
            response_text = str(response)

        # Parse decision
        response_lower = response_text.strip().lower()
        if "add step" in response_lower:
            decision = "Add Step"
            next_node = "planner"
        else:
            # Try to extract step number
            match = re.search(r"step\s+(\d+)", response_lower)
            if match:
                decision = f"Step {match.group(1)}"
                next_node = "backtrack"
            else:
                # Default to adding new step
                decision = "Add Step"
                next_node = "planner"

        print(f"\nRouter Decision: {decision}")
        print(
            f"Next Action: {'Backtrack' if next_node == 'backtrack' else 'Add New Step'}"
        )
        print("=" * 60)

        return {
            "router_decision": decision,
            "messages": [AIMessage(content=f"Router: {decision}")],
            "iteration": state["iteration"] + 1,
            "next": next_node,
        }

    except Exception as e:
        # On error, default to adding new step
        print(f"\n✗ Router error: {str(e)}")
        print("Defaulting to 'Add Step'")
        return {
            "router_decision": "Add Step",
            "messages": [AIMessage(content=f"Router error, adding step: {str(e)}")],
            "iteration": state["iteration"] + 1,
            "next": "planner",
        }


def backtrack_node(state: DSStarState) -> dict:
    """

    Backtrack Node: Truncates plan to remove incorrect steps.



    When router identifies a wrong step, this node:

    1. Parses the step number from router_decision

    2. Truncates the plan to remove that step and all subsequent steps

    3. Routes back to planner to regenerate from that point



    Args:

        state: Current DSStarState



    Returns:

        Dictionary with updated state fields:

        - plan: Truncated plan

        - messages: Agent communication messages

        - next: "planner" to regenerate from truncation point

    """
    print("=" * 60)
    print("BACKTRACK NODE ACTIVATING...")
    print("=" * 60)

    try:
        # Extract step number from router decision
        match = re.search(r"step\s+(\d+)", state["router_decision"].lower())
        if match:
            step_num = int(match.group(1))
        else:
            # If parsing fails, just add new step
            print("Failed to parse step number, adding new step instead")
            return {
                "messages": [
                    AIMessage(content="Backtrack parsing failed, adding new step")
                ],
                "next": "planner",
            }

        # Truncate plan to steps before the wrong one
        # Keep steps 0 to (step_num - 2), which are steps 1 to (step_num - 1) in human counting
        truncated_plan = state["plan"][: step_num - 1] if step_num > 1 else []

        print(
            f"Truncating plan from {len(state['plan'])} to {len(truncated_plan)} steps"
        )
        print(f"Removed step {step_num} and beyond")
        print("=" * 60)

        # Return the truncated plan (replaces entire plan, not appends)
        return {
            "plan": truncated_plan,
            "messages": [AIMessage(content=f"Backtracked to step {step_num - 1}")],
            "next": "planner",
        }

    except Exception as e:
        print(f"✗ Backtrack error: {str(e)}")
        return {
            "messages": [AIMessage(content=f"Backtrack error: {str(e)}, continuing")],
            "next": "planner",
        }


# Standalone test function
def test_router(

    llm, query: str, data_descriptions: dict, plan: list, execution_result: str

):
    """

    Test the router agent independently.



    Args:

        llm: LLM instance

        query: User query

        data_descriptions: Dict of filename -> description

        plan: Current plan steps

        execution_result: Result from code execution



    Returns:

        Dictionary with router results

    """
    # Create minimal test state
    test_state = {
        "llm": llm,
        "query": query,
        "data_descriptions": data_descriptions,
        "plan": plan,
        "current_code": "",
        "execution_result": execution_result,
        "is_sufficient": False,
        "router_decision": "",
        "iteration": 0,
        "max_iterations": 20,
        "messages": [],
        "next": "router",
    }

    result = router_node(test_state)

    print("\n" + "=" * 60)
    print("ROUTER TEST RESULTS")
    print("=" * 60)
    print(f"Decision: {result.get('router_decision', 'unknown')}")
    print(f"Next Node: {result.get('next', 'unknown')}")

    return result