Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Dict, Any, List | |
| import torch | |
| from diffusers import DiffusionPipeline | |
| class LoRAManager: | |
| def __init__(self, pipeline: DiffusionPipeline, device: str = "cuda"): | |
| """ | |
| Manages LoRA adapters for a given Diffusers pipeline. | |
| Args: | |
| pipeline (DiffusionPipeline): The Diffusers pipeline to manage LoRAs for. | |
| device (str, optional): The device to load LoRAs onto. Defaults to "cuda". | |
| """ | |
| self.pipeline = pipeline | |
| self.device = device | |
| self.lora_registry: Dict[str, Dict[str, Any]] = {} | |
| self.lora_configurations: Dict[str, Dict[str, Any]] = {} | |
| self.current_lora: str = None | |
| def register_lora(self, lora_id: str, lora_path: str, **kwargs: Any) -> None: | |
| """ | |
| Registers a LoRA adapter to the registry. | |
| Args: | |
| lora_id (str): A unique identifier for the LoRA adapter. | |
| lora_path (str): The path to the LoRA adapter weights. | |
| **kwargs (Any): Additional keyword arguments to store with the LoRA metadata. | |
| """ | |
| if lora_id in self.lora_registry: | |
| raise ValueError(f"LoRA with id '{lora_id}' already registered.") | |
| self.lora_registry[lora_id] = { | |
| "lora_path": lora_path, | |
| "loaded": False, | |
| **kwargs, | |
| } | |
| def configure_lora(self, lora_id: str, ui_config: Dict[str, Any]) -> None: | |
| """ | |
| Configures the UI elements for a specific LoRA. | |
| Args: | |
| lora_id (str): The identifier of the LoRA adapter. | |
| ui_config (Dict[str, Any]): A dictionary containing the UI configuration for the LoRA. | |
| """ | |
| if lora_id not in self.lora_registry: | |
| raise ValueError(f"LoRA with id '{lora_id}' not registered.") | |
| self.lora_configurations[lora_id] = ui_config | |
| def load_lora(self, lora_id: str, load_in_8bit: bool = False) -> None: | |
| """ | |
| Loads a LoRA adapter into the pipeline. | |
| Args: | |
| lora_id (str): The identifier of the LoRA adapter to load. | |
| load_in_8bit (bool, optional): Whether to load the LoRA in 8-bit mode. Defaults to False. | |
| """ | |
| if lora_id not in self.lora_registry: | |
| raise ValueError(f"LoRA with id '{lora_id}' not registered.") | |
| if self.lora_registry[lora_id]["loaded"]: | |
| print(f"LoRA with id '{lora_id}' already loaded.") | |
| return | |
| lora_path = self.lora_registry[lora_id]["lora_path"] | |
| self.pipeline.load_lora_weights(lora_path) | |
| self.lora_registry[lora_id]["loaded"] = True | |
| self.current_lora = lora_id | |
| print(f"LoRA with id '{lora_id}' loaded successfully.") | |
| def unload_lora(self, lora_id: str) -> None: | |
| """ | |
| Unloads a LoRA adapter from the pipeline. | |
| Args: | |
| lora_id (str): The identifier of the LoRA adapter to unload. | |
| """ | |
| if lora_id not in self.lora_registry: | |
| raise ValueError(f"LoRA with id '{lora_id}' not registered.") | |
| if not self.lora_registry[lora_id]["loaded"]: | |
| print(f"LoRA with id '{lora_id}' is not currently loaded.") | |
| return | |
| # Implement LoRA unloading logic here (e.g., using PEFT methods) | |
| # This will depend on how LoRA is integrated into the pipeline | |
| # For example, if using PEFT's disable_adapters: | |
| # self.pipeline.disable_adapters() | |
| self.pipeline.unload_lora_weights() | |
| self.lora_registry[lora_id]["loaded"] = False | |
| if self.current_lora == lora_id: | |
| self.current_lora = None | |
| print(f"LoRA with id '{lora_id}' unloaded successfully.") | |
| def fuse_lora(self, lora_id: str) -> None: | |
| """ | |
| Fuses the weights of a LoRA adapter into the pipeline. | |
| Args: | |
| lora_id (str): The identifier of the LoRA adapter to fuse. | |
| """ | |
| if lora_id not in self.lora_registry: | |
| raise ValueError(f"LoRA with id '{lora_id}' not registered.") | |
| if not self.lora_registry[lora_id]["loaded"]: | |
| raise ValueError(f"LoRA with id '{lora_id}' must be loaded before fusing.") | |
| self.pipeline.fuse_lora() | |
| print(f"LoRA with id '{lora_id}' fused successfully.") | |
| def unfuse_lora(self) -> None: | |
| """ | |
| Unfuses the weights of the currently fused LoRA adapter. | |
| """ | |
| self.pipeline.unfuse_lora() | |
| print("LoRA unfused successfully.") | |
| def get_lora_metadata(self, lora_id: str) -> Dict[str, Any]: | |
| """ | |
| Retrieves the metadata associated with a LoRA adapter. | |
| Args: | |
| lora_id (str): The identifier of the LoRA adapter. | |
| Returns: | |
| Dict[str, Any]: A dictionary containing the metadata for the LoRA adapter. | |
| """ | |
| if lora_id not in self.lora_registry: | |
| raise ValueError(f"LoRA with id '{lora_id}' not registered.") | |
| return self.lora_registry[lora_id] | |
| def list_loras(self) -> List[str]: | |
| """ | |
| Returns a list of all registered LoRA IDs. | |
| Returns: | |
| List[str]: A list of LoRA identifiers. | |
| """ | |
| return list(self.lora_registry.keys()) | |
| def get_current_lora(self) -> str: | |
| """ | |
| Returns the ID of the currently active LoRA. | |
| Returns: | |
| str: The identifier of the currently active LoRA, or None if no LoRA is loaded. | |
| """ | |
| return self.current_lora | |
| def get_lora_ui_config(self, lora_id: str) -> Dict[str, Any]: | |
| """ | |
| Retrieves the UI configuration associated with a LoRA adapter. | |
| Args: | |
| lora_id (str): The identifier of the LoRA adapter. | |
| Returns: | |
| Dict[str, Any]: A dictionary containing the UI configuration for the LoRA adapter. | |
| """ | |
| return self.lora_configurations.get(lora_id, {}) |