Optimize weights handling

This commit is contained in:
Sergey Borisov 2024-07-30 03:39:01 +03:00
parent 1fd9631f2d
commit 86f705bf48
6 changed files with 62 additions and 27 deletions

View File

@ -19,6 +19,7 @@ from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
"""
loras = [
@ -123,9 +124,7 @@ class ModelPatcher:
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
"""
original_weights: Dict[str, torch.Tensor] = {}
if cached_weights:
original_weights.update(cached_weights)
original_weights = OriginalWeightsStorage(cached_weights)
try:
for lora_model, lora_weight in loras:
LoRAExt.patch_model(
@ -141,7 +140,7 @@ class ModelPatcher:
finally:
with torch.no_grad():
for param_key, weight in original_weights.items():
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
@classmethod

View File

@ -4,12 +4,12 @@ from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List
import torch
from diffusers import UNet2DConditionModel
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
@dataclass
@ -56,17 +56,17 @@ class ExtensionBase:
yield None
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
"""A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
diffusion process. Weight unpatching is handled upstream, and is achieved by adding unsaved weights in
`original_weights` dict. Note that this enables some performance optimization by avoiding redundant operations.
All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched by this
context manager.
diffusion process. Weight unpatching is handled upstream, and is achieved by saving unchanged weights by
`original_weights.save` function. Note that this enables some performance optimization by avoiding redundant
operations. All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched
by this context manager.
Args:
unet (UNet2DConditionModel): The UNet model on execution device to patch.
original_weights (Dict[str, torch.Tensor]]): A read-only copy of the model's original weights in CPU, for
unpatching purposes. Extension can save tensor which being modified, if it is not saved yet, or can
access original weight value.
original_weights (OriginalWeightsStorage): A storage with copy of the model's original weights in CPU, for
unpatching purposes. Extension should save tensor which being modified in this storage, also extensions
can access original weights values.
"""
yield

View File

@ -1,15 +1,15 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
import torch
from diffusers import UNet2DConditionModel
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
if TYPE_CHECKING:
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
class FreeUExt(ExtensionBase):
@ -21,7 +21,7 @@ class FreeUExt(ExtensionBase):
self._freeu_config = freeu_config
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
unet.enable_freeu(
b1=self._freeu_config.b1,
b2=self._freeu_config.b2,

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Tuple
from typing import TYPE_CHECKING, Tuple
import torch
from diffusers import UNet2DConditionModel
@ -13,6 +13,7 @@ if TYPE_CHECKING:
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
class LoRAExt(ExtensionBase):
@ -28,7 +29,7 @@ class LoRAExt(ExtensionBase):
self._weight = weight
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
self.patch_model(
model=unet,
@ -49,7 +50,7 @@ class LoRAExt(ExtensionBase):
prefix: str,
lora: LoRAModelRaw,
lora_weight: float,
original_weights: Dict[str, torch.Tensor],
original_weights: OriginalWeightsStorage,
):
"""
Apply one or more LoRAs to a model.
@ -57,9 +58,12 @@ class LoRAExt(ExtensionBase):
:param lora: LoRA model to patch in.
:param lora_weight: LoRA patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:param original_weights: Dict of original weights, filled by weights which lora patches, used for unpatching.
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
"""
if lora_weight == 0:
return
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
@ -95,8 +99,7 @@ class LoRAExt(ExtensionBase):
module_param = module.get_parameter(param_name)
# save original weight
if param_key not in original_weights:
original_weights[param_key] = module_param.detach().to(device=TorchDevice.CPU_DEVICE, copy=True)
original_weights.save(param_key, module_param)
if module_param.shape != lora_param_weight.shape:
# TODO: debug on lycoris

View File

@ -7,6 +7,7 @@ import torch
from diffusers import UNet2DConditionModel
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
@ -67,10 +68,7 @@ class ExtensionsManager:
if self._is_canceled and self._is_canceled():
raise CanceledException
original_weights: Dict[str, torch.Tensor] = {}
if cached_weights:
original_weights.update(cached_weights)
original_weights = OriginalWeightsStorage(cached_weights)
try:
with ExitStack() as exit_stack:
for ext in self._extensions:
@ -80,5 +78,5 @@ class ExtensionsManager:
finally:
with torch.no_grad():
for param_key, weight in original_weights.items():
for param_key, weight in original_weights.get_changed_weights():
unet.get_parameter(param_key).copy_(weight)

View File

@ -0,0 +1,35 @@
from __future__ import annotations
from typing import Dict, Iterator, Optional, Tuple
import torch
from invokeai.backend.util.devices import TorchDevice
class OriginalWeightsStorage:
def __init__(self, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
self._weights = {}
self._changed_weights = set()
if cached_weights:
self._weights.update(cached_weights)
def save(self, key: str, weight: torch.Tensor, copy: bool = True):
self._changed_weights.add(key)
if key in self._weights:
return
self._weights[key] = weight.detach().to(device=TorchDevice.CPU_DEVICE, copy=copy)
def get(self, key: str, copy: bool = False) -> Optional[torch.Tensor]:
weight = self._weights.get(key, None)
if weight is not None and copy:
weight = weight.clone()
return weight
def contains(self, key: str) -> bool:
return key in self._weights
def get_changed_weights(self) -> Iterator[Tuple[str, torch.Tensor]]:
for key in self._changed_weights:
yield key, self._weights[key]