File size: 5,856 Bytes
ad7badd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from typing import Dict, Any, List
import torch
from diffusers import DiffusionPipeline


class LoRAManager:
    def __init__(self, pipeline: DiffusionPipeline, device: str = "cuda"):
        """
        Manages LoRA adapters for a given Diffusers pipeline.

        Args:
            pipeline (DiffusionPipeline): The Diffusers pipeline to manage LoRAs for.
            device (str, optional): The device to load LoRAs onto. Defaults to "cuda".
        """
        self.pipeline = pipeline
        self.device = device
        self.lora_registry: Dict[str, Dict[str, Any]] = {}
        self.lora_configurations: Dict[str, Dict[str, Any]] = {}
        self.current_lora: str = None

    def register_lora(self, lora_id: str, lora_path: str, **kwargs: Any) -> None:
        """
        Registers a LoRA adapter to the registry.

        Args:
            lora_id (str): A unique identifier for the LoRA adapter.
            lora_path (str): The path to the LoRA adapter weights.
            **kwargs (Any): Additional keyword arguments to store with the LoRA metadata.
        """
        if lora_id in self.lora_registry:
            raise ValueError(f"LoRA with id '{lora_id}' already registered.")

        self.lora_registry[lora_id] = {
            "lora_path": lora_path,
            "loaded": False,
            **kwargs,
        }

    def configure_lora(self, lora_id: str, ui_config: Dict[str, Any]) -> None:
        """
        Configures the UI elements for a specific LoRA.

        Args:
            lora_id (str): The identifier of the LoRA adapter.
            ui_config (Dict[str, Any]): A dictionary containing the UI configuration for the LoRA.
        """
        if lora_id not in self.lora_registry:
            raise ValueError(f"LoRA with id '{lora_id}' not registered.")
        self.lora_configurations[lora_id] = ui_config

    def load_lora(self, lora_id: str, load_in_8bit: bool = False) -> None:
        """
        Loads a LoRA adapter into the pipeline.

        Args:
            lora_id (str): The identifier of the LoRA adapter to load.
            load_in_8bit (bool, optional): Whether to load the LoRA in 8-bit mode. Defaults to False.
        """
        if lora_id not in self.lora_registry:
            raise ValueError(f"LoRA with id '{lora_id}' not registered.")

        if self.lora_registry[lora_id]["loaded"]:
            print(f"LoRA with id '{lora_id}' already loaded.")
            return

        lora_path = self.lora_registry[lora_id]["lora_path"]
        self.pipeline.load_lora_weights(lora_path)
        self.lora_registry[lora_id]["loaded"] = True
        self.current_lora = lora_id
        print(f"LoRA with id '{lora_id}' loaded successfully.")

    def unload_lora(self, lora_id: str) -> None:
        """
        Unloads a LoRA adapter from the pipeline.

        Args:
            lora_id (str): The identifier of the LoRA adapter to unload.
        """
        if lora_id not in self.lora_registry:
            raise ValueError(f"LoRA with id '{lora_id}' not registered.")

        if not self.lora_registry[lora_id]["loaded"]:
            print(f"LoRA with id '{lora_id}' is not currently loaded.")
            return

        # Implement LoRA unloading logic here (e.g., using PEFT methods)
        # This will depend on how LoRA is integrated into the pipeline
        # For example, if using PEFT's disable_adapters:
        # self.pipeline.disable_adapters()
        self.pipeline.unload_lora_weights()
        self.lora_registry[lora_id]["loaded"] = False
        if self.current_lora == lora_id:
            self.current_lora = None
        print(f"LoRA with id '{lora_id}' unloaded successfully.")

    def fuse_lora(self, lora_id: str) -> None:
        """
        Fuses the weights of a LoRA adapter into the pipeline.

        Args:
            lora_id (str): The identifier of the LoRA adapter to fuse.
        """
        if lora_id not in self.lora_registry:
            raise ValueError(f"LoRA with id '{lora_id}' not registered.")

        if not self.lora_registry[lora_id]["loaded"]:
            raise ValueError(f"LoRA with id '{lora_id}' must be loaded before fusing.")

        self.pipeline.fuse_lora()
        print(f"LoRA with id '{lora_id}' fused successfully.")

    def unfuse_lora(self) -> None:
        """
        Unfuses the weights of the currently fused LoRA adapter.
        """
        self.pipeline.unfuse_lora()
        print("LoRA unfused successfully.")

    def get_lora_metadata(self, lora_id: str) -> Dict[str, Any]:
        """
        Retrieves the metadata associated with a LoRA adapter.

        Args:
            lora_id (str): The identifier of the LoRA adapter.

        Returns:
            Dict[str, Any]: A dictionary containing the metadata for the LoRA adapter.
        """
        if lora_id not in self.lora_registry:
            raise ValueError(f"LoRA with id '{lora_id}' not registered.")

        return self.lora_registry[lora_id]

    def list_loras(self) -> List[str]:
        """
        Returns a list of all registered LoRA IDs.

        Returns:
            List[str]: A list of LoRA identifiers.
        """
        return list(self.lora_registry.keys())

    def get_current_lora(self) -> str:
        """
        Returns the ID of the currently active LoRA.

        Returns:
            str: The identifier of the currently active LoRA, or None if no LoRA is loaded.
        """
        return self.current_lora

    def get_lora_ui_config(self, lora_id: str) -> Dict[str, Any]:
        """
        Retrieves the UI configuration associated with a LoRA adapter.

        Args:
            lora_id (str): The identifier of the LoRA adapter.

        Returns:
            Dict[str, Any]: A dictionary containing the UI configuration for the LoRA adapter.
        """
        return self.lora_configurations.get(lora_id, {})