File size: 3,088 Bytes
8ff817c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""

Comprehensive test for DS-STAR workflow.



This test runs the complete multi-agent system to verify:

1. All agents are properly connected

2. The graph routing works correctly

3. The workflow can complete successfully

"""

import os
import sys

from dotenv import load_dotenv

load_dotenv()
LLM_MODEL = os.getenv("LLM_MODEL", "google/gemini-2.5-flash")
LLM_API_KEY = os.getenv("LLM_API_KEY", "")

# Add parent directory to path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from src.config import get_llm
from src.graph import run_ds_star


def test_complete_workflow():
    """Test the complete DS-STAR workflow."""
    print("=" * 60)
    print("COMPREHENSIVE DS-STAR WORKFLOW TEST")
    print("=" * 60)

    # Configuration
    query = "What percentage of transactions use credit cards?"
    max_iterations = 10  # Reduced for testing

    print(f"\nTest Query: {query}")
    print(f"Max Iterations: {max_iterations}")
    print()

    try:
        # Initialize LLM
        print("Initializing LLM (Gemini 1.5 Flash)...")
        llm = get_llm(
            provider="openai",
            model=LLM_MODEL,
            temperature=0,
            api_key=LLM_API_KEY,
        )

        print("✓ LLM initialized")
        print()

        # Run workflow
        print("Starting DS-STAR workflow...")
        print("=" * 60)

        final_state = run_ds_star(
            query=query,
            llm=llm,
            max_iterations=max_iterations,
            thread_id="test-session",
        )

        # Verify results
        print("\n" + "=" * 60)
        print("TEST VERIFICATION")
        print("=" * 60)

        if final_state is None:
            print("❌ FAILED: Workflow returned None")
            return False

        # Check that we got results
        checks = [
            ("Data descriptions", len(final_state.get("data_descriptions", {})) > 0),
            ("Plan generated", len(final_state.get("plan", [])) > 0),
            ("Code generated", len(final_state.get("current_code", "")) > 0),
            ("Execution result", len(final_state.get("execution_result", "")) > 0),
        ]

        all_passed = True
        for check_name, passed in checks:
            status = "✓" if passed else "✗"
            print(f"{status} {check_name}: {'PASS' if passed else 'FAIL'}")
            all_passed = all_passed and passed

        print("\n" + "=" * 60)
        if all_passed:
            print("✅ ALL TESTS PASSED")
            print("=" * 60)
            return True
        else:
            print("❌ SOME TESTS FAILED")
            print("=" * 60)
            return False

    except Exception as e:
        print(f"\n❌ TEST FAILED WITH EXCEPTION: {str(e)}")
        import traceback

        traceback.print_exc()
        return False


def main():
    """Run the test."""
    success = test_complete_workflow()
    return 0 if success else 1


if __name__ == "__main__":
    sys.exit(main())