DS-STAR / src /graph.py
anurag-deo's picture
Upload folder using huggingface_hub
8ff817c verified
raw
history blame
6.86 kB
"""
DS-STAR Graph: Connects all agents into a workflow.
This module builds the LangGraph StateGraph that orchestrates the multi-agent system.
"""
from typing import Literal
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from .agents.analyzer_agent import analyzer_node
from .agents.coder_agent import coder_node
from .agents.finalyzer_agent import finalyzer_node
from .agents.planner_agent import planner_node
from .agents.router_agent import backtrack_node, router_node
from .agents.verifier_agent import verifier_node
from .utils.state import DSStarState
# ==================== CONDITIONAL ROUTING FUNCTIONS ====================
def route_after_analyzer(state: DSStarState) -> Literal["planner", "__end__"]:
"""
Route after analyzer based on success.
If analyzer found errors, end workflow.
Otherwise, proceed to planner.
"""
if "error" in state.get("data_descriptions", {}):
return "__end__"
return state.get("next", "planner")
def route_after_planner(state: DSStarState) -> Literal["coder", "__end__"]:
"""Route after planner to coder."""
return state.get("next", "coder")
def route_after_coder(state: DSStarState) -> Literal["verifier", "__end__"]:
"""Route after coder to verifier."""
return state.get("next", "verifier")
def route_after_verifier(
state: DSStarState,
) -> Literal["router", "finalyzer", "__end__"]:
"""
Route after verifier based on sufficiency and iteration count.
If max iterations reached, go to finalyzer.
If sufficient, go to finalyzer.
Otherwise, go to router to decide next action.
"""
# Check max iterations
if state["iteration"] >= state["max_iterations"]:
print(f"\n⚠ Max iterations ({state['max_iterations']}) reached, finalizing...")
return "finalyzer"
return state.get("next", "router")
def route_after_router(state: DSStarState) -> Literal["planner", "backtrack"]:
"""
Route after router based on decision.
If router says "Add Step", go to planner.
If router says "Step N", go to backtrack.
"""
return state.get("next", "planner")
# ==================== GRAPH BUILDER ====================
def build_ds_star_graph(llm, max_iterations: int = 20):
"""
Constructs the LangGraph workflow for DS-STAR.
The workflow follows this pattern:
1. Analyzer: Analyze data files (runs once)
2. Planner: Generate next plan step
3. Coder: Implement plan as code
4. Verifier: Check if sufficient
5. If insufficient:
a. Router: Decide to add step or backtrack
b. Backtrack (optional): Remove wrong steps
c. Go back to Planner
6. If sufficient:
Finalyzer: Create polished final solution
Args:
llm: LLM instance (e.g., ChatOpenAI, ChatGoogleGenerativeAI)
max_iterations: Maximum refinement iterations (default: 20)
Returns:
Compiled LangGraph application with checkpointing
"""
# Initialize graph with state schema
workflow = StateGraph(DSStarState)
# Add all agent nodes
workflow.add_node("analyzer", analyzer_node)
workflow.add_node("planner", planner_node)
workflow.add_node("coder", coder_node)
workflow.add_node("verifier", verifier_node)
workflow.add_node("router", router_node)
workflow.add_node("backtrack", backtrack_node)
workflow.add_node("finalyzer", finalyzer_node)
# Set entry point
workflow.set_entry_point("analyzer")
# Add conditional edges with proper routing
workflow.add_conditional_edges(
"analyzer", route_after_analyzer, {"planner": "planner", "__end__": END}
)
workflow.add_conditional_edges(
"planner", route_after_planner, {"coder": "coder", "__end__": END}
)
workflow.add_conditional_edges(
"coder", route_after_coder, {"verifier": "verifier", "__end__": END}
)
workflow.add_conditional_edges(
"verifier",
route_after_verifier,
{"router": "router", "finalyzer": "finalyzer", "__end__": END},
)
workflow.add_conditional_edges(
"router", route_after_router, {"planner": "planner", "backtrack": "backtrack"}
)
workflow.add_edge("backtrack", "planner")
workflow.add_edge("finalyzer", END)
# Add memory/checkpointing
memory = MemorySaver()
# Compile graph
app = workflow.compile(checkpointer=memory)
return app
def create_initial_state(query: str, llm, max_iterations: int = 20) -> DSStarState:
"""
Create initial state for the DS-STAR workflow.
Args:
query: User's question to answer
llm: LLM instance
max_iterations: Maximum refinement iterations
Returns:
Initial DSStarState dictionary
"""
return {
"query": query,
"data_descriptions": {},
"plan": [],
"current_code": "",
"execution_result": "",
"is_sufficient": False,
"router_decision": "",
"iteration": 0,
"max_iterations": max_iterations,
"messages": [],
"next": "analyzer",
"llm": llm,
}
def run_ds_star(
query: str, llm, max_iterations: int = 20, thread_id: str = "ds-star-1"
):
"""
Run the complete DS-STAR workflow.
Args:
query: User's question to answer
llm: LLM instance
max_iterations: Maximum refinement iterations
thread_id: Unique thread ID for checkpointing
Returns:
Final state after workflow completion
"""
print("=" * 60)
print("DS-STAR MULTI-AGENT SYSTEM")
print("=" * 60)
print(f"Query: {query}")
print(f"Max Iterations: {max_iterations}")
print("=" * 60)
# Build graph
app = build_ds_star_graph(llm, max_iterations)
# Create initial state
initial_state = create_initial_state(query, llm, max_iterations)
# Run with checkpointing
config = {"configurable": {"thread_id": thread_id}}
try:
# Execute the workflow
final_state = app.invoke(initial_state, config)
# Display results
print("\n" + "=" * 60)
print("FINAL SOLUTION")
print("=" * 60)
print("\nGenerated Code:")
print("-" * 60)
print(final_state["current_code"])
print("\n" + "-" * 60)
print("Execution Result:")
print("-" * 60)
print(final_state["execution_result"])
print("=" * 60)
return final_state
except Exception as e:
print(f"\n✗ Error during execution: {str(e)}")
import traceback
traceback.print_exc()
return None