From faa88f72bf98d49349b5097568e8fde6f2c6c1f5 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sat, 27 Jul 2024 02:39:53 +0300 Subject: [PATCH] Make lora as separate extensions --- invokeai/app/invocations/compel.py | 8 +- invokeai/app/invocations/denoise_latents.py | 19 +- invokeai/backend/model_patcher.py | 91 ++++----- .../stable_diffusion/extensions/lora.py | 145 +++++++++++++++ .../extensions/lora_patcher.py | 172 ------------------ 5 files changed, 190 insertions(+), 245 deletions(-) create mode 100644 invokeai/backend/stable_diffusion/extensions/lora.py delete mode 100644 invokeai/backend/stable_diffusion/extensions/lora_patcher.py diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index fffb09e654..5905df8dd7 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -80,12 +80,12 @@ class CompelInvocation(BaseInvocation): with ( # apply all patches while the model is on the target device - text_encoder_info.model_on_device() as (model_state_dict, text_encoder), + text_encoder_info.model_on_device() as (cached_weights, text_encoder), tokenizer_info as tokenizer, ModelPatcher.apply_lora_text_encoder( text_encoder, loras=_lora_loader(), - model_state_dict=model_state_dict, + cached_weights=cached_weights, ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers), @@ -175,13 +175,13 @@ class SDXLPromptInvocationBase: with ( # apply all patches while the model is on the target device - text_encoder_info.model_on_device() as (state_dict, text_encoder), + text_encoder_info.model_on_device() as (cached_weights, text_encoder), tokenizer_info as tokenizer, ModelPatcher.apply_lora( text_encoder, loras=_lora_loader(), prefix=lora_prefix, - model_state_dict=state_dict, + cached_weights=cached_weights, ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers), diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 39d2d3e08f..8795a44714 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -60,7 +60,7 @@ from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionB from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt -from invokeai.backend.stable_diffusion.extensions.lora_patcher import LoRAPatcherExt +from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager @@ -836,13 +836,14 @@ class DenoiseLatentsInvocation(BaseInvocation): ### lora if self.unet.loras: - ext_manager.add_extension( - LoRAPatcherExt( - node_context=context, - loras=self.unet.loras, - prefix="lora_unet_", + for lora_field in self.unet.loras: + ext_manager.add_extension( + LoRAExt( + node_context=context, + model_id=lora_field.lora, + weight=lora_field.weight, + ) ) - ) # context for loading additional models with ExitStack() as exit_stack: @@ -924,14 +925,14 @@ class DenoiseLatentsInvocation(BaseInvocation): assert isinstance(unet_info.model, UNet2DConditionModel) with ( ExitStack() as exit_stack, - unet_info.model_on_device() as (model_state_dict, unet), + unet_info.model_on_device() as (cached_weights, unet), ModelPatcher.apply_freeu(unet, self.unet.freeu_config), set_seamless(unet, self.unet.seamless_axes), # FIXME # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet( unet, loras=_lora_loader(), - model_state_dict=model_state_dict, + cached_weights=cached_weights, ), ): assert isinstance(unet, UNet2DConditionModel) diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index d30f7b3167..64893aa533 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -5,7 +5,7 @@ from __future__ import annotations import pickle from contextlib import contextmanager -from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Generator, Iterator, List, Optional, Set, Tuple, Type, Union import numpy as np import torch @@ -17,8 +17,8 @@ from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import AnyModel from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel +from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw -from invokeai.backend.util.devices import TorchDevice """ loras = [ @@ -85,13 +85,13 @@ class ModelPatcher: cls, unet: UNet2DConditionModel, loras: Iterator[Tuple[LoRAModelRaw, float]], - model_state_dict: Optional[Dict[str, torch.Tensor]] = None, + cached_weights: Optional[Dict[str, torch.Tensor]] = None, ) -> Generator[None, None, None]: with cls.apply_lora( unet, loras=loras, prefix="lora_unet_", - model_state_dict=model_state_dict, + cached_weights=cached_weights, ): yield @@ -101,9 +101,9 @@ class ModelPatcher: cls, text_encoder: CLIPTextModel, loras: Iterator[Tuple[LoRAModelRaw, float]], - model_state_dict: Optional[Dict[str, torch.Tensor]] = None, + cached_weights: Optional[Dict[str, torch.Tensor]] = None, ) -> Generator[None, None, None]: - with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict): + with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights): yield @classmethod @@ -113,7 +113,7 @@ class ModelPatcher: model: AnyModel, loras: Iterator[Tuple[LoRAModelRaw, float]], prefix: str, - model_state_dict: Optional[Dict[str, torch.Tensor]] = None, + cached_weights: Optional[Dict[str, torch.Tensor]] = None, ) -> Generator[None, None, None]: """ Apply one or more LoRAs to a model. @@ -121,66 +121,37 @@ class ModelPatcher: :param model: The model to patch. :param loras: An iterator that returns the LoRA to patch in and its patch weight. :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. - :model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes. + :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. """ - original_weights = {} + modified_cached_weights: Set[str] = set() + modified_weights: Dict[str, torch.Tensor] = {} try: - with torch.no_grad(): - for lora, lora_weight in loras: - # assert lora.device.type == "cpu" - for layer_key, layer in lora.layers.items(): - if not layer_key.startswith(prefix): - continue + for lora_model, lora_weight in loras: + lora_modified_cached_weights, lora_modified_weights = LoRAExt.patch_model( + model=model, + prefix=prefix, + lora=lora_model, + lora_weight=lora_weight, + cached_weights=cached_weights, + ) + del lora_model - # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This - # should be improved in the following ways: - # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a - # LoRA model is applied. - # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the - # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA - # weights to have valid keys. - assert isinstance(model, torch.nn.Module) - module_key, module = cls._resolve_lora_key(model, layer_key, prefix) + modified_cached_weights.update(lora_modified_cached_weights) + # Store only first returned weight for each key, because + # next extension which changes it, will work with already modified weight + for param_key, weight in lora_modified_weights.items(): + if param_key in modified_weights: + continue + modified_weights[param_key] = weight - # All of the LoRA weight calculations will be done on the same device as the module weight. - # (Performance will be best if this is a CUDA device.) - device = module.weight.device - dtype = module.weight.dtype - - if module_key not in original_weights: - if model_state_dict is not None: # we were provided with the CPU copy of the state dict - original_weights[module_key] = model_state_dict[module_key + ".weight"] - else: - original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) - - layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 - - # We intentionally move to the target device first, then cast. Experimentally, this was found to - # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the - # same thing in a single call to '.to(...)'. - layer.to(device=device) - layer.to(dtype=torch.float32) - # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA - # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. - layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) - layer.to(device=TorchDevice.CPU_DEVICE) - - assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! - if module.weight.shape != layer_weight.shape: - # TODO: debug on lycoris - assert hasattr(layer_weight, "reshape") - layer_weight = layer_weight.reshape(module.weight.shape) - - assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! - module.weight += layer_weight.to(dtype=dtype) - - yield # wait for context manager exit + yield finally: - assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() with torch.no_grad(): - for module_key, weight in original_weights.items(): - model.get_submodule(module_key).weight.copy_(weight) + for param_key in modified_cached_weights: + model.get_parameter(param_key).copy_(cached_weights[param_key]) + for param_key, weight in modified_weights.items(): + model.get_parameter(param_key).copy_(weight) @classmethod @contextmanager diff --git a/invokeai/backend/stable_diffusion/extensions/lora.py b/invokeai/backend/stable_diffusion/extensions/lora.py new file mode 100644 index 0000000000..11cdeb6021 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/lora.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple + +import torch +from diffusers import UNet2DConditionModel + +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase +from invokeai.backend.util.devices import TorchDevice + +if TYPE_CHECKING: + from invokeai.app.invocations.model import ModelIdentifierField + from invokeai.app.services.shared.invocation_context import InvocationContext + from invokeai.backend.lora import LoRAModelRaw + + +class LoRAExt(ExtensionBase): + def __init__( + self, + node_context: InvocationContext, + model_id: ModelIdentifierField, + weight: float, + ): + super().__init__() + self._node_context = node_context + self._model_id = model_id + self._weight = weight + + @contextmanager + def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): + lora_model = self._node_context.models.load(self._model_id).model + modified_cached_weights, modified_weights = self.patch_model( + model=unet, + prefix="lora_unet_", + lora=lora_model, + lora_weight=self._weight, + cached_weights=cached_weights, + ) + del lora_model + + yield modified_cached_weights, modified_weights + + @classmethod + def patch_model( + cls, + model: torch.nn.Module, + prefix: str, + lora: LoRAModelRaw, + lora_weight: float, + cached_weights: Optional[Dict[str, torch.Tensor]] = None, + ): + """ + Apply one or more LoRAs to a model. + :param model: The model to patch. + :param lora: LoRA model to patch in. + :param lora_weight: LoRA patch weight. + :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. + :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. + """ + if cached_weights is None: + cached_weights = {} + + modified_weights: Dict[str, torch.Tensor] = {} + modified_cached_weights: Set[str] = set() + with torch.no_grad(): + # assert lora.device.type == "cpu" + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue + + # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This + # should be improved in the following ways: + # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a + # LoRA model is applied. + # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the + # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA + # weights to have valid keys. + assert isinstance(model, torch.nn.Module) + module_key, module = cls._resolve_lora_key(model, layer_key, prefix) + + # All of the LoRA weight calculations will be done on the same device as the module weight. + # (Performance will be best if this is a CUDA device.) + device = module.weight.device + dtype = module.weight.dtype + + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + + # We intentionally move to the target device first, then cast. Experimentally, this was found to + # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the + # same thing in a single call to '.to(...)'. + layer.to(device=device) + layer.to(dtype=torch.float32) + + # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA + # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. + for param_name, lora_param_weight in layer.get_parameters(module).items(): + param_key = module_key + "." + param_name + module_param = module.get_parameter(param_name) + + # save original weight + if param_key not in modified_cached_weights and param_key not in modified_weights: + if param_key in cached_weights: + modified_cached_weights.add(param_key) + else: + modified_weights[param_key] = module_param.detach().to( + device=TorchDevice.CPU_DEVICE, copy=True + ) + + if module_param.shape != lora_param_weight.shape: + # TODO: debug on lycoris + lora_param_weight = lora_param_weight.reshape(module_param.shape) + + lora_param_weight *= lora_weight * layer_scale + module_param += lora_param_weight.to(dtype=dtype) + + layer.to(device=TorchDevice.CPU_DEVICE) + + return modified_cached_weights, modified_weights + + @staticmethod + def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: + assert "." not in lora_key + + if not lora_key.startswith(prefix): + raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}") + + module = model + module_key = "" + key_parts = lora_key[len(prefix) :].split("_") + + submodule_name = key_parts.pop(0) + + while len(key_parts) > 0: + try: + module = module.get_submodule(submodule_name) + module_key += "." + submodule_name + submodule_name = key_parts.pop(0) + except Exception: + submodule_name += "_" + key_parts.pop(0) + + module = module.get_submodule(submodule_name) + module_key = (module_key + "." + submodule_name).lstrip(".") + + return (module_key, module) diff --git a/invokeai/backend/stable_diffusion/extensions/lora_patcher.py b/invokeai/backend/stable_diffusion/extensions/lora_patcher.py deleted file mode 100644 index eb045a1ec4..0000000000 --- a/invokeai/backend/stable_diffusion/extensions/lora_patcher.py +++ /dev/null @@ -1,172 +0,0 @@ -from __future__ import annotations - -from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple - -import torch -from diffusers import UNet2DConditionModel - -from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase -from invokeai.backend.util.devices import TorchDevice - -if TYPE_CHECKING: - from invokeai.app.invocations.model import LoRAField - from invokeai.app.services.shared.invocation_context import InvocationContext - from invokeai.backend.lora import LoRAModelRaw - - -class LoRAPatcherExt(ExtensionBase): - def __init__( - self, - node_context: InvocationContext, - loras: List[LoRAField], - prefix: str, - ): - super().__init__() - self._loras = loras - self._prefix = prefix - self._node_context = node_context - - @contextmanager - def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): - def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: - for lora in self._loras: - lora_info = self._node_context.models.load(lora.lora) - lora_model = lora_info.model - yield (lora_model, lora.weight) - del lora_info - return - - yield self._patch_model( - model=unet, - prefix=self._prefix, - loras=_lora_loader(), - cached_weights=cached_weights, - ) - - @classmethod - @contextmanager - def static_patch_model( - cls, - model: torch.nn.Module, - prefix: str, - loras: Iterator[Tuple[LoRAModelRaw, float]], - cached_weights: Optional[Dict[str, torch.Tensor]] = None, - ): - modified_cached_weights, modified_weights = cls._patch_model( - model=model, - prefix=prefix, - loras=loras, - cached_weights=cached_weights, - ) - try: - yield - - finally: - with torch.no_grad(): - for param_key in modified_cached_weights: - model.get_parameter(param_key).copy_(cached_weights[param_key]) - for param_key, weight in modified_weights.items(): - model.get_parameter(param_key).copy_(weight) - - @classmethod - def _patch_model( - cls, - model: UNet2DConditionModel, - prefix: str, - loras: Iterator[Tuple[LoRAModelRaw, float]], - cached_weights: Optional[Dict[str, torch.Tensor]] = None, - ): - """ - Apply one or more LoRAs to a model. - :param model: The model to patch. - :param loras: An iterator that returns the LoRA to patch in and its patch weight. - :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. - :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. - """ - if cached_weights is None: - cached_weights = {} - - modified_weights = {} - modified_cached_weights = set() - with torch.no_grad(): - for lora, lora_weight in loras: - # assert lora.device.type == "cpu" - for layer_key, layer in lora.layers.items(): - if not layer_key.startswith(prefix): - continue - - # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This - # should be improved in the following ways: - # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a - # LoRA model is applied. - # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the - # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA - # weights to have valid keys. - assert isinstance(model, torch.nn.Module) - module_key, module = cls._resolve_lora_key(model, layer_key, prefix) - - # All of the LoRA weight calculations will be done on the same device as the module weight. - # (Performance will be best if this is a CUDA device.) - device = module.weight.device - dtype = module.weight.dtype - - layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 - - # We intentionally move to the target device first, then cast. Experimentally, this was found to - # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the - # same thing in a single call to '.to(...)'. - layer.to(device=device) - layer.to(dtype=torch.float32) - - # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA - # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. - for param_name, lora_param_weight in layer.get_parameters(module).items(): - param_key = module_key + "." + param_name - module_param = module.get_parameter(param_name) - - # save original weight - if param_key not in modified_cached_weights and param_key not in modified_weights: - if param_key in cached_weights: - modified_cached_weights.add(param_key) - else: - modified_weights[param_key] = module_param.detach().to( - device=TorchDevice.CPU_DEVICE, copy=True - ) - - if module_param.shape != lora_param_weight.shape: - # TODO: debug on lycoris - lora_param_weight = lora_param_weight.reshape(module_param.shape) - - lora_param_weight *= lora_weight * layer_scale - module_param += lora_param_weight.to(dtype=dtype) - - layer.to(device=TorchDevice.CPU_DEVICE) - - return modified_cached_weights, modified_weights - - @staticmethod - def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: - assert "." not in lora_key - - if not lora_key.startswith(prefix): - raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}") - - module = model - module_key = "" - key_parts = lora_key[len(prefix) :].split("_") - - submodule_name = key_parts.pop(0) - - while len(key_parts) > 0: - try: - module = module.get_submodule(submodule_name) - module_key += "." + submodule_name - submodule_name = key_parts.pop(0) - except Exception: - submodule_name += "_" + key_parts.pop(0) - - module = module.get_submodule(submodule_name) - module_key = (module_key + "." + submodule_name).lstrip(".") - - return (module_key, module)