Qwen-Image-Edit-2509-Turbo-Lightning / test_lora_implementation.py
LPX55's picture
major: load any lora implementation
ad7badd
#!/usr/bin/env python3
"""
Test script to validate the multi-LoRA implementation
"""
import sys
import os
# Add the current directory to the Python path
sys.path.insert(0, '/config/workspace/hf/Qwen-Image-Edit-2509-Turbo-Lightning')
def test_lora_config():
"""Test LoRA configuration system"""
print("Testing LoRA configuration system...")
# Import the configuration from our app
from app import LORA_CONFIG
# Validate configuration structure
for lora_name, config in LORA_CONFIG.items():
required_keys = ['repo_id', 'filename', 'type', 'method', 'prompt_template', 'description']
for key in required_keys:
if key not in config:
print(f"❌ Missing key '{key}' in {lora_name}")
return False
print(f"βœ… {lora_name}: Valid configuration")
print("βœ… LoRA configuration test passed!")
return True
def test_lora_manager():
"""Test LoRA manager functionality"""
print("\nTesting LoRA manager...")
try:
from lora_manager import LoRAManager
# Mock DiffusionPipeline class for testing
class MockPipeline:
def __init__(self):
self.loaded_loras = {}
def load_lora_weights(self, path):
self.loaded_loras['loaded'] = path
print(f"Mock: Loaded LoRA weights from {path}")
def fuse_lora(self):
print("Mock: Fused LoRA")
def unfuse_lora(self):
print("Mock: Unfused LoRA")
# Create mock pipeline and manager
mock_pipe = MockPipeline()
manager = LoRAManager(mock_pipe, "cpu")
# Test registration
manager.register_lora("test_lora", "/path/to/lora", type="edit")
print("βœ… LoRA registration test passed!")
# Test configuration
manager.configure_lora("test_lora", {"description": "Test LoRA"})
print("βœ… LoRA configuration test passed!")
# Test loading
manager.load_lora("test_lora")
print("βœ… LoRA loading test passed!")
return True
except Exception as e:
print(f"❌ LoRA manager test failed: {e}")
return False
def test_ui_functions():
"""Test UI-related functions"""
print("\nTesting UI functions...")
try:
# Mock Gradio components for testing
class MockComponent:
def __init__(self):
self.visible = True
self.label = "Test Component"
def update(self, visible=None, **kwargs):
self.visible = visible if visible is not None else self.visible
return self
# Import and test the UI change handler
from app import on_lora_change, LORA_CONFIG
# Create mock components
mock_components = {
'lora_description': MockComponent(),
'input_image_box': MockComponent(),
'style_image_box': MockComponent(),
'prompt_box': MockComponent()
}
# Test style LoRA (should show style_image, hide input_image)
result = on_lora_change("InStyle (Style Transfer)")
print("βœ… Style LoRA UI change test passed!")
# Test edit LoRA (should show input_image, hide style_image)
result = on_lora_change("InScene (In-Scene Editing)")
print("βœ… Edit LoRA UI change test passed!")
return True
except Exception as e:
print(f"❌ UI function test failed: {e}")
return False
def test_manual_fusion():
"""Test manual LoRA fusion function"""
print("\nTesting manual LoRA fusion...")
try:
import torch
from app import fuse_lora_manual
# Create a mock transformer for testing
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.randn(10, 5)
def named_modules(self):
return [('linear1', torch.nn.Linear(5, 10))]
# Create test data
mock_transformer = MockModule()
lora_state_dict = {
'diffusion_model.linear1.lora_A.weight': torch.randn(2, 5),
'diffusion_model.linear1.lora_B.weight': torch.randn(10, 2)
}
# Test fusion
result = fuse_lora_manual(mock_transformer, lora_state_dict)
print("βœ… Manual LoRA fusion test passed!")
return True
except Exception as e:
print(f"❌ Manual fusion test failed: {e}")
return False
def main():
"""Run all tests"""
print("=" * 50)
print("Multi-LoRA Implementation Validation")
print("=" * 50)
tests = [
test_lora_config,
test_lora_manager,
test_ui_functions,
test_manual_fusion
]
passed = 0
failed = 0
for test in tests:
try:
if test():
passed += 1
else:
failed += 1
except Exception as e:
print(f"❌ {test.__name__} failed with exception: {e}")
failed += 1
print("\n" + "=" * 50)
print(f"Test Results: {passed} passed, {failed} failed")
print("=" * 50)
if failed == 0:
print("πŸŽ‰ All tests passed! Multi-LoRA implementation is ready.")
return True
else:
print("⚠️ Some tests failed. Please check the implementation.")
return False
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)