Spaces:
Running
Running
File size: 6,862 Bytes
8ff817c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
"""
DS-STAR Graph: Connects all agents into a workflow.
This module builds the LangGraph StateGraph that orchestrates the multi-agent system.
"""
from typing import Literal
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from .agents.analyzer_agent import analyzer_node
from .agents.coder_agent import coder_node
from .agents.finalyzer_agent import finalyzer_node
from .agents.planner_agent import planner_node
from .agents.router_agent import backtrack_node, router_node
from .agents.verifier_agent import verifier_node
from .utils.state import DSStarState
# ==================== CONDITIONAL ROUTING FUNCTIONS ====================
def route_after_analyzer(state: DSStarState) -> Literal["planner", "__end__"]:
"""
Route after analyzer based on success.
If analyzer found errors, end workflow.
Otherwise, proceed to planner.
"""
if "error" in state.get("data_descriptions", {}):
return "__end__"
return state.get("next", "planner")
def route_after_planner(state: DSStarState) -> Literal["coder", "__end__"]:
"""Route after planner to coder."""
return state.get("next", "coder")
def route_after_coder(state: DSStarState) -> Literal["verifier", "__end__"]:
"""Route after coder to verifier."""
return state.get("next", "verifier")
def route_after_verifier(
state: DSStarState,
) -> Literal["router", "finalyzer", "__end__"]:
"""
Route after verifier based on sufficiency and iteration count.
If max iterations reached, go to finalyzer.
If sufficient, go to finalyzer.
Otherwise, go to router to decide next action.
"""
# Check max iterations
if state["iteration"] >= state["max_iterations"]:
print(f"\n⚠ Max iterations ({state['max_iterations']}) reached, finalizing...")
return "finalyzer"
return state.get("next", "router")
def route_after_router(state: DSStarState) -> Literal["planner", "backtrack"]:
"""
Route after router based on decision.
If router says "Add Step", go to planner.
If router says "Step N", go to backtrack.
"""
return state.get("next", "planner")
# ==================== GRAPH BUILDER ====================
def build_ds_star_graph(llm, max_iterations: int = 20):
"""
Constructs the LangGraph workflow for DS-STAR.
The workflow follows this pattern:
1. Analyzer: Analyze data files (runs once)
2. Planner: Generate next plan step
3. Coder: Implement plan as code
4. Verifier: Check if sufficient
5. If insufficient:
a. Router: Decide to add step or backtrack
b. Backtrack (optional): Remove wrong steps
c. Go back to Planner
6. If sufficient:
Finalyzer: Create polished final solution
Args:
llm: LLM instance (e.g., ChatOpenAI, ChatGoogleGenerativeAI)
max_iterations: Maximum refinement iterations (default: 20)
Returns:
Compiled LangGraph application with checkpointing
"""
# Initialize graph with state schema
workflow = StateGraph(DSStarState)
# Add all agent nodes
workflow.add_node("analyzer", analyzer_node)
workflow.add_node("planner", planner_node)
workflow.add_node("coder", coder_node)
workflow.add_node("verifier", verifier_node)
workflow.add_node("router", router_node)
workflow.add_node("backtrack", backtrack_node)
workflow.add_node("finalyzer", finalyzer_node)
# Set entry point
workflow.set_entry_point("analyzer")
# Add conditional edges with proper routing
workflow.add_conditional_edges(
"analyzer", route_after_analyzer, {"planner": "planner", "__end__": END}
)
workflow.add_conditional_edges(
"planner", route_after_planner, {"coder": "coder", "__end__": END}
)
workflow.add_conditional_edges(
"coder", route_after_coder, {"verifier": "verifier", "__end__": END}
)
workflow.add_conditional_edges(
"verifier",
route_after_verifier,
{"router": "router", "finalyzer": "finalyzer", "__end__": END},
)
workflow.add_conditional_edges(
"router", route_after_router, {"planner": "planner", "backtrack": "backtrack"}
)
workflow.add_edge("backtrack", "planner")
workflow.add_edge("finalyzer", END)
# Add memory/checkpointing
memory = MemorySaver()
# Compile graph
app = workflow.compile(checkpointer=memory)
return app
def create_initial_state(query: str, llm, max_iterations: int = 20) -> DSStarState:
"""
Create initial state for the DS-STAR workflow.
Args:
query: User's question to answer
llm: LLM instance
max_iterations: Maximum refinement iterations
Returns:
Initial DSStarState dictionary
"""
return {
"query": query,
"data_descriptions": {},
"plan": [],
"current_code": "",
"execution_result": "",
"is_sufficient": False,
"router_decision": "",
"iteration": 0,
"max_iterations": max_iterations,
"messages": [],
"next": "analyzer",
"llm": llm,
}
def run_ds_star(
query: str, llm, max_iterations: int = 20, thread_id: str = "ds-star-1"
):
"""
Run the complete DS-STAR workflow.
Args:
query: User's question to answer
llm: LLM instance
max_iterations: Maximum refinement iterations
thread_id: Unique thread ID for checkpointing
Returns:
Final state after workflow completion
"""
print("=" * 60)
print("DS-STAR MULTI-AGENT SYSTEM")
print("=" * 60)
print(f"Query: {query}")
print(f"Max Iterations: {max_iterations}")
print("=" * 60)
# Build graph
app = build_ds_star_graph(llm, max_iterations)
# Create initial state
initial_state = create_initial_state(query, llm, max_iterations)
# Run with checkpointing
config = {"configurable": {"thread_id": thread_id}}
try:
# Execute the workflow
final_state = app.invoke(initial_state, config)
# Display results
print("\n" + "=" * 60)
print("FINAL SOLUTION")
print("=" * 60)
print("\nGenerated Code:")
print("-" * 60)
print(final_state["current_code"])
print("\n" + "-" * 60)
print("Execution Result:")
print("-" * 60)
print(final_state["execution_result"])
print("=" * 60)
return final_state
except Exception as e:
print(f"\n✗ Error during execution: {str(e)}")
import traceback
traceback.print_exc()
return None
|