Spaces:
Running
Running
| """ | |
| Comprehensive test for DS-STAR workflow. | |
| This test runs the complete multi-agent system to verify: | |
| 1. All agents are properly connected | |
| 2. The graph routing works correctly | |
| 3. The workflow can complete successfully | |
| """ | |
| import os | |
| import sys | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| LLM_MODEL = os.getenv("LLM_MODEL", "google/gemini-2.5-flash") | |
| LLM_API_KEY = os.getenv("LLM_API_KEY", "") | |
| # Add parent directory to path | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| from src.config import get_llm | |
| from src.graph import run_ds_star | |
| def test_complete_workflow(): | |
| """Test the complete DS-STAR workflow.""" | |
| print("=" * 60) | |
| print("COMPREHENSIVE DS-STAR WORKFLOW TEST") | |
| print("=" * 60) | |
| # Configuration | |
| query = "What percentage of transactions use credit cards?" | |
| max_iterations = 10 # Reduced for testing | |
| print(f"\nTest Query: {query}") | |
| print(f"Max Iterations: {max_iterations}") | |
| print() | |
| try: | |
| # Initialize LLM | |
| print("Initializing LLM (Gemini 1.5 Flash)...") | |
| llm = get_llm( | |
| provider="openai", | |
| model=LLM_MODEL, | |
| temperature=0, | |
| api_key=LLM_API_KEY, | |
| ) | |
| print("✓ LLM initialized") | |
| print() | |
| # Run workflow | |
| print("Starting DS-STAR workflow...") | |
| print("=" * 60) | |
| final_state = run_ds_star( | |
| query=query, | |
| llm=llm, | |
| max_iterations=max_iterations, | |
| thread_id="test-session", | |
| ) | |
| # Verify results | |
| print("\n" + "=" * 60) | |
| print("TEST VERIFICATION") | |
| print("=" * 60) | |
| if final_state is None: | |
| print("❌ FAILED: Workflow returned None") | |
| return False | |
| # Check that we got results | |
| checks = [ | |
| ("Data descriptions", len(final_state.get("data_descriptions", {})) > 0), | |
| ("Plan generated", len(final_state.get("plan", [])) > 0), | |
| ("Code generated", len(final_state.get("current_code", "")) > 0), | |
| ("Execution result", len(final_state.get("execution_result", "")) > 0), | |
| ] | |
| all_passed = True | |
| for check_name, passed in checks: | |
| status = "✓" if passed else "✗" | |
| print(f"{status} {check_name}: {'PASS' if passed else 'FAIL'}") | |
| all_passed = all_passed and passed | |
| print("\n" + "=" * 60) | |
| if all_passed: | |
| print("✅ ALL TESTS PASSED") | |
| print("=" * 60) | |
| return True | |
| else: | |
| print("❌ SOME TESTS FAILED") | |
| print("=" * 60) | |
| return False | |
| except Exception as e: | |
| print(f"\n❌ TEST FAILED WITH EXCEPTION: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def main(): | |
| """Run the test.""" | |
| success = test_complete_workflow() | |
| return 0 if success else 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |