diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index 8d8ce04d66..f5f3eedaa8 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -71,6 +71,9 @@ class LoRALayerBase: self.bias = self.bias.to(device=device, dtype=dtype) def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]): + """Log a warning if values contains unhandled keys.""" + # {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by + # `LoRALayerBase`. Sub-classes should provide the known_keys that they handled. all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"} unknown_keys = set(values.keys()) - all_known_keys if unknown_keys: @@ -232,7 +235,6 @@ class LoKRLayer(LoRALayerBase): else: self.rank = None # unscaled - # Although lokr_t1 not used in algo, it still defined in LoKR weights self.check_keys( values, { @@ -242,7 +244,6 @@ class LoKRLayer(LoRALayerBase): "lokr_w2", "lokr_w2_a", "lokr_w2_b", - "lokr_t1", "lokr_t2", }, ) diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index 820d5d32a3..1208c3f0ee 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -2,7 +2,7 @@ from __future__ import annotations from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple import torch from diffusers import UNet2DConditionModel @@ -56,5 +56,17 @@ class ExtensionBase: yield None @contextmanager - def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): - yield None + def patch_unet( + self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None + ) -> Tuple[Set[str], Dict[str, torch.Tensor]]: + """Apply patches to UNet model. This function responsible for restoring all changes except weights, + changed weights should only be reported in return. + Return contains 2 values: + - Set of cached weights, just keys from cached_weights dictionary + - Dict of not cached weights that should be copies on the cpu device + + Args: + unet (UNet2DConditionModel): The UNet model on execution device to patch. + cached_weights (Optional[Dict[str, torch.Tensor]]): Read-only copy of the model's state dict in CPU, for caches purposes. + """ + yield set(), {} diff --git a/invokeai/backend/stable_diffusion/extensions/freeu.py b/invokeai/backend/stable_diffusion/extensions/freeu.py index 6ec4fea3fa..593481f198 100644 --- a/invokeai/backend/stable_diffusion/extensions/freeu.py +++ b/invokeai/backend/stable_diffusion/extensions/freeu.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple import torch from diffusers import UNet2DConditionModel @@ -21,7 +21,9 @@ class FreeUExt(ExtensionBase): self._freeu_config = freeu_config @contextmanager - def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): + def patch_unet( + self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None + ) -> Tuple[Set[str], Dict[str, torch.Tensor]]: unet.enable_freeu( b1=self._freeu_config.b1, b2=self._freeu_config.b2, @@ -30,6 +32,6 @@ class FreeUExt(ExtensionBase): ) try: - yield + yield set(), {} finally: unet.disable_freeu() diff --git a/invokeai/backend/stable_diffusion/extensions/lora.py b/invokeai/backend/stable_diffusion/extensions/lora.py index 11cdeb6021..55f7259d96 100644 --- a/invokeai/backend/stable_diffusion/extensions/lora.py +++ b/invokeai/backend/stable_diffusion/extensions/lora.py @@ -28,7 +28,9 @@ class LoRAExt(ExtensionBase): self._weight = weight @contextmanager - def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): + def patch_unet( + self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None + ) -> Tuple[Set[str], Dict[str, torch.Tensor]]: lora_model = self._node_context.models.load(self._model_id).model modified_cached_weights, modified_weights = self.patch_model( model=unet, @@ -49,14 +51,14 @@ class LoRAExt(ExtensionBase): lora: LoRAModelRaw, lora_weight: float, cached_weights: Optional[Dict[str, torch.Tensor]] = None, - ): + ) -> Tuple[Set[str], Dict[str, torch.Tensor]]: """ Apply one or more LoRAs to a model. :param model: The model to patch. :param lora: LoRA model to patch in. :param lora_weight: LoRA patch weight. :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. - :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. + :param cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. """ if cached_weights is None: cached_weights = {} diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 4f7e1e0874..e838a2034b 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -70,26 +70,24 @@ class ExtensionsManager: modified_weights: Dict[str, torch.Tensor] = {} modified_cached_weights: Set[str] = set() - exit_stack = ExitStack() try: - for ext in self._extensions: - res = exit_stack.enter_context(ext.patch_unet(unet, cached_weights)) - if res is None: - continue - ext_modified_cached_weights, ext_modified_weights = res + with ExitStack() as exit_stack: + for ext in self._extensions: + ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context( + ext.patch_unet(unet, cached_weights) + ) - modified_cached_weights.update(ext_modified_cached_weights) - # store only first returned weight for each key, because - # next extension which changes it, will work with already modified weight - for param_key, weight in ext_modified_weights.items(): - if param_key in modified_weights: - continue - modified_weights[param_key] = weight + modified_cached_weights.update(ext_modified_cached_weights) + # store only first returned weight for each key, because + # next extension which changes it, will work with already modified weight + for param_key, weight in ext_modified_weights.items(): + if param_key in modified_weights: + continue + modified_weights[param_key] = weight - yield None + yield None finally: - exit_stack.close() with torch.no_grad(): for param_key in modified_cached_weights: unet.get_parameter(param_key).copy_(cached_weights[param_key])