diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 2787074265..39d2d3e08f 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -60,6 +60,7 @@ from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionB from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt +from invokeai.backend.stable_diffusion.extensions.lora_patcher import LoRAPatcherExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager @@ -833,6 +834,16 @@ class DenoiseLatentsInvocation(BaseInvocation): if self.unet.freeu_config: ext_manager.add_extension(FreeUExt(self.unet.freeu_config)) + ### lora + if self.unet.loras: + ext_manager.add_extension( + LoRAPatcherExt( + node_context=context, + loras=self.unet.loras, + prefix="lora_unet_", + ) + ) + # context for loading additional models with ExitStack() as exit_stack: # later should be smth like: diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index 8ef81915f1..21b99d7f6c 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -49,6 +49,9 @@ class LoRALayerBase: def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: raise NotImplementedError() + def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: + raise NotImplementedError() + def calc_size(self) -> int: model_size = 0 for val in [self.bias]: @@ -93,6 +96,9 @@ class LoRALayer(LoRALayerBase): return weight + def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: + return {"weight": self.get_weight(orig_module.weight)} + def calc_size(self) -> int: model_size = super().calc_size() for val in [self.up, self.mid, self.down]: @@ -149,6 +155,9 @@ class LoHALayer(LoRALayerBase): return weight + def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: + return {"weight": self.get_weight(orig_module.weight)} + def calc_size(self) -> int: model_size = super().calc_size() for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]: @@ -241,6 +250,9 @@ class LoKRLayer(LoRALayerBase): return weight + def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: + return {"weight": self.get_weight(orig_module.weight)} + def calc_size(self) -> int: model_size = super().calc_size() for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]: @@ -293,6 +305,9 @@ class FullLayer(LoRALayerBase): def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: return self.weight + def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: + return {"weight": self.get_weight(orig_module.weight)} + def calc_size(self) -> int: model_size = super().calc_size() model_size += self.weight.nelement() * self.weight.element_size() @@ -327,6 +342,9 @@ class IA3Layer(LoRALayerBase): assert orig_weight is not None return orig_weight * weight + def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: + return {"weight": self.get_weight(orig_module.weight)} + def calc_size(self) -> int: model_size = super().calc_size() model_size += self.weight.nelement() * self.weight.element_size() diff --git a/invokeai/backend/stable_diffusion/extensions/lora_patcher.py b/invokeai/backend/stable_diffusion/extensions/lora_patcher.py new file mode 100644 index 0000000000..452bcec1ef --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/lora_patcher.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple + +import torch +from diffusers import UNet2DConditionModel + +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase +from invokeai.backend.util.devices import TorchDevice + +if TYPE_CHECKING: + from invokeai.app.invocations.model import LoRAField + from invokeai.app.services.shared.invocation_context import InvocationContext + from invokeai.backend.lora import LoRAModelRaw + + +class LoRAPatcherExt(ExtensionBase): + def __init__( + self, + node_context: InvocationContext, + loras: List[LoRAField], + prefix: str, + ): + super().__init__() + self._loras = loras + self._prefix = prefix + self._node_context = node_context + + @contextmanager + def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: + for lora in self._loras: + lora_info = self._node_context.models.load(lora.lora) + lora_model = lora_info.model + yield (lora_model, lora.weight) + del lora_info + return + + yield self._patch_model( + model=unet, + prefix=self._prefix, + loras=_lora_loader(), + cached_weights=cached_weights, + ) + + @classmethod + @contextmanager + def static_patch_model( + cls, + model: torch.nn.Module, + prefix: str, + loras: Iterator[Tuple[LoRAModelRaw, float]], + cached_weights: Optional[Dict[str, torch.Tensor]] = None, + ): + modified_cached_weights, modified_weights = cls._patch_model( + model=model, + prefix=prefix, + loras=loras, + cached_weights=cached_weights, + ) + try: + yield + + finally: + with torch.no_grad(): + for param_key in modified_cached_weights: + model.get_parameter(param_key).copy_(cached_weights[param_key]) + for param_key, weight in modified_weights.items(): + model.get_parameter(param_key).copy_(weight) + + @classmethod + def _patch_model( + cls, + model: UNet2DConditionModel, + prefix: str, + loras: Iterator[Tuple[LoRAModelRaw, float]], + cached_weights: Optional[Dict[str, torch.Tensor]] = None, + ): + """ + Apply one or more LoRAs to a model. + :param model: The model to patch. + :param loras: An iterator that returns the LoRA to patch in and its patch weight. + :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. + :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. + """ + if cached_weights is None: + cached_weights = {} + + modified_weights = {} + modified_cached_weights = set() + with torch.no_grad(): + for lora, lora_weight in loras: + # assert lora.device.type == "cpu" + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue + + # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This + # should be improved in the following ways: + # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a + # LoRA model is applied. + # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the + # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA + # weights to have valid keys. + assert isinstance(model, torch.nn.Module) + module_key, module = cls._resolve_lora_key(model, layer_key, prefix) + + # All of the LoRA weight calculations will be done on the same device as the module weight. + # (Performance will be best if this is a CUDA device.) + device = module.weight.device + dtype = module.weight.dtype + + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + + # We intentionally move to the target device first, then cast. Experimentally, this was found to + # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the + # same thing in a single call to '.to(...)'. + layer.to(device=device) + layer.to(dtype=torch.float32) + + # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA + # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. + for param_name, lora_param_weight in layer.get_parameters(module).items(): + param_key = module_key + "." + param_name + module_param = module.get_parameter(param_name) + + # save original weight + if param_key not in modified_cached_weights and param_key not in modified_weights: + if param_key in cached_weights: + modified_cached_weights.add(param_key) + else: + modified_weights[param_key] = module_param.detach().to( + device=TorchDevice.CPU_DEVICE, copy=True + ) + + if module_param.shape != lora_param_weight.shape: + # TODO: debug on lycoris + lora_param_weight = lora_param_weight.reshape(module_param.shape) + + lora_param_weight *= (lora_weight * layer_scale) + module_param += lora_param_weight.to(dtype=dtype) + + layer.to(device=TorchDevice.CPU_DEVICE) + + return modified_cached_weights, modified_weights + + @staticmethod + def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: + assert "." not in lora_key + + if not lora_key.startswith(prefix): + raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}") + + module = model + module_key = "" + key_parts = lora_key[len(prefix) :].split("_") + + submodule_name = key_parts.pop(0) + + while len(key_parts) > 0: + try: + module = module.get_submodule(submodule_name) + module_key += "." + submodule_name + submodule_name = key_parts.pop(0) + except Exception: + submodule_name += "_" + key_parts.pop(0) + + module = module.get_submodule(submodule_name) + module_key = (module_key + "." + submodule_name).lstrip(".") + + return (module_key, module) diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index c8d585406a..4f7e1e0874 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import ExitStack, contextmanager -from typing import TYPE_CHECKING, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set import torch from diffusers import UNet2DConditionModel @@ -67,9 +67,31 @@ class ExtensionsManager: if self._is_canceled and self._is_canceled(): raise CanceledException - # TODO: create weight patch logic in PR with extension which uses it - with ExitStack() as exit_stack: + modified_weights: Dict[str, torch.Tensor] = {} + modified_cached_weights: Set[str] = set() + + exit_stack = ExitStack() + try: for ext in self._extensions: - exit_stack.enter_context(ext.patch_unet(unet, cached_weights)) + res = exit_stack.enter_context(ext.patch_unet(unet, cached_weights)) + if res is None: + continue + ext_modified_cached_weights, ext_modified_weights = res + + modified_cached_weights.update(ext_modified_cached_weights) + # store only first returned weight for each key, because + # next extension which changes it, will work with already modified weight + for param_key, weight in ext_modified_weights.items(): + if param_key in modified_weights: + continue + modified_weights[param_key] = weight yield None + + finally: + exit_stack.close() + with torch.no_grad(): + for param_key in modified_cached_weights: + unet.get_parameter(param_key).copy_(cached_weights[param_key]) + for param_key, weight in modified_weights.items(): + unet.get_parameter(param_key).copy_(weight)