LPX55's picture
major: load any lora implementation
ad7badd
from typing import Dict, Any, List
import torch
from diffusers import DiffusionPipeline
class LoRAManager:
def __init__(self, pipeline: DiffusionPipeline, device: str = "cuda"):
"""
Manages LoRA adapters for a given Diffusers pipeline.
Args:
pipeline (DiffusionPipeline): The Diffusers pipeline to manage LoRAs for.
device (str, optional): The device to load LoRAs onto. Defaults to "cuda".
"""
self.pipeline = pipeline
self.device = device
self.lora_registry: Dict[str, Dict[str, Any]] = {}
self.lora_configurations: Dict[str, Dict[str, Any]] = {}
self.current_lora: str = None
def register_lora(self, lora_id: str, lora_path: str, **kwargs: Any) -> None:
"""
Registers a LoRA adapter to the registry.
Args:
lora_id (str): A unique identifier for the LoRA adapter.
lora_path (str): The path to the LoRA adapter weights.
**kwargs (Any): Additional keyword arguments to store with the LoRA metadata.
"""
if lora_id in self.lora_registry:
raise ValueError(f"LoRA with id '{lora_id}' already registered.")
self.lora_registry[lora_id] = {
"lora_path": lora_path,
"loaded": False,
**kwargs,
}
def configure_lora(self, lora_id: str, ui_config: Dict[str, Any]) -> None:
"""
Configures the UI elements for a specific LoRA.
Args:
lora_id (str): The identifier of the LoRA adapter.
ui_config (Dict[str, Any]): A dictionary containing the UI configuration for the LoRA.
"""
if lora_id not in self.lora_registry:
raise ValueError(f"LoRA with id '{lora_id}' not registered.")
self.lora_configurations[lora_id] = ui_config
def load_lora(self, lora_id: str, load_in_8bit: bool = False) -> None:
"""
Loads a LoRA adapter into the pipeline.
Args:
lora_id (str): The identifier of the LoRA adapter to load.
load_in_8bit (bool, optional): Whether to load the LoRA in 8-bit mode. Defaults to False.
"""
if lora_id not in self.lora_registry:
raise ValueError(f"LoRA with id '{lora_id}' not registered.")
if self.lora_registry[lora_id]["loaded"]:
print(f"LoRA with id '{lora_id}' already loaded.")
return
lora_path = self.lora_registry[lora_id]["lora_path"]
self.pipeline.load_lora_weights(lora_path)
self.lora_registry[lora_id]["loaded"] = True
self.current_lora = lora_id
print(f"LoRA with id '{lora_id}' loaded successfully.")
def unload_lora(self, lora_id: str) -> None:
"""
Unloads a LoRA adapter from the pipeline.
Args:
lora_id (str): The identifier of the LoRA adapter to unload.
"""
if lora_id not in self.lora_registry:
raise ValueError(f"LoRA with id '{lora_id}' not registered.")
if not self.lora_registry[lora_id]["loaded"]:
print(f"LoRA with id '{lora_id}' is not currently loaded.")
return
# Implement LoRA unloading logic here (e.g., using PEFT methods)
# This will depend on how LoRA is integrated into the pipeline
# For example, if using PEFT's disable_adapters:
# self.pipeline.disable_adapters()
self.pipeline.unload_lora_weights()
self.lora_registry[lora_id]["loaded"] = False
if self.current_lora == lora_id:
self.current_lora = None
print(f"LoRA with id '{lora_id}' unloaded successfully.")
def fuse_lora(self, lora_id: str) -> None:
"""
Fuses the weights of a LoRA adapter into the pipeline.
Args:
lora_id (str): The identifier of the LoRA adapter to fuse.
"""
if lora_id not in self.lora_registry:
raise ValueError(f"LoRA with id '{lora_id}' not registered.")
if not self.lora_registry[lora_id]["loaded"]:
raise ValueError(f"LoRA with id '{lora_id}' must be loaded before fusing.")
self.pipeline.fuse_lora()
print(f"LoRA with id '{lora_id}' fused successfully.")
def unfuse_lora(self) -> None:
"""
Unfuses the weights of the currently fused LoRA adapter.
"""
self.pipeline.unfuse_lora()
print("LoRA unfused successfully.")
def get_lora_metadata(self, lora_id: str) -> Dict[str, Any]:
"""
Retrieves the metadata associated with a LoRA adapter.
Args:
lora_id (str): The identifier of the LoRA adapter.
Returns:
Dict[str, Any]: A dictionary containing the metadata for the LoRA adapter.
"""
if lora_id not in self.lora_registry:
raise ValueError(f"LoRA with id '{lora_id}' not registered.")
return self.lora_registry[lora_id]
def list_loras(self) -> List[str]:
"""
Returns a list of all registered LoRA IDs.
Returns:
List[str]: A list of LoRA identifiers.
"""
return list(self.lora_registry.keys())
def get_current_lora(self) -> str:
"""
Returns the ID of the currently active LoRA.
Returns:
str: The identifier of the currently active LoRA, or None if no LoRA is loaded.
"""
return self.current_lora
def get_lora_ui_config(self, lora_id: str) -> Dict[str, Any]:
"""
Retrieves the UI configuration associated with a LoRA adapter.
Args:
lora_id (str): The identifier of the LoRA adapter.
Returns:
Dict[str, Any]: A dictionary containing the UI configuration for the LoRA adapter.
"""
return self.lora_configurations.get(lora_id, {})