From 86f705bf484196b07f329c7aade50c7b987da941 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 30 Jul 2024 03:39:01 +0300 Subject: [PATCH] Optimize weights handling --- invokeai/backend/model_patcher.py | 7 ++-- .../stable_diffusion/extensions/base.py | 18 +++++----- .../stable_diffusion/extensions/freeu.py | 6 ++-- .../stable_diffusion/extensions/lora.py | 15 ++++---- .../stable_diffusion/extensions_manager.py | 8 ++--- .../backend/util/original_weights_storage.py | 35 +++++++++++++++++++ 6 files changed, 62 insertions(+), 27 deletions(-) create mode 100644 invokeai/backend/util/original_weights_storage.py diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index c0dc1dca1d..e2f22ba019 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -19,6 +19,7 @@ from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_ from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw +from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage """ loras = [ @@ -123,9 +124,7 @@ class ModelPatcher: :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. """ - original_weights: Dict[str, torch.Tensor] = {} - if cached_weights: - original_weights.update(cached_weights) + original_weights = OriginalWeightsStorage(cached_weights) try: for lora_model, lora_weight in loras: LoRAExt.patch_model( @@ -141,7 +140,7 @@ class ModelPatcher: finally: with torch.no_grad(): - for param_key, weight in original_weights.items(): + for param_key, weight in original_weights.get_changed_weights(): model.get_parameter(param_key).copy_(weight) @classmethod diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index 61276e0784..a3d27464a0 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -4,12 +4,12 @@ from contextlib import contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, Dict, List -import torch from diffusers import UNet2DConditionModel if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType + from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage @dataclass @@ -56,17 +56,17 @@ class ExtensionBase: yield None @contextmanager - def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]): + def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage): """A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire - diffusion process. Weight unpatching is handled upstream, and is achieved by adding unsaved weights in - `original_weights` dict. Note that this enables some performance optimization by avoiding redundant operations. - All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched by this - context manager. + diffusion process. Weight unpatching is handled upstream, and is achieved by saving unchanged weights by + `original_weights.save` function. Note that this enables some performance optimization by avoiding redundant + operations. All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched + by this context manager. Args: unet (UNet2DConditionModel): The UNet model on execution device to patch. - original_weights (Dict[str, torch.Tensor]]): A read-only copy of the model's original weights in CPU, for - unpatching purposes. Extension can save tensor which being modified, if it is not saved yet, or can - access original weight value. + original_weights (OriginalWeightsStorage): A storage with copy of the model's original weights in CPU, for + unpatching purposes. Extension should save tensor which being modified in this storage, also extensions + can access original weights values. """ yield diff --git a/invokeai/backend/stable_diffusion/extensions/freeu.py b/invokeai/backend/stable_diffusion/extensions/freeu.py index 75370d23f4..ff54e1a52f 100644 --- a/invokeai/backend/stable_diffusion/extensions/freeu.py +++ b/invokeai/backend/stable_diffusion/extensions/freeu.py @@ -1,15 +1,15 @@ from __future__ import annotations from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING -import torch from diffusers import UNet2DConditionModel from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase if TYPE_CHECKING: from invokeai.app.shared.models import FreeUConfig + from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage class FreeUExt(ExtensionBase): @@ -21,7 +21,7 @@ class FreeUExt(ExtensionBase): self._freeu_config = freeu_config @contextmanager - def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]): + def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage): unet.enable_freeu( b1=self._freeu_config.b1, b2=self._freeu_config.b2, diff --git a/invokeai/backend/stable_diffusion/extensions/lora.py b/invokeai/backend/stable_diffusion/extensions/lora.py index cfb97a2cb2..617bdcbbaf 100644 --- a/invokeai/backend/stable_diffusion/extensions/lora.py +++ b/invokeai/backend/stable_diffusion/extensions/lora.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Tuple +from typing import TYPE_CHECKING, Tuple import torch from diffusers import UNet2DConditionModel @@ -13,6 +13,7 @@ if TYPE_CHECKING: from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.lora import LoRAModelRaw + from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage class LoRAExt(ExtensionBase): @@ -28,7 +29,7 @@ class LoRAExt(ExtensionBase): self._weight = weight @contextmanager - def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]): + def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage): lora_model = self._node_context.models.load(self._model_id).model self.patch_model( model=unet, @@ -49,7 +50,7 @@ class LoRAExt(ExtensionBase): prefix: str, lora: LoRAModelRaw, lora_weight: float, - original_weights: Dict[str, torch.Tensor], + original_weights: OriginalWeightsStorage, ): """ Apply one or more LoRAs to a model. @@ -57,9 +58,12 @@ class LoRAExt(ExtensionBase): :param lora: LoRA model to patch in. :param lora_weight: LoRA patch weight. :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. - :param original_weights: Dict of original weights, filled by weights which lora patches, used for unpatching. + :param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching. """ + if lora_weight == 0: + return + # assert lora.device.type == "cpu" for layer_key, layer in lora.layers.items(): if not layer_key.startswith(prefix): @@ -95,8 +99,7 @@ class LoRAExt(ExtensionBase): module_param = module.get_parameter(param_name) # save original weight - if param_key not in original_weights: - original_weights[param_key] = module_param.detach().to(device=TorchDevice.CPU_DEVICE, copy=True) + original_weights.save(param_key, module_param) if module_param.shape != lora_param_weight.shape: # TODO: debug on lycoris diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 968dffd069..3783bb422e 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -7,6 +7,7 @@ import torch from diffusers import UNet2DConditionModel from invokeai.app.services.session_processor.session_processor_common import CanceledException +from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext @@ -67,10 +68,7 @@ class ExtensionsManager: if self._is_canceled and self._is_canceled(): raise CanceledException - original_weights: Dict[str, torch.Tensor] = {} - if cached_weights: - original_weights.update(cached_weights) - + original_weights = OriginalWeightsStorage(cached_weights) try: with ExitStack() as exit_stack: for ext in self._extensions: @@ -80,5 +78,5 @@ class ExtensionsManager: finally: with torch.no_grad(): - for param_key, weight in original_weights.items(): + for param_key, weight in original_weights.get_changed_weights(): unet.get_parameter(param_key).copy_(weight) diff --git a/invokeai/backend/util/original_weights_storage.py b/invokeai/backend/util/original_weights_storage.py new file mode 100644 index 0000000000..3632c52b09 --- /dev/null +++ b/invokeai/backend/util/original_weights_storage.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import Dict, Iterator, Optional, Tuple + +import torch + +from invokeai.backend.util.devices import TorchDevice + + +class OriginalWeightsStorage: + def __init__(self, cached_weights: Optional[Dict[str, torch.Tensor]] = None): + self._weights = {} + self._changed_weights = set() + if cached_weights: + self._weights.update(cached_weights) + + def save(self, key: str, weight: torch.Tensor, copy: bool = True): + self._changed_weights.add(key) + if key in self._weights: + return + + self._weights[key] = weight.detach().to(device=TorchDevice.CPU_DEVICE, copy=copy) + + def get(self, key: str, copy: bool = False) -> Optional[torch.Tensor]: + weight = self._weights.get(key, None) + if weight is not None and copy: + weight = weight.clone() + return weight + + def contains(self, key: str) -> bool: + return key in self._weights + + def get_changed_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: + for key in self._changed_weights: + yield key, self._weights[key]