Optimize weights handling

This commit is contained in:
Sergey Borisov 2024-07-30 03:39:01 +03:00
parent 1fd9631f2d
commit 86f705bf48
6 changed files with 62 additions and 27 deletions

View File

@ -19,6 +19,7 @@ from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
""" """
loras = [ loras = [
@ -123,9 +124,7 @@ class ModelPatcher:
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers. :param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
""" """
original_weights: Dict[str, torch.Tensor] = {} original_weights = OriginalWeightsStorage(cached_weights)
if cached_weights:
original_weights.update(cached_weights)
try: try:
for lora_model, lora_weight in loras: for lora_model, lora_weight in loras:
LoRAExt.patch_model( LoRAExt.patch_model(
@ -141,7 +140,7 @@ class ModelPatcher:
finally: finally:
with torch.no_grad(): with torch.no_grad():
for param_key, weight in original_weights.items(): for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight) model.get_parameter(param_key).copy_(weight)
@classmethod @classmethod

View File

@ -4,12 +4,12 @@ from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List from typing import TYPE_CHECKING, Callable, Dict, List
import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
@dataclass @dataclass
@ -56,17 +56,17 @@ class ExtensionBase:
yield None yield None
@contextmanager @contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]): def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
"""A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire """A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
diffusion process. Weight unpatching is handled upstream, and is achieved by adding unsaved weights in diffusion process. Weight unpatching is handled upstream, and is achieved by saving unchanged weights by
`original_weights` dict. Note that this enables some performance optimization by avoiding redundant operations. `original_weights.save` function. Note that this enables some performance optimization by avoiding redundant
All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched by this operations. All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched
context manager. by this context manager.
Args: Args:
unet (UNet2DConditionModel): The UNet model on execution device to patch. unet (UNet2DConditionModel): The UNet model on execution device to patch.
original_weights (Dict[str, torch.Tensor]]): A read-only copy of the model's original weights in CPU, for original_weights (OriginalWeightsStorage): A storage with copy of the model's original weights in CPU, for
unpatching purposes. Extension can save tensor which being modified, if it is not saved yet, or can unpatching purposes. Extension should save tensor which being modified in this storage, also extensions
access original weight value. can access original weights values.
""" """
yield yield

View File

@ -1,15 +1,15 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING
import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.shared.models import FreeUConfig from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
class FreeUExt(ExtensionBase): class FreeUExt(ExtensionBase):
@ -21,7 +21,7 @@ class FreeUExt(ExtensionBase):
self._freeu_config = freeu_config self._freeu_config = freeu_config
@contextmanager @contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]): def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
unet.enable_freeu( unet.enable_freeu(
b1=self._freeu_config.b1, b1=self._freeu_config.b1,
b2=self._freeu_config.b2, b2=self._freeu_config.b2,

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Tuple from typing import TYPE_CHECKING, Tuple
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
@ -13,6 +13,7 @@ if TYPE_CHECKING:
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
class LoRAExt(ExtensionBase): class LoRAExt(ExtensionBase):
@ -28,7 +29,7 @@ class LoRAExt(ExtensionBase):
self._weight = weight self._weight = weight
@contextmanager @contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]): def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model lora_model = self._node_context.models.load(self._model_id).model
self.patch_model( self.patch_model(
model=unet, model=unet,
@ -49,7 +50,7 @@ class LoRAExt(ExtensionBase):
prefix: str, prefix: str,
lora: LoRAModelRaw, lora: LoRAModelRaw,
lora_weight: float, lora_weight: float,
original_weights: Dict[str, torch.Tensor], original_weights: OriginalWeightsStorage,
): ):
""" """
Apply one or more LoRAs to a model. Apply one or more LoRAs to a model.
@ -57,9 +58,12 @@ class LoRAExt(ExtensionBase):
:param lora: LoRA model to patch in. :param lora: LoRA model to patch in.
:param lora_weight: LoRA patch weight. :param lora_weight: LoRA patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers. :param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:param original_weights: Dict of original weights, filled by weights which lora patches, used for unpatching. :param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
""" """
if lora_weight == 0:
return
# assert lora.device.type == "cpu" # assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items(): for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix): if not layer_key.startswith(prefix):
@ -95,8 +99,7 @@ class LoRAExt(ExtensionBase):
module_param = module.get_parameter(param_name) module_param = module.get_parameter(param_name)
# save original weight # save original weight
if param_key not in original_weights: original_weights.save(param_key, module_param)
original_weights[param_key] = module_param.detach().to(device=TorchDevice.CPU_DEVICE, copy=True)
if module_param.shape != lora_param_weight.shape: if module_param.shape != lora_param_weight.shape:
# TODO: debug on lycoris # TODO: debug on lycoris

View File

@ -7,6 +7,7 @@ import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
@ -67,10 +68,7 @@ class ExtensionsManager:
if self._is_canceled and self._is_canceled(): if self._is_canceled and self._is_canceled():
raise CanceledException raise CanceledException
original_weights: Dict[str, torch.Tensor] = {} original_weights = OriginalWeightsStorage(cached_weights)
if cached_weights:
original_weights.update(cached_weights)
try: try:
with ExitStack() as exit_stack: with ExitStack() as exit_stack:
for ext in self._extensions: for ext in self._extensions:
@ -80,5 +78,5 @@ class ExtensionsManager:
finally: finally:
with torch.no_grad(): with torch.no_grad():
for param_key, weight in original_weights.items(): for param_key, weight in original_weights.get_changed_weights():
unet.get_parameter(param_key).copy_(weight) unet.get_parameter(param_key).copy_(weight)

View File

@ -0,0 +1,35 @@
from __future__ import annotations
from typing import Dict, Iterator, Optional, Tuple
import torch
from invokeai.backend.util.devices import TorchDevice
class OriginalWeightsStorage:
def __init__(self, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
self._weights = {}
self._changed_weights = set()
if cached_weights:
self._weights.update(cached_weights)
def save(self, key: str, weight: torch.Tensor, copy: bool = True):
self._changed_weights.add(key)
if key in self._weights:
return
self._weights[key] = weight.detach().to(device=TorchDevice.CPU_DEVICE, copy=copy)
def get(self, key: str, copy: bool = False) -> Optional[torch.Tensor]:
weight = self._weights.get(key, None)
if weight is not None and copy:
weight = weight.clone()
return weight
def contains(self, key: str) -> bool:
return key in self._weights
def get_changed_weights(self) -> Iterator[Tuple[str, torch.Tensor]]:
for key in self._changed_weights:
yield key, self._weights[key]