Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| """ | |
| Test script to validate the multi-LoRA implementation | |
| """ | |
| import sys | |
| import os | |
| # Add the current directory to the Python path | |
| sys.path.insert(0, '/config/workspace/hf/Qwen-Image-Edit-2509-Turbo-Lightning') | |
| def test_lora_config(): | |
| """Test LoRA configuration system""" | |
| print("Testing LoRA configuration system...") | |
| # Import the configuration from our app | |
| from app import LORA_CONFIG | |
| # Validate configuration structure | |
| for lora_name, config in LORA_CONFIG.items(): | |
| required_keys = ['repo_id', 'filename', 'type', 'method', 'prompt_template', 'description'] | |
| for key in required_keys: | |
| if key not in config: | |
| print(f"β Missing key '{key}' in {lora_name}") | |
| return False | |
| print(f"β {lora_name}: Valid configuration") | |
| print("β LoRA configuration test passed!") | |
| return True | |
| def test_lora_manager(): | |
| """Test LoRA manager functionality""" | |
| print("\nTesting LoRA manager...") | |
| try: | |
| from lora_manager import LoRAManager | |
| # Mock DiffusionPipeline class for testing | |
| class MockPipeline: | |
| def __init__(self): | |
| self.loaded_loras = {} | |
| def load_lora_weights(self, path): | |
| self.loaded_loras['loaded'] = path | |
| print(f"Mock: Loaded LoRA weights from {path}") | |
| def fuse_lora(self): | |
| print("Mock: Fused LoRA") | |
| def unfuse_lora(self): | |
| print("Mock: Unfused LoRA") | |
| # Create mock pipeline and manager | |
| mock_pipe = MockPipeline() | |
| manager = LoRAManager(mock_pipe, "cpu") | |
| # Test registration | |
| manager.register_lora("test_lora", "/path/to/lora", type="edit") | |
| print("β LoRA registration test passed!") | |
| # Test configuration | |
| manager.configure_lora("test_lora", {"description": "Test LoRA"}) | |
| print("β LoRA configuration test passed!") | |
| # Test loading | |
| manager.load_lora("test_lora") | |
| print("β LoRA loading test passed!") | |
| return True | |
| except Exception as e: | |
| print(f"β LoRA manager test failed: {e}") | |
| return False | |
| def test_ui_functions(): | |
| """Test UI-related functions""" | |
| print("\nTesting UI functions...") | |
| try: | |
| # Mock Gradio components for testing | |
| class MockComponent: | |
| def __init__(self): | |
| self.visible = True | |
| self.label = "Test Component" | |
| def update(self, visible=None, **kwargs): | |
| self.visible = visible if visible is not None else self.visible | |
| return self | |
| # Import and test the UI change handler | |
| from app import on_lora_change, LORA_CONFIG | |
| # Create mock components | |
| mock_components = { | |
| 'lora_description': MockComponent(), | |
| 'input_image_box': MockComponent(), | |
| 'style_image_box': MockComponent(), | |
| 'prompt_box': MockComponent() | |
| } | |
| # Test style LoRA (should show style_image, hide input_image) | |
| result = on_lora_change("InStyle (Style Transfer)") | |
| print("β Style LoRA UI change test passed!") | |
| # Test edit LoRA (should show input_image, hide style_image) | |
| result = on_lora_change("InScene (In-Scene Editing)") | |
| print("β Edit LoRA UI change test passed!") | |
| return True | |
| except Exception as e: | |
| print(f"β UI function test failed: {e}") | |
| return False | |
| def test_manual_fusion(): | |
| """Test manual LoRA fusion function""" | |
| print("\nTesting manual LoRA fusion...") | |
| try: | |
| import torch | |
| from app import fuse_lora_manual | |
| # Create a mock transformer for testing | |
| class MockModule(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.weight = torch.randn(10, 5) | |
| def named_modules(self): | |
| return [('linear1', torch.nn.Linear(5, 10))] | |
| # Create test data | |
| mock_transformer = MockModule() | |
| lora_state_dict = { | |
| 'diffusion_model.linear1.lora_A.weight': torch.randn(2, 5), | |
| 'diffusion_model.linear1.lora_B.weight': torch.randn(10, 2) | |
| } | |
| # Test fusion | |
| result = fuse_lora_manual(mock_transformer, lora_state_dict) | |
| print("β Manual LoRA fusion test passed!") | |
| return True | |
| except Exception as e: | |
| print(f"β Manual fusion test failed: {e}") | |
| return False | |
| def main(): | |
| """Run all tests""" | |
| print("=" * 50) | |
| print("Multi-LoRA Implementation Validation") | |
| print("=" * 50) | |
| tests = [ | |
| test_lora_config, | |
| test_lora_manager, | |
| test_ui_functions, | |
| test_manual_fusion | |
| ] | |
| passed = 0 | |
| failed = 0 | |
| for test in tests: | |
| try: | |
| if test(): | |
| passed += 1 | |
| else: | |
| failed += 1 | |
| except Exception as e: | |
| print(f"β {test.__name__} failed with exception: {e}") | |
| failed += 1 | |
| print("\n" + "=" * 50) | |
| print(f"Test Results: {passed} passed, {failed} failed") | |
| print("=" * 50) | |
| if failed == 0: | |
| print("π All tests passed! Multi-LoRA implementation is ready.") | |
| return True | |
| else: | |
| print("β οΈ Some tests failed. Please check the implementation.") | |
| return False | |
| if __name__ == "__main__": | |
| success = main() | |
| sys.exit(0 if success else 1) |