File size: 6,862 Bytes
8ff817c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
"""

DS-STAR Graph: Connects all agents into a workflow.



This module builds the LangGraph StateGraph that orchestrates the multi-agent system.

"""

from typing import Literal

from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph

from .agents.analyzer_agent import analyzer_node
from .agents.coder_agent import coder_node
from .agents.finalyzer_agent import finalyzer_node
from .agents.planner_agent import planner_node
from .agents.router_agent import backtrack_node, router_node
from .agents.verifier_agent import verifier_node
from .utils.state import DSStarState

# ==================== CONDITIONAL ROUTING FUNCTIONS ====================


def route_after_analyzer(state: DSStarState) -> Literal["planner", "__end__"]:
    """

    Route after analyzer based on success.



    If analyzer found errors, end workflow.

    Otherwise, proceed to planner.

    """
    if "error" in state.get("data_descriptions", {}):
        return "__end__"
    return state.get("next", "planner")


def route_after_planner(state: DSStarState) -> Literal["coder", "__end__"]:
    """Route after planner to coder."""
    return state.get("next", "coder")


def route_after_coder(state: DSStarState) -> Literal["verifier", "__end__"]:
    """Route after coder to verifier."""
    return state.get("next", "verifier")


def route_after_verifier(

    state: DSStarState,

) -> Literal["router", "finalyzer", "__end__"]:
    """

    Route after verifier based on sufficiency and iteration count.



    If max iterations reached, go to finalyzer.

    If sufficient, go to finalyzer.

    Otherwise, go to router to decide next action.

    """
    # Check max iterations
    if state["iteration"] >= state["max_iterations"]:
        print(f"\n⚠ Max iterations ({state['max_iterations']}) reached, finalizing...")
        return "finalyzer"

    return state.get("next", "router")


def route_after_router(state: DSStarState) -> Literal["planner", "backtrack"]:
    """

    Route after router based on decision.



    If router says "Add Step", go to planner.

    If router says "Step N", go to backtrack.

    """
    return state.get("next", "planner")


# ==================== GRAPH BUILDER ====================


def build_ds_star_graph(llm, max_iterations: int = 20):
    """

    Constructs the LangGraph workflow for DS-STAR.



    The workflow follows this pattern:

    1. Analyzer: Analyze data files (runs once)

    2. Planner: Generate next plan step

    3. Coder: Implement plan as code

    4. Verifier: Check if sufficient

    5. If insufficient:

       a. Router: Decide to add step or backtrack

       b. Backtrack (optional): Remove wrong steps

       c. Go back to Planner

    6. If sufficient:

       Finalyzer: Create polished final solution



    Args:

        llm: LLM instance (e.g., ChatOpenAI, ChatGoogleGenerativeAI)

        max_iterations: Maximum refinement iterations (default: 20)



    Returns:

        Compiled LangGraph application with checkpointing

    """
    # Initialize graph with state schema
    workflow = StateGraph(DSStarState)

    # Add all agent nodes
    workflow.add_node("analyzer", analyzer_node)
    workflow.add_node("planner", planner_node)
    workflow.add_node("coder", coder_node)
    workflow.add_node("verifier", verifier_node)
    workflow.add_node("router", router_node)
    workflow.add_node("backtrack", backtrack_node)
    workflow.add_node("finalyzer", finalyzer_node)

    # Set entry point
    workflow.set_entry_point("analyzer")

    # Add conditional edges with proper routing
    workflow.add_conditional_edges(
        "analyzer", route_after_analyzer, {"planner": "planner", "__end__": END}
    )

    workflow.add_conditional_edges(
        "planner", route_after_planner, {"coder": "coder", "__end__": END}
    )

    workflow.add_conditional_edges(
        "coder", route_after_coder, {"verifier": "verifier", "__end__": END}
    )

    workflow.add_conditional_edges(
        "verifier",
        route_after_verifier,
        {"router": "router", "finalyzer": "finalyzer", "__end__": END},
    )

    workflow.add_conditional_edges(
        "router", route_after_router, {"planner": "planner", "backtrack": "backtrack"}
    )

    workflow.add_edge("backtrack", "planner")
    workflow.add_edge("finalyzer", END)

    # Add memory/checkpointing
    memory = MemorySaver()

    # Compile graph
    app = workflow.compile(checkpointer=memory)

    return app


def create_initial_state(query: str, llm, max_iterations: int = 20) -> DSStarState:
    """

    Create initial state for the DS-STAR workflow.



    Args:

        query: User's question to answer

        llm: LLM instance

        max_iterations: Maximum refinement iterations



    Returns:

        Initial DSStarState dictionary

    """
    return {
        "query": query,
        "data_descriptions": {},
        "plan": [],
        "current_code": "",
        "execution_result": "",
        "is_sufficient": False,
        "router_decision": "",
        "iteration": 0,
        "max_iterations": max_iterations,
        "messages": [],
        "next": "analyzer",
        "llm": llm,
    }


def run_ds_star(

    query: str, llm, max_iterations: int = 20, thread_id: str = "ds-star-1"

):
    """

    Run the complete DS-STAR workflow.



    Args:

        query: User's question to answer

        llm: LLM instance

        max_iterations: Maximum refinement iterations

        thread_id: Unique thread ID for checkpointing



    Returns:

        Final state after workflow completion

    """
    print("=" * 60)
    print("DS-STAR MULTI-AGENT SYSTEM")
    print("=" * 60)
    print(f"Query: {query}")
    print(f"Max Iterations: {max_iterations}")
    print("=" * 60)

    # Build graph
    app = build_ds_star_graph(llm, max_iterations)

    # Create initial state
    initial_state = create_initial_state(query, llm, max_iterations)

    # Run with checkpointing
    config = {"configurable": {"thread_id": thread_id}}

    try:
        # Execute the workflow
        final_state = app.invoke(initial_state, config)

        # Display results
        print("\n" + "=" * 60)
        print("FINAL SOLUTION")
        print("=" * 60)
        print("\nGenerated Code:")
        print("-" * 60)
        print(final_state["current_code"])
        print("\n" + "-" * 60)
        print("Execution Result:")
        print("-" * 60)
        print(final_state["execution_result"])
        print("=" * 60)

        return final_state

    except Exception as e:
        print(f"\n✗ Error during execution: {str(e)}")
        import traceback

        traceback.print_exc()
        return None