diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index b4b59b5fcf..cec76ffea2 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -490,6 +490,9 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module): state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) for layer_key, values in state_dict.items(): + # Detect layers according to LyCORIS detection logic(`weight_list_det`) + # https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules + # lora and locon if "lora_up.weight" in values: layer: AnyLoRALayer = LoRALayer(layer_key, values) diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index 64893aa533..c0dc1dca1d 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -5,7 +5,7 @@ from __future__ import annotations import pickle from contextlib import contextmanager -from typing import Any, Dict, Generator, Iterator, List, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union import numpy as np import torch @@ -123,34 +123,25 @@ class ModelPatcher: :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. """ - modified_cached_weights: Set[str] = set() - modified_weights: Dict[str, torch.Tensor] = {} + original_weights: Dict[str, torch.Tensor] = {} + if cached_weights: + original_weights.update(cached_weights) try: for lora_model, lora_weight in loras: - lora_modified_cached_weights, lora_modified_weights = LoRAExt.patch_model( + LoRAExt.patch_model( model=model, prefix=prefix, lora=lora_model, lora_weight=lora_weight, - cached_weights=cached_weights, + original_weights=original_weights, ) del lora_model - modified_cached_weights.update(lora_modified_cached_weights) - # Store only first returned weight for each key, because - # next extension which changes it, will work with already modified weight - for param_key, weight in lora_modified_weights.items(): - if param_key in modified_weights: - continue - modified_weights[param_key] = weight - yield finally: with torch.no_grad(): - for param_key in modified_cached_weights: - model.get_parameter(param_key).copy_(cached_weights[param_key]) - for param_key, weight in modified_weights.items(): + for param_key, weight in original_weights.items(): model.get_parameter(param_key).copy_(weight) @classmethod diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index 1208c3f0ee..f9753b4344 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -2,7 +2,7 @@ from __future__ import annotations from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Callable, Dict, List import torch from diffusers import UNet2DConditionModel @@ -56,17 +56,17 @@ class ExtensionBase: yield None @contextmanager - def patch_unet( - self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None - ) -> Tuple[Set[str], Dict[str, torch.Tensor]]: - """Apply patches to UNet model. This function responsible for restoring all changes except weights, - changed weights should only be reported in return. - Return contains 2 values: - - Set of cached weights, just keys from cached_weights dictionary - - Dict of not cached weights that should be copies on the cpu device + def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]): + """A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire + diffusion process. Weight unpatching is handled upstream, and is achieved by adding unsaved weights in + `original_weights` dict. Note that this enables some performance optimization by avoiding redundant operations. + All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched by this + context manager. Args: unet (UNet2DConditionModel): The UNet model on execution device to patch. - cached_weights (Optional[Dict[str, torch.Tensor]]): Read-only copy of the model's state dict in CPU, for caches purposes. + cached_weights (Dict[str, torch.Tensor]]): A read-only copy of the model's original weights in CPU, for + unpatching purposes. Extension can save tensor which being modified, if it is not saved yet, or can + access original weight value. """ - yield set(), {} + yield diff --git a/invokeai/backend/stable_diffusion/extensions/freeu.py b/invokeai/backend/stable_diffusion/extensions/freeu.py index 593481f198..75370d23f4 100644 --- a/invokeai/backend/stable_diffusion/extensions/freeu.py +++ b/invokeai/backend/stable_diffusion/extensions/freeu.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict import torch from diffusers import UNet2DConditionModel @@ -21,9 +21,7 @@ class FreeUExt(ExtensionBase): self._freeu_config = freeu_config @contextmanager - def patch_unet( - self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None - ) -> Tuple[Set[str], Dict[str, torch.Tensor]]: + def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]): unet.enable_freeu( b1=self._freeu_config.b1, b2=self._freeu_config.b2, @@ -32,6 +30,6 @@ class FreeUExt(ExtensionBase): ) try: - yield set(), {} + yield finally: unet.disable_freeu() diff --git a/invokeai/backend/stable_diffusion/extensions/lora.py b/invokeai/backend/stable_diffusion/extensions/lora.py index 55f7259d96..71584247c0 100644 --- a/invokeai/backend/stable_diffusion/extensions/lora.py +++ b/invokeai/backend/stable_diffusion/extensions/lora.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Tuple import torch from diffusers import UNet2DConditionModel @@ -28,97 +28,84 @@ class LoRAExt(ExtensionBase): self._weight = weight @contextmanager - def patch_unet( - self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None - ) -> Tuple[Set[str], Dict[str, torch.Tensor]]: + def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]): lora_model = self._node_context.models.load(self._model_id).model - modified_cached_weights, modified_weights = self.patch_model( + self.patch_model( model=unet, prefix="lora_unet_", lora=lora_model, lora_weight=self._weight, - cached_weights=cached_weights, + original_weights=original_weights, ) del lora_model - yield modified_cached_weights, modified_weights + yield @classmethod + @torch.no_grad() def patch_model( cls, model: torch.nn.Module, prefix: str, lora: LoRAModelRaw, lora_weight: float, - cached_weights: Optional[Dict[str, torch.Tensor]] = None, - ) -> Tuple[Set[str], Dict[str, torch.Tensor]]: + original_weights: Dict[str, torch.Tensor], + ): """ Apply one or more LoRAs to a model. :param model: The model to patch. :param lora: LoRA model to patch in. :param lora_weight: LoRA patch weight. :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. - :param cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. + :param original_weights: TODO: """ - if cached_weights is None: - cached_weights = {} - modified_weights: Dict[str, torch.Tensor] = {} - modified_cached_weights: Set[str] = set() - with torch.no_grad(): - # assert lora.device.type == "cpu" - for layer_key, layer in lora.layers.items(): - if not layer_key.startswith(prefix): - continue + # assert lora.device.type == "cpu" + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue - # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This - # should be improved in the following ways: - # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a - # LoRA model is applied. - # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the - # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA - # weights to have valid keys. - assert isinstance(model, torch.nn.Module) - module_key, module = cls._resolve_lora_key(model, layer_key, prefix) + # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This + # should be improved in the following ways: + # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a + # LoRA model is applied. + # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the + # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA + # weights to have valid keys. + assert isinstance(model, torch.nn.Module) + module_key, module = cls._resolve_lora_key(model, layer_key, prefix) - # All of the LoRA weight calculations will be done on the same device as the module weight. - # (Performance will be best if this is a CUDA device.) - device = module.weight.device - dtype = module.weight.dtype + # All of the LoRA weight calculations will be done on the same device as the module weight. + # (Performance will be best if this is a CUDA device.) + device = module.weight.device + dtype = module.weight.dtype - layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 - # We intentionally move to the target device first, then cast. Experimentally, this was found to - # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the - # same thing in a single call to '.to(...)'. - layer.to(device=device) - layer.to(dtype=torch.float32) + # We intentionally move to the target device first, then cast. Experimentally, this was found to + # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the + # same thing in a single call to '.to(...)'. + layer.to(device=device) + layer.to(dtype=torch.float32) - # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA - # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. - for param_name, lora_param_weight in layer.get_parameters(module).items(): - param_key = module_key + "." + param_name - module_param = module.get_parameter(param_name) + # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA + # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. + for param_name, lora_param_weight in layer.get_parameters(module).items(): + param_key = module_key + "." + param_name + module_param = module.get_parameter(param_name) - # save original weight - if param_key not in modified_cached_weights and param_key not in modified_weights: - if param_key in cached_weights: - modified_cached_weights.add(param_key) - else: - modified_weights[param_key] = module_param.detach().to( - device=TorchDevice.CPU_DEVICE, copy=True - ) + # save original weight + if param_key not in original_weights: + original_weights[param_key] = module_param.detach().to(device=TorchDevice.CPU_DEVICE, copy=True) - if module_param.shape != lora_param_weight.shape: - # TODO: debug on lycoris - lora_param_weight = lora_param_weight.reshape(module_param.shape) + if module_param.shape != lora_param_weight.shape: + # TODO: debug on lycoris + lora_param_weight = lora_param_weight.reshape(module_param.shape) - lora_param_weight *= lora_weight * layer_scale - module_param += lora_param_weight.to(dtype=dtype) + lora_param_weight *= lora_weight * layer_scale + module_param += lora_param_weight.to(dtype=dtype) - layer.to(device=TorchDevice.CPU_DEVICE) - - return modified_cached_weights, modified_weights + layer.to(device=TorchDevice.CPU_DEVICE) @staticmethod def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index e838a2034b..968dffd069 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import ExitStack, contextmanager -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Callable, Dict, List, Optional import torch from diffusers import UNet2DConditionModel @@ -67,29 +67,18 @@ class ExtensionsManager: if self._is_canceled and self._is_canceled(): raise CanceledException - modified_weights: Dict[str, torch.Tensor] = {} - modified_cached_weights: Set[str] = set() + original_weights: Dict[str, torch.Tensor] = {} + if cached_weights: + original_weights.update(cached_weights) try: with ExitStack() as exit_stack: for ext in self._extensions: - ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context( - ext.patch_unet(unet, cached_weights) - ) - - modified_cached_weights.update(ext_modified_cached_weights) - # store only first returned weight for each key, because - # next extension which changes it, will work with already modified weight - for param_key, weight in ext_modified_weights.items(): - if param_key in modified_weights: - continue - modified_weights[param_key] = weight + exit_stack.enter_context(ext.patch_unet(unet, original_weights)) yield None finally: with torch.no_grad(): - for param_key in modified_cached_weights: - unet.get_parameter(param_key).copy_(cached_weights[param_key]) - for param_key, weight in modified_weights.items(): + for param_key, weight in original_weights.items(): unet.get_parameter(param_key).copy_(weight)