From ab0bfa709adfddb5d4a66ea68a027c1d507d0cac Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 24 Jul 2024 05:07:29 +0300 Subject: [PATCH 01/13] Handle loras in modular denoise --- invokeai/app/invocations/denoise_latents.py | 11 ++ invokeai/backend/lora.py | 18 ++ .../extensions/lora_patcher.py | 172 ++++++++++++++++++ .../stable_diffusion/extensions_manager.py | 30 ++- 4 files changed, 227 insertions(+), 4 deletions(-) create mode 100644 invokeai/backend/stable_diffusion/extensions/lora_patcher.py diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 2787074265..39d2d3e08f 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -60,6 +60,7 @@ from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionB from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt +from invokeai.backend.stable_diffusion.extensions.lora_patcher import LoRAPatcherExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager @@ -833,6 +834,16 @@ class DenoiseLatentsInvocation(BaseInvocation): if self.unet.freeu_config: ext_manager.add_extension(FreeUExt(self.unet.freeu_config)) + ### lora + if self.unet.loras: + ext_manager.add_extension( + LoRAPatcherExt( + node_context=context, + loras=self.unet.loras, + prefix="lora_unet_", + ) + ) + # context for loading additional models with ExitStack() as exit_stack: # later should be smth like: diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index 8ef81915f1..21b99d7f6c 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -49,6 +49,9 @@ class LoRALayerBase: def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: raise NotImplementedError() + def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: + raise NotImplementedError() + def calc_size(self) -> int: model_size = 0 for val in [self.bias]: @@ -93,6 +96,9 @@ class LoRALayer(LoRALayerBase): return weight + def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: + return {"weight": self.get_weight(orig_module.weight)} + def calc_size(self) -> int: model_size = super().calc_size() for val in [self.up, self.mid, self.down]: @@ -149,6 +155,9 @@ class LoHALayer(LoRALayerBase): return weight + def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: + return {"weight": self.get_weight(orig_module.weight)} + def calc_size(self) -> int: model_size = super().calc_size() for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]: @@ -241,6 +250,9 @@ class LoKRLayer(LoRALayerBase): return weight + def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: + return {"weight": self.get_weight(orig_module.weight)} + def calc_size(self) -> int: model_size = super().calc_size() for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]: @@ -293,6 +305,9 @@ class FullLayer(LoRALayerBase): def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: return self.weight + def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: + return {"weight": self.get_weight(orig_module.weight)} + def calc_size(self) -> int: model_size = super().calc_size() model_size += self.weight.nelement() * self.weight.element_size() @@ -327,6 +342,9 @@ class IA3Layer(LoRALayerBase): assert orig_weight is not None return orig_weight * weight + def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: + return {"weight": self.get_weight(orig_module.weight)} + def calc_size(self) -> int: model_size = super().calc_size() model_size += self.weight.nelement() * self.weight.element_size() diff --git a/invokeai/backend/stable_diffusion/extensions/lora_patcher.py b/invokeai/backend/stable_diffusion/extensions/lora_patcher.py new file mode 100644 index 0000000000..452bcec1ef --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/lora_patcher.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple + +import torch +from diffusers import UNet2DConditionModel + +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase +from invokeai.backend.util.devices import TorchDevice + +if TYPE_CHECKING: + from invokeai.app.invocations.model import LoRAField + from invokeai.app.services.shared.invocation_context import InvocationContext + from invokeai.backend.lora import LoRAModelRaw + + +class LoRAPatcherExt(ExtensionBase): + def __init__( + self, + node_context: InvocationContext, + loras: List[LoRAField], + prefix: str, + ): + super().__init__() + self._loras = loras + self._prefix = prefix + self._node_context = node_context + + @contextmanager + def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: + for lora in self._loras: + lora_info = self._node_context.models.load(lora.lora) + lora_model = lora_info.model + yield (lora_model, lora.weight) + del lora_info + return + + yield self._patch_model( + model=unet, + prefix=self._prefix, + loras=_lora_loader(), + cached_weights=cached_weights, + ) + + @classmethod + @contextmanager + def static_patch_model( + cls, + model: torch.nn.Module, + prefix: str, + loras: Iterator[Tuple[LoRAModelRaw, float]], + cached_weights: Optional[Dict[str, torch.Tensor]] = None, + ): + modified_cached_weights, modified_weights = cls._patch_model( + model=model, + prefix=prefix, + loras=loras, + cached_weights=cached_weights, + ) + try: + yield + + finally: + with torch.no_grad(): + for param_key in modified_cached_weights: + model.get_parameter(param_key).copy_(cached_weights[param_key]) + for param_key, weight in modified_weights.items(): + model.get_parameter(param_key).copy_(weight) + + @classmethod + def _patch_model( + cls, + model: UNet2DConditionModel, + prefix: str, + loras: Iterator[Tuple[LoRAModelRaw, float]], + cached_weights: Optional[Dict[str, torch.Tensor]] = None, + ): + """ + Apply one or more LoRAs to a model. + :param model: The model to patch. + :param loras: An iterator that returns the LoRA to patch in and its patch weight. + :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. + :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. + """ + if cached_weights is None: + cached_weights = {} + + modified_weights = {} + modified_cached_weights = set() + with torch.no_grad(): + for lora, lora_weight in loras: + # assert lora.device.type == "cpu" + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue + + # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This + # should be improved in the following ways: + # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a + # LoRA model is applied. + # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the + # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA + # weights to have valid keys. + assert isinstance(model, torch.nn.Module) + module_key, module = cls._resolve_lora_key(model, layer_key, prefix) + + # All of the LoRA weight calculations will be done on the same device as the module weight. + # (Performance will be best if this is a CUDA device.) + device = module.weight.device + dtype = module.weight.dtype + + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + + # We intentionally move to the target device first, then cast. Experimentally, this was found to + # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the + # same thing in a single call to '.to(...)'. + layer.to(device=device) + layer.to(dtype=torch.float32) + + # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA + # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. + for param_name, lora_param_weight in layer.get_parameters(module).items(): + param_key = module_key + "." + param_name + module_param = module.get_parameter(param_name) + + # save original weight + if param_key not in modified_cached_weights and param_key not in modified_weights: + if param_key in cached_weights: + modified_cached_weights.add(param_key) + else: + modified_weights[param_key] = module_param.detach().to( + device=TorchDevice.CPU_DEVICE, copy=True + ) + + if module_param.shape != lora_param_weight.shape: + # TODO: debug on lycoris + lora_param_weight = lora_param_weight.reshape(module_param.shape) + + lora_param_weight *= (lora_weight * layer_scale) + module_param += lora_param_weight.to(dtype=dtype) + + layer.to(device=TorchDevice.CPU_DEVICE) + + return modified_cached_weights, modified_weights + + @staticmethod + def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: + assert "." not in lora_key + + if not lora_key.startswith(prefix): + raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}") + + module = model + module_key = "" + key_parts = lora_key[len(prefix) :].split("_") + + submodule_name = key_parts.pop(0) + + while len(key_parts) > 0: + try: + module = module.get_submodule(submodule_name) + module_key += "." + submodule_name + submodule_name = key_parts.pop(0) + except Exception: + submodule_name += "_" + key_parts.pop(0) + + module = module.get_submodule(submodule_name) + module_key = (module_key + "." + submodule_name).lstrip(".") + + return (module_key, module) diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index c8d585406a..4f7e1e0874 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import ExitStack, contextmanager -from typing import TYPE_CHECKING, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set import torch from diffusers import UNet2DConditionModel @@ -67,9 +67,31 @@ class ExtensionsManager: if self._is_canceled and self._is_canceled(): raise CanceledException - # TODO: create weight patch logic in PR with extension which uses it - with ExitStack() as exit_stack: + modified_weights: Dict[str, torch.Tensor] = {} + modified_cached_weights: Set[str] = set() + + exit_stack = ExitStack() + try: for ext in self._extensions: - exit_stack.enter_context(ext.patch_unet(unet, cached_weights)) + res = exit_stack.enter_context(ext.patch_unet(unet, cached_weights)) + if res is None: + continue + ext_modified_cached_weights, ext_modified_weights = res + + modified_cached_weights.update(ext_modified_cached_weights) + # store only first returned weight for each key, because + # next extension which changes it, will work with already modified weight + for param_key, weight in ext_modified_weights.items(): + if param_key in modified_weights: + continue + modified_weights[param_key] = weight yield None + + finally: + exit_stack.close() + with torch.no_grad(): + for param_key in modified_cached_weights: + unet.get_parameter(param_key).copy_(cached_weights[param_key]) + for param_key, weight in modified_weights.items(): + unet.get_parameter(param_key).copy_(weight) From 0ccb304b8b0590002786d53470e0016648750b57 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 24 Jul 2024 16:01:29 +0300 Subject: [PATCH 02/13] Ruff format --- invokeai/backend/stable_diffusion/extensions/lora_patcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/backend/stable_diffusion/extensions/lora_patcher.py b/invokeai/backend/stable_diffusion/extensions/lora_patcher.py index 452bcec1ef..eb045a1ec4 100644 --- a/invokeai/backend/stable_diffusion/extensions/lora_patcher.py +++ b/invokeai/backend/stable_diffusion/extensions/lora_patcher.py @@ -138,7 +138,7 @@ class LoRAPatcherExt(ExtensionBase): # TODO: debug on lycoris lora_param_weight = lora_param_weight.reshape(module_param.shape) - lora_param_weight *= (lora_weight * layer_scale) + lora_param_weight *= lora_weight * layer_scale module_param += lora_param_weight.to(dtype=dtype) layer.to(device=TorchDevice.CPU_DEVICE) From 31949ed2f2be3446b6c505ee7f0272e7c075fcf8 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Thu, 25 Jul 2024 02:00:30 +0300 Subject: [PATCH 03/13] Refactor code a bit --- invokeai/backend/lora.py | 84 +++++++++++++--------------------------- 1 file changed, 27 insertions(+), 57 deletions(-) diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index 21b99d7f6c..714a4a8a2d 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -46,11 +46,18 @@ class LoRALayerBase: self.rank = None # set in layer implementation self.layer_key = layer_key - def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: raise NotImplementedError() - def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: - raise NotImplementedError() + def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]: + return self.bias + + def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]: + params = {"weight": self.get_weight(orig_module.weight)} + bias = self.get_bias(orig_module.bias) + if bias is not None: + params["bias"] = bias + return params def calc_size(self) -> int: model_size = 0 @@ -79,14 +86,11 @@ class LoRALayer(LoRALayerBase): self.up = values["lora_up.weight"] self.down = values["lora_down.weight"] - if "lora_mid.weight" in values: - self.mid: Optional[torch.Tensor] = values["lora_mid.weight"] - else: - self.mid = None + self.mid = values.get("lora_mid.weight", None) self.rank = self.down.shape[0] - def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: if self.mid is not None: up = self.up.reshape(self.up.shape[0], self.up.shape[1]) down = self.down.reshape(self.down.shape[0], self.down.shape[1]) @@ -96,9 +100,6 @@ class LoRALayer(LoRALayerBase): return weight - def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: - return {"weight": self.get_weight(orig_module.weight)} - def calc_size(self) -> int: model_size = super().calc_size() for val in [self.up, self.mid, self.down]: @@ -131,20 +132,12 @@ class LoHALayer(LoRALayerBase): self.w1_b = values["hada_w1_b"] self.w2_a = values["hada_w2_a"] self.w2_b = values["hada_w2_b"] - - if "hada_t1" in values: - self.t1: Optional[torch.Tensor] = values["hada_t1"] - else: - self.t1 = None - - if "hada_t2" in values: - self.t2: Optional[torch.Tensor] = values["hada_t2"] - else: - self.t2 = None + self.t1 = values.get("hada_t1", None) + self.t2 = values.get("hada_t2", None) self.rank = self.w1_b.shape[0] - def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: if self.t1 is None: weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) @@ -155,9 +148,6 @@ class LoHALayer(LoRALayerBase): return weight - def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: - return {"weight": self.get_weight(orig_module.weight)} - def calc_size(self) -> int: model_size = super().calc_size() for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]: @@ -195,37 +185,26 @@ class LoKRLayer(LoRALayerBase): ): super().__init__(layer_key, values) - if "lokr_w1" in values: - self.w1: Optional[torch.Tensor] = values["lokr_w1"] - self.w1_a = None - self.w1_b = None - else: - self.w1 = None + self.w1 = values.get("lokr_w1", None) + if self.w1 is None: self.w1_a = values["lokr_w1_a"] self.w1_b = values["lokr_w1_b"] - if "lokr_w2" in values: - self.w2: Optional[torch.Tensor] = values["lokr_w2"] - self.w2_a = None - self.w2_b = None - else: - self.w2 = None + self.w2 = values.get("lokr_w2", None) + if self.w2 is None: self.w2_a = values["lokr_w2_a"] self.w2_b = values["lokr_w2_b"] - if "lokr_t2" in values: - self.t2: Optional[torch.Tensor] = values["lokr_t2"] - else: - self.t2 = None + self.t2 = values.get("lokr_t2", None) - if "lokr_w1_b" in values: - self.rank = values["lokr_w1_b"].shape[0] - elif "lokr_w2_b" in values: - self.rank = values["lokr_w2_b"].shape[0] + if self.w1_b is not None: + self.rank = self.w1_b.shape[0] + elif self.w2_b is not None: + self.rank = self.w2_b.shape[0] else: self.rank = None # unscaled - def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: w1: Optional[torch.Tensor] = self.w1 if w1 is None: assert self.w1_a is not None @@ -250,9 +229,6 @@ class LoKRLayer(LoRALayerBase): return weight - def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: - return {"weight": self.get_weight(orig_module.weight)} - def calc_size(self) -> int: model_size = super().calc_size() for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]: @@ -302,12 +278,9 @@ class FullLayer(LoRALayerBase): self.rank = None # unscaled - def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: return self.weight - def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: - return {"weight": self.get_weight(orig_module.weight)} - def calc_size(self) -> int: model_size = super().calc_size() model_size += self.weight.nelement() * self.weight.element_size() @@ -335,16 +308,13 @@ class IA3Layer(LoRALayerBase): self.rank = None # unscaled - def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: weight = self.weight if not self.on_input: weight = weight.reshape(-1, 1) assert orig_weight is not None return orig_weight * weight - def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: - return {"weight": self.get_weight(orig_module.weight)} - def calc_size(self) -> int: model_size = super().calc_size() model_size += self.weight.nelement() * self.weight.element_size() From 8a9e2f57a499f7ed62e47598c99243d53b8785e7 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Thu, 25 Jul 2024 02:02:37 +0300 Subject: [PATCH 04/13] Handle bias in full/diff lora layer --- invokeai/backend/lora.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index 714a4a8a2d..b2cba07b2c 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -260,7 +260,9 @@ class LoKRLayer(LoRALayerBase): class FullLayer(LoRALayerBase): + # bias handled in LoRALayerBase(calc_size, to) # weight: torch.Tensor + # bias: Optional[torch.Tensor] def __init__( self, @@ -270,11 +272,7 @@ class FullLayer(LoRALayerBase): super().__init__(layer_key, values) self.weight = values["diff"] - - if len(values.keys()) > 1: - _keys = list(values.keys()) - _keys.remove("diff") - raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}") + self.bias = values.get("diff_b", None) self.rank = None # unscaled From 653f63ae7196f431006b1717222bab6fb76e311d Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Thu, 25 Jul 2024 02:03:08 +0300 Subject: [PATCH 05/13] Add layer keys check --- invokeai/backend/lora.py | 47 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index b2cba07b2c..8f5b1f08f5 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -3,7 +3,7 @@ import bisect from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Set, Tuple, Union import torch from safetensors.torch import load_file @@ -70,6 +70,15 @@ class LoRALayerBase: if self.bias is not None: self.bias = self.bias.to(device=device, dtype=dtype) + def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]): + all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"} + unknown_keys = set(values.keys()) - all_known_keys + if unknown_keys: + # TODO: how to warn log? + print( + f"[WARN] Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}" + ) + # TODO: find and debug lora/locon with bias class LoRALayer(LoRALayerBase): @@ -89,6 +98,14 @@ class LoRALayer(LoRALayerBase): self.mid = values.get("lora_mid.weight", None) self.rank = self.down.shape[0] + self.check_keys( + values, + { + "lora_up.weight", + "lora_down.weight", + "lora_mid.weight", + }, + ) def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: if self.mid is not None: @@ -136,6 +153,17 @@ class LoHALayer(LoRALayerBase): self.t2 = values.get("hada_t2", None) self.rank = self.w1_b.shape[0] + self.check_keys( + values, + { + "hada_w1_a", + "hada_w1_b", + "hada_w2_a", + "hada_w2_b", + "hada_t1", + "hada_t2", + }, + ) def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: if self.t1 is None: @@ -204,6 +232,21 @@ class LoKRLayer(LoRALayerBase): else: self.rank = None # unscaled + # Although lokr_t1 not used in algo, it still defined in LoKR weights + self.check_keys( + values, + { + "lokr_w1", + "lokr_w1_a", + "lokr_w1_b", + "lokr_w2", + "lokr_w2_a", + "lokr_w2_b", + "lokr_t1", + "lokr_t2", + }, + ) + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: w1: Optional[torch.Tensor] = self.w1 if w1 is None: @@ -275,6 +318,7 @@ class FullLayer(LoRALayerBase): self.bias = values.get("diff_b", None) self.rank = None # unscaled + self.check_keys(values, {"diff", "diff_b"}) def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: return self.weight @@ -305,6 +349,7 @@ class IA3Layer(LoRALayerBase): self.on_input = values["on_input"] self.rank = None # unscaled + self.check_keys(values, {"weight", "on_input"}) def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: weight = self.weight From 46c632e7ccb5673dd04c79a576433485f414eac2 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Thu, 25 Jul 2024 02:10:47 +0300 Subject: [PATCH 06/13] Change layer detection keys according to LyCORIS repository --- invokeai/backend/lora.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index 8f5b1f08f5..8d8ce04d66 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -490,15 +490,15 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module): for layer_key, values in state_dict.items(): # lora and locon - if "lora_down.weight" in values: + if "lora_up.weight" in values: layer: AnyLoRALayer = LoRALayer(layer_key, values) # loha - elif "hada_w1_b" in values: + elif "hada_w1_a" in values: layer = LoHALayer(layer_key, values) # lokr - elif "lokr_w1_b" in values or "lokr_w1" in values: + elif "lokr_w1" in values or "lokr_w1_a" in values: layer = LoKRLayer(layer_key, values) # diff @@ -506,7 +506,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module): layer = FullLayer(layer_key, values) # ia3 - elif "weight" in values and "on_input" in values: + elif "on_input" in values: layer = IA3Layer(layer_key, values) else: From faa88f72bf98d49349b5097568e8fde6f2c6c1f5 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sat, 27 Jul 2024 02:39:53 +0300 Subject: [PATCH 07/13] Make lora as separate extensions --- invokeai/app/invocations/compel.py | 8 +- invokeai/app/invocations/denoise_latents.py | 19 +- invokeai/backend/model_patcher.py | 91 ++++----- .../stable_diffusion/extensions/lora.py | 145 +++++++++++++++ .../extensions/lora_patcher.py | 172 ------------------ 5 files changed, 190 insertions(+), 245 deletions(-) create mode 100644 invokeai/backend/stable_diffusion/extensions/lora.py delete mode 100644 invokeai/backend/stable_diffusion/extensions/lora_patcher.py diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index fffb09e654..5905df8dd7 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -80,12 +80,12 @@ class CompelInvocation(BaseInvocation): with ( # apply all patches while the model is on the target device - text_encoder_info.model_on_device() as (model_state_dict, text_encoder), + text_encoder_info.model_on_device() as (cached_weights, text_encoder), tokenizer_info as tokenizer, ModelPatcher.apply_lora_text_encoder( text_encoder, loras=_lora_loader(), - model_state_dict=model_state_dict, + cached_weights=cached_weights, ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers), @@ -175,13 +175,13 @@ class SDXLPromptInvocationBase: with ( # apply all patches while the model is on the target device - text_encoder_info.model_on_device() as (state_dict, text_encoder), + text_encoder_info.model_on_device() as (cached_weights, text_encoder), tokenizer_info as tokenizer, ModelPatcher.apply_lora( text_encoder, loras=_lora_loader(), prefix=lora_prefix, - model_state_dict=state_dict, + cached_weights=cached_weights, ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers), diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 39d2d3e08f..8795a44714 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -60,7 +60,7 @@ from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionB from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt -from invokeai.backend.stable_diffusion.extensions.lora_patcher import LoRAPatcherExt +from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager @@ -836,13 +836,14 @@ class DenoiseLatentsInvocation(BaseInvocation): ### lora if self.unet.loras: - ext_manager.add_extension( - LoRAPatcherExt( - node_context=context, - loras=self.unet.loras, - prefix="lora_unet_", + for lora_field in self.unet.loras: + ext_manager.add_extension( + LoRAExt( + node_context=context, + model_id=lora_field.lora, + weight=lora_field.weight, + ) ) - ) # context for loading additional models with ExitStack() as exit_stack: @@ -924,14 +925,14 @@ class DenoiseLatentsInvocation(BaseInvocation): assert isinstance(unet_info.model, UNet2DConditionModel) with ( ExitStack() as exit_stack, - unet_info.model_on_device() as (model_state_dict, unet), + unet_info.model_on_device() as (cached_weights, unet), ModelPatcher.apply_freeu(unet, self.unet.freeu_config), set_seamless(unet, self.unet.seamless_axes), # FIXME # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet( unet, loras=_lora_loader(), - model_state_dict=model_state_dict, + cached_weights=cached_weights, ), ): assert isinstance(unet, UNet2DConditionModel) diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index d30f7b3167..64893aa533 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -5,7 +5,7 @@ from __future__ import annotations import pickle from contextlib import contextmanager -from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Generator, Iterator, List, Optional, Set, Tuple, Type, Union import numpy as np import torch @@ -17,8 +17,8 @@ from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import AnyModel from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel +from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw -from invokeai.backend.util.devices import TorchDevice """ loras = [ @@ -85,13 +85,13 @@ class ModelPatcher: cls, unet: UNet2DConditionModel, loras: Iterator[Tuple[LoRAModelRaw, float]], - model_state_dict: Optional[Dict[str, torch.Tensor]] = None, + cached_weights: Optional[Dict[str, torch.Tensor]] = None, ) -> Generator[None, None, None]: with cls.apply_lora( unet, loras=loras, prefix="lora_unet_", - model_state_dict=model_state_dict, + cached_weights=cached_weights, ): yield @@ -101,9 +101,9 @@ class ModelPatcher: cls, text_encoder: CLIPTextModel, loras: Iterator[Tuple[LoRAModelRaw, float]], - model_state_dict: Optional[Dict[str, torch.Tensor]] = None, + cached_weights: Optional[Dict[str, torch.Tensor]] = None, ) -> Generator[None, None, None]: - with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict): + with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights): yield @classmethod @@ -113,7 +113,7 @@ class ModelPatcher: model: AnyModel, loras: Iterator[Tuple[LoRAModelRaw, float]], prefix: str, - model_state_dict: Optional[Dict[str, torch.Tensor]] = None, + cached_weights: Optional[Dict[str, torch.Tensor]] = None, ) -> Generator[None, None, None]: """ Apply one or more LoRAs to a model. @@ -121,66 +121,37 @@ class ModelPatcher: :param model: The model to patch. :param loras: An iterator that returns the LoRA to patch in and its patch weight. :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. - :model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes. + :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. """ - original_weights = {} + modified_cached_weights: Set[str] = set() + modified_weights: Dict[str, torch.Tensor] = {} try: - with torch.no_grad(): - for lora, lora_weight in loras: - # assert lora.device.type == "cpu" - for layer_key, layer in lora.layers.items(): - if not layer_key.startswith(prefix): - continue + for lora_model, lora_weight in loras: + lora_modified_cached_weights, lora_modified_weights = LoRAExt.patch_model( + model=model, + prefix=prefix, + lora=lora_model, + lora_weight=lora_weight, + cached_weights=cached_weights, + ) + del lora_model - # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This - # should be improved in the following ways: - # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a - # LoRA model is applied. - # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the - # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA - # weights to have valid keys. - assert isinstance(model, torch.nn.Module) - module_key, module = cls._resolve_lora_key(model, layer_key, prefix) + modified_cached_weights.update(lora_modified_cached_weights) + # Store only first returned weight for each key, because + # next extension which changes it, will work with already modified weight + for param_key, weight in lora_modified_weights.items(): + if param_key in modified_weights: + continue + modified_weights[param_key] = weight - # All of the LoRA weight calculations will be done on the same device as the module weight. - # (Performance will be best if this is a CUDA device.) - device = module.weight.device - dtype = module.weight.dtype - - if module_key not in original_weights: - if model_state_dict is not None: # we were provided with the CPU copy of the state dict - original_weights[module_key] = model_state_dict[module_key + ".weight"] - else: - original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) - - layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 - - # We intentionally move to the target device first, then cast. Experimentally, this was found to - # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the - # same thing in a single call to '.to(...)'. - layer.to(device=device) - layer.to(dtype=torch.float32) - # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA - # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. - layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) - layer.to(device=TorchDevice.CPU_DEVICE) - - assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! - if module.weight.shape != layer_weight.shape: - # TODO: debug on lycoris - assert hasattr(layer_weight, "reshape") - layer_weight = layer_weight.reshape(module.weight.shape) - - assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! - module.weight += layer_weight.to(dtype=dtype) - - yield # wait for context manager exit + yield finally: - assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() with torch.no_grad(): - for module_key, weight in original_weights.items(): - model.get_submodule(module_key).weight.copy_(weight) + for param_key in modified_cached_weights: + model.get_parameter(param_key).copy_(cached_weights[param_key]) + for param_key, weight in modified_weights.items(): + model.get_parameter(param_key).copy_(weight) @classmethod @contextmanager diff --git a/invokeai/backend/stable_diffusion/extensions/lora.py b/invokeai/backend/stable_diffusion/extensions/lora.py new file mode 100644 index 0000000000..11cdeb6021 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/lora.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple + +import torch +from diffusers import UNet2DConditionModel + +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase +from invokeai.backend.util.devices import TorchDevice + +if TYPE_CHECKING: + from invokeai.app.invocations.model import ModelIdentifierField + from invokeai.app.services.shared.invocation_context import InvocationContext + from invokeai.backend.lora import LoRAModelRaw + + +class LoRAExt(ExtensionBase): + def __init__( + self, + node_context: InvocationContext, + model_id: ModelIdentifierField, + weight: float, + ): + super().__init__() + self._node_context = node_context + self._model_id = model_id + self._weight = weight + + @contextmanager + def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): + lora_model = self._node_context.models.load(self._model_id).model + modified_cached_weights, modified_weights = self.patch_model( + model=unet, + prefix="lora_unet_", + lora=lora_model, + lora_weight=self._weight, + cached_weights=cached_weights, + ) + del lora_model + + yield modified_cached_weights, modified_weights + + @classmethod + def patch_model( + cls, + model: torch.nn.Module, + prefix: str, + lora: LoRAModelRaw, + lora_weight: float, + cached_weights: Optional[Dict[str, torch.Tensor]] = None, + ): + """ + Apply one or more LoRAs to a model. + :param model: The model to patch. + :param lora: LoRA model to patch in. + :param lora_weight: LoRA patch weight. + :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. + :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. + """ + if cached_weights is None: + cached_weights = {} + + modified_weights: Dict[str, torch.Tensor] = {} + modified_cached_weights: Set[str] = set() + with torch.no_grad(): + # assert lora.device.type == "cpu" + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue + + # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This + # should be improved in the following ways: + # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a + # LoRA model is applied. + # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the + # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA + # weights to have valid keys. + assert isinstance(model, torch.nn.Module) + module_key, module = cls._resolve_lora_key(model, layer_key, prefix) + + # All of the LoRA weight calculations will be done on the same device as the module weight. + # (Performance will be best if this is a CUDA device.) + device = module.weight.device + dtype = module.weight.dtype + + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + + # We intentionally move to the target device first, then cast. Experimentally, this was found to + # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the + # same thing in a single call to '.to(...)'. + layer.to(device=device) + layer.to(dtype=torch.float32) + + # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA + # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. + for param_name, lora_param_weight in layer.get_parameters(module).items(): + param_key = module_key + "." + param_name + module_param = module.get_parameter(param_name) + + # save original weight + if param_key not in modified_cached_weights and param_key not in modified_weights: + if param_key in cached_weights: + modified_cached_weights.add(param_key) + else: + modified_weights[param_key] = module_param.detach().to( + device=TorchDevice.CPU_DEVICE, copy=True + ) + + if module_param.shape != lora_param_weight.shape: + # TODO: debug on lycoris + lora_param_weight = lora_param_weight.reshape(module_param.shape) + + lora_param_weight *= lora_weight * layer_scale + module_param += lora_param_weight.to(dtype=dtype) + + layer.to(device=TorchDevice.CPU_DEVICE) + + return modified_cached_weights, modified_weights + + @staticmethod + def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: + assert "." not in lora_key + + if not lora_key.startswith(prefix): + raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}") + + module = model + module_key = "" + key_parts = lora_key[len(prefix) :].split("_") + + submodule_name = key_parts.pop(0) + + while len(key_parts) > 0: + try: + module = module.get_submodule(submodule_name) + module_key += "." + submodule_name + submodule_name = key_parts.pop(0) + except Exception: + submodule_name += "_" + key_parts.pop(0) + + module = module.get_submodule(submodule_name) + module_key = (module_key + "." + submodule_name).lstrip(".") + + return (module_key, module) diff --git a/invokeai/backend/stable_diffusion/extensions/lora_patcher.py b/invokeai/backend/stable_diffusion/extensions/lora_patcher.py deleted file mode 100644 index eb045a1ec4..0000000000 --- a/invokeai/backend/stable_diffusion/extensions/lora_patcher.py +++ /dev/null @@ -1,172 +0,0 @@ -from __future__ import annotations - -from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple - -import torch -from diffusers import UNet2DConditionModel - -from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase -from invokeai.backend.util.devices import TorchDevice - -if TYPE_CHECKING: - from invokeai.app.invocations.model import LoRAField - from invokeai.app.services.shared.invocation_context import InvocationContext - from invokeai.backend.lora import LoRAModelRaw - - -class LoRAPatcherExt(ExtensionBase): - def __init__( - self, - node_context: InvocationContext, - loras: List[LoRAField], - prefix: str, - ): - super().__init__() - self._loras = loras - self._prefix = prefix - self._node_context = node_context - - @contextmanager - def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): - def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: - for lora in self._loras: - lora_info = self._node_context.models.load(lora.lora) - lora_model = lora_info.model - yield (lora_model, lora.weight) - del lora_info - return - - yield self._patch_model( - model=unet, - prefix=self._prefix, - loras=_lora_loader(), - cached_weights=cached_weights, - ) - - @classmethod - @contextmanager - def static_patch_model( - cls, - model: torch.nn.Module, - prefix: str, - loras: Iterator[Tuple[LoRAModelRaw, float]], - cached_weights: Optional[Dict[str, torch.Tensor]] = None, - ): - modified_cached_weights, modified_weights = cls._patch_model( - model=model, - prefix=prefix, - loras=loras, - cached_weights=cached_weights, - ) - try: - yield - - finally: - with torch.no_grad(): - for param_key in modified_cached_weights: - model.get_parameter(param_key).copy_(cached_weights[param_key]) - for param_key, weight in modified_weights.items(): - model.get_parameter(param_key).copy_(weight) - - @classmethod - def _patch_model( - cls, - model: UNet2DConditionModel, - prefix: str, - loras: Iterator[Tuple[LoRAModelRaw, float]], - cached_weights: Optional[Dict[str, torch.Tensor]] = None, - ): - """ - Apply one or more LoRAs to a model. - :param model: The model to patch. - :param loras: An iterator that returns the LoRA to patch in and its patch weight. - :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. - :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. - """ - if cached_weights is None: - cached_weights = {} - - modified_weights = {} - modified_cached_weights = set() - with torch.no_grad(): - for lora, lora_weight in loras: - # assert lora.device.type == "cpu" - for layer_key, layer in lora.layers.items(): - if not layer_key.startswith(prefix): - continue - - # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This - # should be improved in the following ways: - # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a - # LoRA model is applied. - # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the - # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA - # weights to have valid keys. - assert isinstance(model, torch.nn.Module) - module_key, module = cls._resolve_lora_key(model, layer_key, prefix) - - # All of the LoRA weight calculations will be done on the same device as the module weight. - # (Performance will be best if this is a CUDA device.) - device = module.weight.device - dtype = module.weight.dtype - - layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 - - # We intentionally move to the target device first, then cast. Experimentally, this was found to - # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the - # same thing in a single call to '.to(...)'. - layer.to(device=device) - layer.to(dtype=torch.float32) - - # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA - # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. - for param_name, lora_param_weight in layer.get_parameters(module).items(): - param_key = module_key + "." + param_name - module_param = module.get_parameter(param_name) - - # save original weight - if param_key not in modified_cached_weights and param_key not in modified_weights: - if param_key in cached_weights: - modified_cached_weights.add(param_key) - else: - modified_weights[param_key] = module_param.detach().to( - device=TorchDevice.CPU_DEVICE, copy=True - ) - - if module_param.shape != lora_param_weight.shape: - # TODO: debug on lycoris - lora_param_weight = lora_param_weight.reshape(module_param.shape) - - lora_param_weight *= lora_weight * layer_scale - module_param += lora_param_weight.to(dtype=dtype) - - layer.to(device=TorchDevice.CPU_DEVICE) - - return modified_cached_weights, modified_weights - - @staticmethod - def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: - assert "." not in lora_key - - if not lora_key.startswith(prefix): - raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}") - - module = model - module_key = "" - key_parts = lora_key[len(prefix) :].split("_") - - submodule_name = key_parts.pop(0) - - while len(key_parts) > 0: - try: - module = module.get_submodule(submodule_name) - module_key += "." + submodule_name - submodule_name = key_parts.pop(0) - except Exception: - submodule_name += "_" + key_parts.pop(0) - - module = module.get_submodule(submodule_name) - module_key = (module_key + "." + submodule_name).lstrip(".") - - return (module_key, module) From 9e582563eb2623386551f87794a590edbb06159e Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sat, 27 Jul 2024 04:25:15 +0300 Subject: [PATCH 08/13] Suggested changes Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com> --- invokeai/backend/lora.py | 5 ++-- .../stable_diffusion/extensions/base.py | 18 ++++++++++-- .../stable_diffusion/extensions/freeu.py | 8 ++++-- .../stable_diffusion/extensions/lora.py | 8 ++++-- .../stable_diffusion/extensions_manager.py | 28 +++++++++---------- 5 files changed, 41 insertions(+), 26 deletions(-) diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index 8d8ce04d66..f5f3eedaa8 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -71,6 +71,9 @@ class LoRALayerBase: self.bias = self.bias.to(device=device, dtype=dtype) def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]): + """Log a warning if values contains unhandled keys.""" + # {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by + # `LoRALayerBase`. Sub-classes should provide the known_keys that they handled. all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"} unknown_keys = set(values.keys()) - all_known_keys if unknown_keys: @@ -232,7 +235,6 @@ class LoKRLayer(LoRALayerBase): else: self.rank = None # unscaled - # Although lokr_t1 not used in algo, it still defined in LoKR weights self.check_keys( values, { @@ -242,7 +244,6 @@ class LoKRLayer(LoRALayerBase): "lokr_w2", "lokr_w2_a", "lokr_w2_b", - "lokr_t1", "lokr_t2", }, ) diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index 820d5d32a3..1208c3f0ee 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -2,7 +2,7 @@ from __future__ import annotations from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple import torch from diffusers import UNet2DConditionModel @@ -56,5 +56,17 @@ class ExtensionBase: yield None @contextmanager - def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): - yield None + def patch_unet( + self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None + ) -> Tuple[Set[str], Dict[str, torch.Tensor]]: + """Apply patches to UNet model. This function responsible for restoring all changes except weights, + changed weights should only be reported in return. + Return contains 2 values: + - Set of cached weights, just keys from cached_weights dictionary + - Dict of not cached weights that should be copies on the cpu device + + Args: + unet (UNet2DConditionModel): The UNet model on execution device to patch. + cached_weights (Optional[Dict[str, torch.Tensor]]): Read-only copy of the model's state dict in CPU, for caches purposes. + """ + yield set(), {} diff --git a/invokeai/backend/stable_diffusion/extensions/freeu.py b/invokeai/backend/stable_diffusion/extensions/freeu.py index 6ec4fea3fa..593481f198 100644 --- a/invokeai/backend/stable_diffusion/extensions/freeu.py +++ b/invokeai/backend/stable_diffusion/extensions/freeu.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple import torch from diffusers import UNet2DConditionModel @@ -21,7 +21,9 @@ class FreeUExt(ExtensionBase): self._freeu_config = freeu_config @contextmanager - def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): + def patch_unet( + self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None + ) -> Tuple[Set[str], Dict[str, torch.Tensor]]: unet.enable_freeu( b1=self._freeu_config.b1, b2=self._freeu_config.b2, @@ -30,6 +32,6 @@ class FreeUExt(ExtensionBase): ) try: - yield + yield set(), {} finally: unet.disable_freeu() diff --git a/invokeai/backend/stable_diffusion/extensions/lora.py b/invokeai/backend/stable_diffusion/extensions/lora.py index 11cdeb6021..55f7259d96 100644 --- a/invokeai/backend/stable_diffusion/extensions/lora.py +++ b/invokeai/backend/stable_diffusion/extensions/lora.py @@ -28,7 +28,9 @@ class LoRAExt(ExtensionBase): self._weight = weight @contextmanager - def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): + def patch_unet( + self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None + ) -> Tuple[Set[str], Dict[str, torch.Tensor]]: lora_model = self._node_context.models.load(self._model_id).model modified_cached_weights, modified_weights = self.patch_model( model=unet, @@ -49,14 +51,14 @@ class LoRAExt(ExtensionBase): lora: LoRAModelRaw, lora_weight: float, cached_weights: Optional[Dict[str, torch.Tensor]] = None, - ): + ) -> Tuple[Set[str], Dict[str, torch.Tensor]]: """ Apply one or more LoRAs to a model. :param model: The model to patch. :param lora: LoRA model to patch in. :param lora_weight: LoRA patch weight. :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. - :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. + :param cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. """ if cached_weights is None: cached_weights = {} diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 4f7e1e0874..e838a2034b 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -70,26 +70,24 @@ class ExtensionsManager: modified_weights: Dict[str, torch.Tensor] = {} modified_cached_weights: Set[str] = set() - exit_stack = ExitStack() try: - for ext in self._extensions: - res = exit_stack.enter_context(ext.patch_unet(unet, cached_weights)) - if res is None: - continue - ext_modified_cached_weights, ext_modified_weights = res + with ExitStack() as exit_stack: + for ext in self._extensions: + ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context( + ext.patch_unet(unet, cached_weights) + ) - modified_cached_weights.update(ext_modified_cached_weights) - # store only first returned weight for each key, because - # next extension which changes it, will work with already modified weight - for param_key, weight in ext_modified_weights.items(): - if param_key in modified_weights: - continue - modified_weights[param_key] = weight + modified_cached_weights.update(ext_modified_cached_weights) + # store only first returned weight for each key, because + # next extension which changes it, will work with already modified weight + for param_key, weight in ext_modified_weights.items(): + if param_key in modified_weights: + continue + modified_weights[param_key] = weight - yield None + yield None finally: - exit_stack.close() with torch.no_grad(): for param_key in modified_cached_weights: unet.get_parameter(param_key).copy_(cached_weights[param_key]) From 8500bac3ca3e6ad4d90a7380264bf307fe80b0ab Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 28 Jul 2024 22:51:52 +0300 Subject: [PATCH 09/13] Use logger for warning --- invokeai/backend/lora.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index f5f3eedaa8..b4b59b5fcf 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -9,6 +9,7 @@ import torch from safetensors.torch import load_file from typing_extensions import Self +import invokeai.backend.util.logging as logger from invokeai.backend.model_manager import BaseModelType from invokeai.backend.raw_model import RawModel @@ -77,9 +78,8 @@ class LoRALayerBase: all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"} unknown_keys = set(values.keys()) - all_known_keys if unknown_keys: - # TODO: how to warn log? - print( - f"[WARN] Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}" + logger.warning( + f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}" ) From 2227a2357f8163e6a841851f7b6e7cf63dab96b3 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 30 Jul 2024 00:34:37 +0300 Subject: [PATCH 10/13] Suggested changes + simplify weights logic in patching Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com> --- invokeai/backend/lora.py | 3 + invokeai/backend/model_patcher.py | 23 ++-- .../stable_diffusion/extensions/base.py | 22 ++-- .../stable_diffusion/extensions/freeu.py | 8 +- .../stable_diffusion/extensions/lora.py | 105 ++++++++---------- .../stable_diffusion/extensions_manager.py | 23 +--- 6 files changed, 76 insertions(+), 108 deletions(-) diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index b4b59b5fcf..cec76ffea2 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -490,6 +490,9 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module): state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) for layer_key, values in state_dict.items(): + # Detect layers according to LyCORIS detection logic(`weight_list_det`) + # https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules + # lora and locon if "lora_up.weight" in values: layer: AnyLoRALayer = LoRALayer(layer_key, values) diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index 64893aa533..c0dc1dca1d 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -5,7 +5,7 @@ from __future__ import annotations import pickle from contextlib import contextmanager -from typing import Any, Dict, Generator, Iterator, List, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union import numpy as np import torch @@ -123,34 +123,25 @@ class ModelPatcher: :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. """ - modified_cached_weights: Set[str] = set() - modified_weights: Dict[str, torch.Tensor] = {} + original_weights: Dict[str, torch.Tensor] = {} + if cached_weights: + original_weights.update(cached_weights) try: for lora_model, lora_weight in loras: - lora_modified_cached_weights, lora_modified_weights = LoRAExt.patch_model( + LoRAExt.patch_model( model=model, prefix=prefix, lora=lora_model, lora_weight=lora_weight, - cached_weights=cached_weights, + original_weights=original_weights, ) del lora_model - modified_cached_weights.update(lora_modified_cached_weights) - # Store only first returned weight for each key, because - # next extension which changes it, will work with already modified weight - for param_key, weight in lora_modified_weights.items(): - if param_key in modified_weights: - continue - modified_weights[param_key] = weight - yield finally: with torch.no_grad(): - for param_key in modified_cached_weights: - model.get_parameter(param_key).copy_(cached_weights[param_key]) - for param_key, weight in modified_weights.items(): + for param_key, weight in original_weights.items(): model.get_parameter(param_key).copy_(weight) @classmethod diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index 1208c3f0ee..f9753b4344 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -2,7 +2,7 @@ from __future__ import annotations from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Callable, Dict, List import torch from diffusers import UNet2DConditionModel @@ -56,17 +56,17 @@ class ExtensionBase: yield None @contextmanager - def patch_unet( - self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None - ) -> Tuple[Set[str], Dict[str, torch.Tensor]]: - """Apply patches to UNet model. This function responsible for restoring all changes except weights, - changed weights should only be reported in return. - Return contains 2 values: - - Set of cached weights, just keys from cached_weights dictionary - - Dict of not cached weights that should be copies on the cpu device + def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]): + """A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire + diffusion process. Weight unpatching is handled upstream, and is achieved by adding unsaved weights in + `original_weights` dict. Note that this enables some performance optimization by avoiding redundant operations. + All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched by this + context manager. Args: unet (UNet2DConditionModel): The UNet model on execution device to patch. - cached_weights (Optional[Dict[str, torch.Tensor]]): Read-only copy of the model's state dict in CPU, for caches purposes. + cached_weights (Dict[str, torch.Tensor]]): A read-only copy of the model's original weights in CPU, for + unpatching purposes. Extension can save tensor which being modified, if it is not saved yet, or can + access original weight value. """ - yield set(), {} + yield diff --git a/invokeai/backend/stable_diffusion/extensions/freeu.py b/invokeai/backend/stable_diffusion/extensions/freeu.py index 593481f198..75370d23f4 100644 --- a/invokeai/backend/stable_diffusion/extensions/freeu.py +++ b/invokeai/backend/stable_diffusion/extensions/freeu.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict import torch from diffusers import UNet2DConditionModel @@ -21,9 +21,7 @@ class FreeUExt(ExtensionBase): self._freeu_config = freeu_config @contextmanager - def patch_unet( - self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None - ) -> Tuple[Set[str], Dict[str, torch.Tensor]]: + def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]): unet.enable_freeu( b1=self._freeu_config.b1, b2=self._freeu_config.b2, @@ -32,6 +30,6 @@ class FreeUExt(ExtensionBase): ) try: - yield set(), {} + yield finally: unet.disable_freeu() diff --git a/invokeai/backend/stable_diffusion/extensions/lora.py b/invokeai/backend/stable_diffusion/extensions/lora.py index 55f7259d96..71584247c0 100644 --- a/invokeai/backend/stable_diffusion/extensions/lora.py +++ b/invokeai/backend/stable_diffusion/extensions/lora.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Tuple import torch from diffusers import UNet2DConditionModel @@ -28,97 +28,84 @@ class LoRAExt(ExtensionBase): self._weight = weight @contextmanager - def patch_unet( - self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None - ) -> Tuple[Set[str], Dict[str, torch.Tensor]]: + def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]): lora_model = self._node_context.models.load(self._model_id).model - modified_cached_weights, modified_weights = self.patch_model( + self.patch_model( model=unet, prefix="lora_unet_", lora=lora_model, lora_weight=self._weight, - cached_weights=cached_weights, + original_weights=original_weights, ) del lora_model - yield modified_cached_weights, modified_weights + yield @classmethod + @torch.no_grad() def patch_model( cls, model: torch.nn.Module, prefix: str, lora: LoRAModelRaw, lora_weight: float, - cached_weights: Optional[Dict[str, torch.Tensor]] = None, - ) -> Tuple[Set[str], Dict[str, torch.Tensor]]: + original_weights: Dict[str, torch.Tensor], + ): """ Apply one or more LoRAs to a model. :param model: The model to patch. :param lora: LoRA model to patch in. :param lora_weight: LoRA patch weight. :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. - :param cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. + :param original_weights: TODO: """ - if cached_weights is None: - cached_weights = {} - modified_weights: Dict[str, torch.Tensor] = {} - modified_cached_weights: Set[str] = set() - with torch.no_grad(): - # assert lora.device.type == "cpu" - for layer_key, layer in lora.layers.items(): - if not layer_key.startswith(prefix): - continue + # assert lora.device.type == "cpu" + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue - # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This - # should be improved in the following ways: - # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a - # LoRA model is applied. - # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the - # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA - # weights to have valid keys. - assert isinstance(model, torch.nn.Module) - module_key, module = cls._resolve_lora_key(model, layer_key, prefix) + # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This + # should be improved in the following ways: + # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a + # LoRA model is applied. + # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the + # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA + # weights to have valid keys. + assert isinstance(model, torch.nn.Module) + module_key, module = cls._resolve_lora_key(model, layer_key, prefix) - # All of the LoRA weight calculations will be done on the same device as the module weight. - # (Performance will be best if this is a CUDA device.) - device = module.weight.device - dtype = module.weight.dtype + # All of the LoRA weight calculations will be done on the same device as the module weight. + # (Performance will be best if this is a CUDA device.) + device = module.weight.device + dtype = module.weight.dtype - layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 - # We intentionally move to the target device first, then cast. Experimentally, this was found to - # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the - # same thing in a single call to '.to(...)'. - layer.to(device=device) - layer.to(dtype=torch.float32) + # We intentionally move to the target device first, then cast. Experimentally, this was found to + # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the + # same thing in a single call to '.to(...)'. + layer.to(device=device) + layer.to(dtype=torch.float32) - # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA - # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. - for param_name, lora_param_weight in layer.get_parameters(module).items(): - param_key = module_key + "." + param_name - module_param = module.get_parameter(param_name) + # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA + # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. + for param_name, lora_param_weight in layer.get_parameters(module).items(): + param_key = module_key + "." + param_name + module_param = module.get_parameter(param_name) - # save original weight - if param_key not in modified_cached_weights and param_key not in modified_weights: - if param_key in cached_weights: - modified_cached_weights.add(param_key) - else: - modified_weights[param_key] = module_param.detach().to( - device=TorchDevice.CPU_DEVICE, copy=True - ) + # save original weight + if param_key not in original_weights: + original_weights[param_key] = module_param.detach().to(device=TorchDevice.CPU_DEVICE, copy=True) - if module_param.shape != lora_param_weight.shape: - # TODO: debug on lycoris - lora_param_weight = lora_param_weight.reshape(module_param.shape) + if module_param.shape != lora_param_weight.shape: + # TODO: debug on lycoris + lora_param_weight = lora_param_weight.reshape(module_param.shape) - lora_param_weight *= lora_weight * layer_scale - module_param += lora_param_weight.to(dtype=dtype) + lora_param_weight *= lora_weight * layer_scale + module_param += lora_param_weight.to(dtype=dtype) - layer.to(device=TorchDevice.CPU_DEVICE) - - return modified_cached_weights, modified_weights + layer.to(device=TorchDevice.CPU_DEVICE) @staticmethod def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index e838a2034b..968dffd069 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import ExitStack, contextmanager -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Callable, Dict, List, Optional import torch from diffusers import UNet2DConditionModel @@ -67,29 +67,18 @@ class ExtensionsManager: if self._is_canceled and self._is_canceled(): raise CanceledException - modified_weights: Dict[str, torch.Tensor] = {} - modified_cached_weights: Set[str] = set() + original_weights: Dict[str, torch.Tensor] = {} + if cached_weights: + original_weights.update(cached_weights) try: with ExitStack() as exit_stack: for ext in self._extensions: - ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context( - ext.patch_unet(unet, cached_weights) - ) - - modified_cached_weights.update(ext_modified_cached_weights) - # store only first returned weight for each key, because - # next extension which changes it, will work with already modified weight - for param_key, weight in ext_modified_weights.items(): - if param_key in modified_weights: - continue - modified_weights[param_key] = weight + exit_stack.enter_context(ext.patch_unet(unet, original_weights)) yield None finally: with torch.no_grad(): - for param_key in modified_cached_weights: - unet.get_parameter(param_key).copy_(cached_weights[param_key]) - for param_key, weight in modified_weights.items(): + for param_key, weight in original_weights.items(): unet.get_parameter(param_key).copy_(weight) From 1fd9631f2d2320eb474562715468539e97d25e9b Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 30 Jul 2024 00:39:50 +0300 Subject: [PATCH 11/13] Comments fix Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com> --- invokeai/backend/stable_diffusion/extensions/base.py | 2 +- invokeai/backend/stable_diffusion/extensions/lora.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index f9753b4344..61276e0784 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -65,7 +65,7 @@ class ExtensionBase: Args: unet (UNet2DConditionModel): The UNet model on execution device to patch. - cached_weights (Dict[str, torch.Tensor]]): A read-only copy of the model's original weights in CPU, for + original_weights (Dict[str, torch.Tensor]]): A read-only copy of the model's original weights in CPU, for unpatching purposes. Extension can save tensor which being modified, if it is not saved yet, or can access original weight value. """ diff --git a/invokeai/backend/stable_diffusion/extensions/lora.py b/invokeai/backend/stable_diffusion/extensions/lora.py index 71584247c0..cfb97a2cb2 100644 --- a/invokeai/backend/stable_diffusion/extensions/lora.py +++ b/invokeai/backend/stable_diffusion/extensions/lora.py @@ -57,7 +57,7 @@ class LoRAExt(ExtensionBase): :param lora: LoRA model to patch in. :param lora_weight: LoRA patch weight. :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. - :param original_weights: TODO: + :param original_weights: Dict of original weights, filled by weights which lora patches, used for unpatching. """ # assert lora.device.type == "cpu" From 86f705bf484196b07f329c7aade50c7b987da941 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 30 Jul 2024 03:39:01 +0300 Subject: [PATCH 12/13] Optimize weights handling --- invokeai/backend/model_patcher.py | 7 ++-- .../stable_diffusion/extensions/base.py | 18 +++++----- .../stable_diffusion/extensions/freeu.py | 6 ++-- .../stable_diffusion/extensions/lora.py | 15 ++++---- .../stable_diffusion/extensions_manager.py | 8 ++--- .../backend/util/original_weights_storage.py | 35 +++++++++++++++++++ 6 files changed, 62 insertions(+), 27 deletions(-) create mode 100644 invokeai/backend/util/original_weights_storage.py diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index c0dc1dca1d..e2f22ba019 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -19,6 +19,7 @@ from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_ from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw +from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage """ loras = [ @@ -123,9 +124,7 @@ class ModelPatcher: :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. """ - original_weights: Dict[str, torch.Tensor] = {} - if cached_weights: - original_weights.update(cached_weights) + original_weights = OriginalWeightsStorage(cached_weights) try: for lora_model, lora_weight in loras: LoRAExt.patch_model( @@ -141,7 +140,7 @@ class ModelPatcher: finally: with torch.no_grad(): - for param_key, weight in original_weights.items(): + for param_key, weight in original_weights.get_changed_weights(): model.get_parameter(param_key).copy_(weight) @classmethod diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index 61276e0784..a3d27464a0 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -4,12 +4,12 @@ from contextlib import contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, Dict, List -import torch from diffusers import UNet2DConditionModel if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType + from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage @dataclass @@ -56,17 +56,17 @@ class ExtensionBase: yield None @contextmanager - def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]): + def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage): """A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire - diffusion process. Weight unpatching is handled upstream, and is achieved by adding unsaved weights in - `original_weights` dict. Note that this enables some performance optimization by avoiding redundant operations. - All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched by this - context manager. + diffusion process. Weight unpatching is handled upstream, and is achieved by saving unchanged weights by + `original_weights.save` function. Note that this enables some performance optimization by avoiding redundant + operations. All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched + by this context manager. Args: unet (UNet2DConditionModel): The UNet model on execution device to patch. - original_weights (Dict[str, torch.Tensor]]): A read-only copy of the model's original weights in CPU, for - unpatching purposes. Extension can save tensor which being modified, if it is not saved yet, or can - access original weight value. + original_weights (OriginalWeightsStorage): A storage with copy of the model's original weights in CPU, for + unpatching purposes. Extension should save tensor which being modified in this storage, also extensions + can access original weights values. """ yield diff --git a/invokeai/backend/stable_diffusion/extensions/freeu.py b/invokeai/backend/stable_diffusion/extensions/freeu.py index 75370d23f4..ff54e1a52f 100644 --- a/invokeai/backend/stable_diffusion/extensions/freeu.py +++ b/invokeai/backend/stable_diffusion/extensions/freeu.py @@ -1,15 +1,15 @@ from __future__ import annotations from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING -import torch from diffusers import UNet2DConditionModel from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase if TYPE_CHECKING: from invokeai.app.shared.models import FreeUConfig + from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage class FreeUExt(ExtensionBase): @@ -21,7 +21,7 @@ class FreeUExt(ExtensionBase): self._freeu_config = freeu_config @contextmanager - def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]): + def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage): unet.enable_freeu( b1=self._freeu_config.b1, b2=self._freeu_config.b2, diff --git a/invokeai/backend/stable_diffusion/extensions/lora.py b/invokeai/backend/stable_diffusion/extensions/lora.py index cfb97a2cb2..617bdcbbaf 100644 --- a/invokeai/backend/stable_diffusion/extensions/lora.py +++ b/invokeai/backend/stable_diffusion/extensions/lora.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Tuple +from typing import TYPE_CHECKING, Tuple import torch from diffusers import UNet2DConditionModel @@ -13,6 +13,7 @@ if TYPE_CHECKING: from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.lora import LoRAModelRaw + from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage class LoRAExt(ExtensionBase): @@ -28,7 +29,7 @@ class LoRAExt(ExtensionBase): self._weight = weight @contextmanager - def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]): + def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage): lora_model = self._node_context.models.load(self._model_id).model self.patch_model( model=unet, @@ -49,7 +50,7 @@ class LoRAExt(ExtensionBase): prefix: str, lora: LoRAModelRaw, lora_weight: float, - original_weights: Dict[str, torch.Tensor], + original_weights: OriginalWeightsStorage, ): """ Apply one or more LoRAs to a model. @@ -57,9 +58,12 @@ class LoRAExt(ExtensionBase): :param lora: LoRA model to patch in. :param lora_weight: LoRA patch weight. :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. - :param original_weights: Dict of original weights, filled by weights which lora patches, used for unpatching. + :param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching. """ + if lora_weight == 0: + return + # assert lora.device.type == "cpu" for layer_key, layer in lora.layers.items(): if not layer_key.startswith(prefix): @@ -95,8 +99,7 @@ class LoRAExt(ExtensionBase): module_param = module.get_parameter(param_name) # save original weight - if param_key not in original_weights: - original_weights[param_key] = module_param.detach().to(device=TorchDevice.CPU_DEVICE, copy=True) + original_weights.save(param_key, module_param) if module_param.shape != lora_param_weight.shape: # TODO: debug on lycoris diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 968dffd069..3783bb422e 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -7,6 +7,7 @@ import torch from diffusers import UNet2DConditionModel from invokeai.app.services.session_processor.session_processor_common import CanceledException +from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext @@ -67,10 +68,7 @@ class ExtensionsManager: if self._is_canceled and self._is_canceled(): raise CanceledException - original_weights: Dict[str, torch.Tensor] = {} - if cached_weights: - original_weights.update(cached_weights) - + original_weights = OriginalWeightsStorage(cached_weights) try: with ExitStack() as exit_stack: for ext in self._extensions: @@ -80,5 +78,5 @@ class ExtensionsManager: finally: with torch.no_grad(): - for param_key, weight in original_weights.items(): + for param_key, weight in original_weights.get_changed_weights(): unet.get_parameter(param_key).copy_(weight) diff --git a/invokeai/backend/util/original_weights_storage.py b/invokeai/backend/util/original_weights_storage.py new file mode 100644 index 0000000000..3632c52b09 --- /dev/null +++ b/invokeai/backend/util/original_weights_storage.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import Dict, Iterator, Optional, Tuple + +import torch + +from invokeai.backend.util.devices import TorchDevice + + +class OriginalWeightsStorage: + def __init__(self, cached_weights: Optional[Dict[str, torch.Tensor]] = None): + self._weights = {} + self._changed_weights = set() + if cached_weights: + self._weights.update(cached_weights) + + def save(self, key: str, weight: torch.Tensor, copy: bool = True): + self._changed_weights.add(key) + if key in self._weights: + return + + self._weights[key] = weight.detach().to(device=TorchDevice.CPU_DEVICE, copy=copy) + + def get(self, key: str, copy: bool = False) -> Optional[torch.Tensor]: + weight = self._weights.get(key, None) + if weight is not None and copy: + weight = weight.clone() + return weight + + def contains(self, key: str) -> bool: + return key in self._weights + + def get_changed_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: + for key in self._changed_weights: + yield key, self._weights[key] From 0bb7ed44f64534460fe5afaee802eb57fcaa0ae6 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 31 Jul 2024 15:08:24 -0400 Subject: [PATCH 13/13] Add some docs to OriginalWeightsStorage and fix type hints. --- invokeai/backend/util/original_weights_storage.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/util/original_weights_storage.py b/invokeai/backend/util/original_weights_storage.py index 3632c52b09..af945b086f 100644 --- a/invokeai/backend/util/original_weights_storage.py +++ b/invokeai/backend/util/original_weights_storage.py @@ -8,9 +8,13 @@ from invokeai.backend.util.devices import TorchDevice class OriginalWeightsStorage: + """A class for tracking the original weights of a model for patch/unpatch operations.""" + def __init__(self, cached_weights: Optional[Dict[str, torch.Tensor]] = None): - self._weights = {} - self._changed_weights = set() + # The original weights of the model. + self._weights: dict[str, torch.Tensor] = {} + # The keys of the weights that have been changed (via `save()`) during the lifetime of this instance. + self._changed_weights: set[str] = set() if cached_weights: self._weights.update(cached_weights)