Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,856 Bytes
ad7badd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
from typing import Dict, Any, List
import torch
from diffusers import DiffusionPipeline
class LoRAManager:
def __init__(self, pipeline: DiffusionPipeline, device: str = "cuda"):
"""
Manages LoRA adapters for a given Diffusers pipeline.
Args:
pipeline (DiffusionPipeline): The Diffusers pipeline to manage LoRAs for.
device (str, optional): The device to load LoRAs onto. Defaults to "cuda".
"""
self.pipeline = pipeline
self.device = device
self.lora_registry: Dict[str, Dict[str, Any]] = {}
self.lora_configurations: Dict[str, Dict[str, Any]] = {}
self.current_lora: str = None
def register_lora(self, lora_id: str, lora_path: str, **kwargs: Any) -> None:
"""
Registers a LoRA adapter to the registry.
Args:
lora_id (str): A unique identifier for the LoRA adapter.
lora_path (str): The path to the LoRA adapter weights.
**kwargs (Any): Additional keyword arguments to store with the LoRA metadata.
"""
if lora_id in self.lora_registry:
raise ValueError(f"LoRA with id '{lora_id}' already registered.")
self.lora_registry[lora_id] = {
"lora_path": lora_path,
"loaded": False,
**kwargs,
}
def configure_lora(self, lora_id: str, ui_config: Dict[str, Any]) -> None:
"""
Configures the UI elements for a specific LoRA.
Args:
lora_id (str): The identifier of the LoRA adapter.
ui_config (Dict[str, Any]): A dictionary containing the UI configuration for the LoRA.
"""
if lora_id not in self.lora_registry:
raise ValueError(f"LoRA with id '{lora_id}' not registered.")
self.lora_configurations[lora_id] = ui_config
def load_lora(self, lora_id: str, load_in_8bit: bool = False) -> None:
"""
Loads a LoRA adapter into the pipeline.
Args:
lora_id (str): The identifier of the LoRA adapter to load.
load_in_8bit (bool, optional): Whether to load the LoRA in 8-bit mode. Defaults to False.
"""
if lora_id not in self.lora_registry:
raise ValueError(f"LoRA with id '{lora_id}' not registered.")
if self.lora_registry[lora_id]["loaded"]:
print(f"LoRA with id '{lora_id}' already loaded.")
return
lora_path = self.lora_registry[lora_id]["lora_path"]
self.pipeline.load_lora_weights(lora_path)
self.lora_registry[lora_id]["loaded"] = True
self.current_lora = lora_id
print(f"LoRA with id '{lora_id}' loaded successfully.")
def unload_lora(self, lora_id: str) -> None:
"""
Unloads a LoRA adapter from the pipeline.
Args:
lora_id (str): The identifier of the LoRA adapter to unload.
"""
if lora_id not in self.lora_registry:
raise ValueError(f"LoRA with id '{lora_id}' not registered.")
if not self.lora_registry[lora_id]["loaded"]:
print(f"LoRA with id '{lora_id}' is not currently loaded.")
return
# Implement LoRA unloading logic here (e.g., using PEFT methods)
# This will depend on how LoRA is integrated into the pipeline
# For example, if using PEFT's disable_adapters:
# self.pipeline.disable_adapters()
self.pipeline.unload_lora_weights()
self.lora_registry[lora_id]["loaded"] = False
if self.current_lora == lora_id:
self.current_lora = None
print(f"LoRA with id '{lora_id}' unloaded successfully.")
def fuse_lora(self, lora_id: str) -> None:
"""
Fuses the weights of a LoRA adapter into the pipeline.
Args:
lora_id (str): The identifier of the LoRA adapter to fuse.
"""
if lora_id not in self.lora_registry:
raise ValueError(f"LoRA with id '{lora_id}' not registered.")
if not self.lora_registry[lora_id]["loaded"]:
raise ValueError(f"LoRA with id '{lora_id}' must be loaded before fusing.")
self.pipeline.fuse_lora()
print(f"LoRA with id '{lora_id}' fused successfully.")
def unfuse_lora(self) -> None:
"""
Unfuses the weights of the currently fused LoRA adapter.
"""
self.pipeline.unfuse_lora()
print("LoRA unfused successfully.")
def get_lora_metadata(self, lora_id: str) -> Dict[str, Any]:
"""
Retrieves the metadata associated with a LoRA adapter.
Args:
lora_id (str): The identifier of the LoRA adapter.
Returns:
Dict[str, Any]: A dictionary containing the metadata for the LoRA adapter.
"""
if lora_id not in self.lora_registry:
raise ValueError(f"LoRA with id '{lora_id}' not registered.")
return self.lora_registry[lora_id]
def list_loras(self) -> List[str]:
"""
Returns a list of all registered LoRA IDs.
Returns:
List[str]: A list of LoRA identifiers.
"""
return list(self.lora_registry.keys())
def get_current_lora(self) -> str:
"""
Returns the ID of the currently active LoRA.
Returns:
str: The identifier of the currently active LoRA, or None if no LoRA is loaded.
"""
return self.current_lora
def get_lora_ui_config(self, lora_id: str) -> Dict[str, Any]:
"""
Retrieves the UI configuration associated with a LoRA adapter.
Args:
lora_id (str): The identifier of the LoRA adapter.
Returns:
Dict[str, Any]: A dictionary containing the UI configuration for the LoRA adapter.
"""
return self.lora_configurations.get(lora_id, {}) |