from typing import Dict, Any, List import torch from diffusers import DiffusionPipeline class LoRAManager: def __init__(self, pipeline: DiffusionPipeline, device: str = "cuda"): """ Manages LoRA adapters for a given Diffusers pipeline. Args: pipeline (DiffusionPipeline): The Diffusers pipeline to manage LoRAs for. device (str, optional): The device to load LoRAs onto. Defaults to "cuda". """ self.pipeline = pipeline self.device = device self.lora_registry: Dict[str, Dict[str, Any]] = {} self.lora_configurations: Dict[str, Dict[str, Any]] = {} self.current_lora: str = None def register_lora(self, lora_id: str, lora_path: str, **kwargs: Any) -> None: """ Registers a LoRA adapter to the registry. Args: lora_id (str): A unique identifier for the LoRA adapter. lora_path (str): The path to the LoRA adapter weights. **kwargs (Any): Additional keyword arguments to store with the LoRA metadata. """ if lora_id in self.lora_registry: raise ValueError(f"LoRA with id '{lora_id}' already registered.") self.lora_registry[lora_id] = { "lora_path": lora_path, "loaded": False, **kwargs, } def configure_lora(self, lora_id: str, ui_config: Dict[str, Any]) -> None: """ Configures the UI elements for a specific LoRA. Args: lora_id (str): The identifier of the LoRA adapter. ui_config (Dict[str, Any]): A dictionary containing the UI configuration for the LoRA. """ if lora_id not in self.lora_registry: raise ValueError(f"LoRA with id '{lora_id}' not registered.") self.lora_configurations[lora_id] = ui_config def load_lora(self, lora_id: str, load_in_8bit: bool = False) -> None: """ Loads a LoRA adapter into the pipeline. Args: lora_id (str): The identifier of the LoRA adapter to load. load_in_8bit (bool, optional): Whether to load the LoRA in 8-bit mode. Defaults to False. """ if lora_id not in self.lora_registry: raise ValueError(f"LoRA with id '{lora_id}' not registered.") if self.lora_registry[lora_id]["loaded"]: print(f"LoRA with id '{lora_id}' already loaded.") return lora_path = self.lora_registry[lora_id]["lora_path"] self.pipeline.load_lora_weights(lora_path) self.lora_registry[lora_id]["loaded"] = True self.current_lora = lora_id print(f"LoRA with id '{lora_id}' loaded successfully.") def unload_lora(self, lora_id: str) -> None: """ Unloads a LoRA adapter from the pipeline. Args: lora_id (str): The identifier of the LoRA adapter to unload. """ if lora_id not in self.lora_registry: raise ValueError(f"LoRA with id '{lora_id}' not registered.") if not self.lora_registry[lora_id]["loaded"]: print(f"LoRA with id '{lora_id}' is not currently loaded.") return # Implement LoRA unloading logic here (e.g., using PEFT methods) # This will depend on how LoRA is integrated into the pipeline # For example, if using PEFT's disable_adapters: # self.pipeline.disable_adapters() self.pipeline.unload_lora_weights() self.lora_registry[lora_id]["loaded"] = False if self.current_lora == lora_id: self.current_lora = None print(f"LoRA with id '{lora_id}' unloaded successfully.") def fuse_lora(self, lora_id: str) -> None: """ Fuses the weights of a LoRA adapter into the pipeline. Args: lora_id (str): The identifier of the LoRA adapter to fuse. """ if lora_id not in self.lora_registry: raise ValueError(f"LoRA with id '{lora_id}' not registered.") if not self.lora_registry[lora_id]["loaded"]: raise ValueError(f"LoRA with id '{lora_id}' must be loaded before fusing.") self.pipeline.fuse_lora() print(f"LoRA with id '{lora_id}' fused successfully.") def unfuse_lora(self) -> None: """ Unfuses the weights of the currently fused LoRA adapter. """ self.pipeline.unfuse_lora() print("LoRA unfused successfully.") def get_lora_metadata(self, lora_id: str) -> Dict[str, Any]: """ Retrieves the metadata associated with a LoRA adapter. Args: lora_id (str): The identifier of the LoRA adapter. Returns: Dict[str, Any]: A dictionary containing the metadata for the LoRA adapter. """ if lora_id not in self.lora_registry: raise ValueError(f"LoRA with id '{lora_id}' not registered.") return self.lora_registry[lora_id] def list_loras(self) -> List[str]: """ Returns a list of all registered LoRA IDs. Returns: List[str]: A list of LoRA identifiers. """ return list(self.lora_registry.keys()) def get_current_lora(self) -> str: """ Returns the ID of the currently active LoRA. Returns: str: The identifier of the currently active LoRA, or None if no LoRA is loaded. """ return self.current_lora def get_lora_ui_config(self, lora_id: str) -> Dict[str, Any]: """ Retrieves the UI configuration associated with a LoRA adapter. Args: lora_id (str): The identifier of the LoRA adapter. Returns: Dict[str, Any]: A dictionary containing the UI configuration for the LoRA adapter. """ return self.lora_configurations.get(lora_id, {})