File size: 5,736 Bytes
ad7badd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#!/usr/bin/env python3
"""
Test script to validate the multi-LoRA implementation
"""
import sys
import os

# Add the current directory to the Python path
sys.path.insert(0, '/config/workspace/hf/Qwen-Image-Edit-2509-Turbo-Lightning')

def test_lora_config():
    """Test LoRA configuration system"""
    print("Testing LoRA configuration system...")
    
    # Import the configuration from our app
    from app import LORA_CONFIG
    
    # Validate configuration structure
    for lora_name, config in LORA_CONFIG.items():
        required_keys = ['repo_id', 'filename', 'type', 'method', 'prompt_template', 'description']
        for key in required_keys:
            if key not in config:
                print(f"❌ Missing key '{key}' in {lora_name}")
                return False
        print(f"βœ… {lora_name}: Valid configuration")
    
    print("βœ… LoRA configuration test passed!")
    return True

def test_lora_manager():
    """Test LoRA manager functionality"""
    print("\nTesting LoRA manager...")
    
    try:
        from lora_manager import LoRAManager
        
        # Mock DiffusionPipeline class for testing
        class MockPipeline:
            def __init__(self):
                self.loaded_loras = {}
            
            def load_lora_weights(self, path):
                self.loaded_loras['loaded'] = path
                print(f"Mock: Loaded LoRA weights from {path}")
            
            def fuse_lora(self):
                print("Mock: Fused LoRA")
            
            def unfuse_lora(self):
                print("Mock: Unfused LoRA")
        
        # Create mock pipeline and manager
        mock_pipe = MockPipeline()
        manager = LoRAManager(mock_pipe, "cpu")
        
        # Test registration
        manager.register_lora("test_lora", "/path/to/lora", type="edit")
        print("βœ… LoRA registration test passed!")
        
        # Test configuration
        manager.configure_lora("test_lora", {"description": "Test LoRA"})
        print("βœ… LoRA configuration test passed!")
        
        # Test loading
        manager.load_lora("test_lora")
        print("βœ… LoRA loading test passed!")
        
        return True
        
    except Exception as e:
        print(f"❌ LoRA manager test failed: {e}")
        return False

def test_ui_functions():
    """Test UI-related functions"""
    print("\nTesting UI functions...")
    
    try:
        # Mock Gradio components for testing
        class MockComponent:
            def __init__(self):
                self.visible = True
                self.label = "Test Component"
            
            def update(self, visible=None, **kwargs):
                self.visible = visible if visible is not None else self.visible
                return self
        
        # Import and test the UI change handler
        from app import on_lora_change, LORA_CONFIG
        
        # Create mock components
        mock_components = {
            'lora_description': MockComponent(),
            'input_image_box': MockComponent(),
            'style_image_box': MockComponent(),
            'prompt_box': MockComponent()
        }
        
        # Test style LoRA (should show style_image, hide input_image)
        result = on_lora_change("InStyle (Style Transfer)")
        print("βœ… Style LoRA UI change test passed!")
        
        # Test edit LoRA (should show input_image, hide style_image)
        result = on_lora_change("InScene (In-Scene Editing)")
        print("βœ… Edit LoRA UI change test passed!")
        
        return True
        
    except Exception as e:
        print(f"❌ UI function test failed: {e}")
        return False

def test_manual_fusion():
    """Test manual LoRA fusion function"""
    print("\nTesting manual LoRA fusion...")
    
    try:
        import torch
        from app import fuse_lora_manual
        
        # Create a mock transformer for testing
        class MockModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.weight = torch.randn(10, 5)
            
            def named_modules(self):
                return [('linear1', torch.nn.Linear(5, 10))]
        
        # Create test data
        mock_transformer = MockModule()
        lora_state_dict = {
            'diffusion_model.linear1.lora_A.weight': torch.randn(2, 5),
            'diffusion_model.linear1.lora_B.weight': torch.randn(10, 2)
        }
        
        # Test fusion
        result = fuse_lora_manual(mock_transformer, lora_state_dict)
        print("βœ… Manual LoRA fusion test passed!")
        
        return True
        
    except Exception as e:
        print(f"❌ Manual fusion test failed: {e}")
        return False

def main():
    """Run all tests"""
    print("=" * 50)
    print("Multi-LoRA Implementation Validation")
    print("=" * 50)
    
    tests = [
        test_lora_config,
        test_lora_manager,
        test_ui_functions,
        test_manual_fusion
    ]
    
    passed = 0
    failed = 0
    
    for test in tests:
        try:
            if test():
                passed += 1
            else:
                failed += 1
        except Exception as e:
            print(f"❌ {test.__name__} failed with exception: {e}")
            failed += 1
    
    print("\n" + "=" * 50)
    print(f"Test Results: {passed} passed, {failed} failed")
    print("=" * 50)
    
    if failed == 0:
        print("πŸŽ‰ All tests passed! Multi-LoRA implementation is ready.")
        return True
    else:
        print("⚠️ Some tests failed. Please check the implementation.")
        return False

if __name__ == "__main__":
    success = main()
    sys.exit(0 if success else 1)