Suggested changes + simplify weights logic in patching

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
Sergey Borisov 2024-07-30 00:34:37 +03:00
parent 8500bac3ca
commit 2227a2357f
6 changed files with 76 additions and 108 deletions

View File

@ -490,6 +490,9 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
for layer_key, values in state_dict.items(): for layer_key, values in state_dict.items():
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
# lora and locon # lora and locon
if "lora_up.weight" in values: if "lora_up.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values) layer: AnyLoRALayer = LoRALayer(layer_key, values)

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import pickle import pickle
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Set, Tuple, Type, Union from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
import numpy as np import numpy as np
import torch import torch
@ -123,34 +123,25 @@ class ModelPatcher:
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers. :param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
""" """
modified_cached_weights: Set[str] = set() original_weights: Dict[str, torch.Tensor] = {}
modified_weights: Dict[str, torch.Tensor] = {} if cached_weights:
original_weights.update(cached_weights)
try: try:
for lora_model, lora_weight in loras: for lora_model, lora_weight in loras:
lora_modified_cached_weights, lora_modified_weights = LoRAExt.patch_model( LoRAExt.patch_model(
model=model, model=model,
prefix=prefix, prefix=prefix,
lora=lora_model, lora=lora_model,
lora_weight=lora_weight, lora_weight=lora_weight,
cached_weights=cached_weights, original_weights=original_weights,
) )
del lora_model del lora_model
modified_cached_weights.update(lora_modified_cached_weights)
# Store only first returned weight for each key, because
# next extension which changes it, will work with already modified weight
for param_key, weight in lora_modified_weights.items():
if param_key in modified_weights:
continue
modified_weights[param_key] = weight
yield yield
finally: finally:
with torch.no_grad(): with torch.no_grad():
for param_key in modified_cached_weights: for param_key, weight in original_weights.items():
model.get_parameter(param_key).copy_(cached_weights[param_key])
for param_key, weight in modified_weights.items():
model.get_parameter(param_key).copy_(weight) model.get_parameter(param_key).copy_(weight)
@classmethod @classmethod

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Callable, Dict, List
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
@ -56,17 +56,17 @@ class ExtensionBase:
yield None yield None
@contextmanager @contextmanager
def patch_unet( def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None """A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
) -> Tuple[Set[str], Dict[str, torch.Tensor]]: diffusion process. Weight unpatching is handled upstream, and is achieved by adding unsaved weights in
"""Apply patches to UNet model. This function responsible for restoring all changes except weights, `original_weights` dict. Note that this enables some performance optimization by avoiding redundant operations.
changed weights should only be reported in return. All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched by this
Return contains 2 values: context manager.
- Set of cached weights, just keys from cached_weights dictionary
- Dict of not cached weights that should be copies on the cpu device
Args: Args:
unet (UNet2DConditionModel): The UNet model on execution device to patch. unet (UNet2DConditionModel): The UNet model on execution device to patch.
cached_weights (Optional[Dict[str, torch.Tensor]]): Read-only copy of the model's state dict in CPU, for caches purposes. cached_weights (Dict[str, torch.Tensor]]): A read-only copy of the model's original weights in CPU, for
unpatching purposes. Extension can save tensor which being modified, if it is not saved yet, or can
access original weight value.
""" """
yield set(), {} yield

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
@ -21,9 +21,7 @@ class FreeUExt(ExtensionBase):
self._freeu_config = freeu_config self._freeu_config = freeu_config
@contextmanager @contextmanager
def patch_unet( def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
unet.enable_freeu( unet.enable_freeu(
b1=self._freeu_config.b1, b1=self._freeu_config.b1,
b2=self._freeu_config.b2, b2=self._freeu_config.b2,
@ -32,6 +30,6 @@ class FreeUExt(ExtensionBase):
) )
try: try:
yield set(), {} yield
finally: finally:
unet.disable_freeu() unet.disable_freeu()

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, Tuple
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
@ -28,44 +28,38 @@ class LoRAExt(ExtensionBase):
self._weight = weight self._weight = weight
@contextmanager @contextmanager
def patch_unet( def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
lora_model = self._node_context.models.load(self._model_id).model lora_model = self._node_context.models.load(self._model_id).model
modified_cached_weights, modified_weights = self.patch_model( self.patch_model(
model=unet, model=unet,
prefix="lora_unet_", prefix="lora_unet_",
lora=lora_model, lora=lora_model,
lora_weight=self._weight, lora_weight=self._weight,
cached_weights=cached_weights, original_weights=original_weights,
) )
del lora_model del lora_model
yield modified_cached_weights, modified_weights yield
@classmethod @classmethod
@torch.no_grad()
def patch_model( def patch_model(
cls, cls,
model: torch.nn.Module, model: torch.nn.Module,
prefix: str, prefix: str,
lora: LoRAModelRaw, lora: LoRAModelRaw,
lora_weight: float, lora_weight: float,
cached_weights: Optional[Dict[str, torch.Tensor]] = None, original_weights: Dict[str, torch.Tensor],
) -> Tuple[Set[str], Dict[str, torch.Tensor]]: ):
""" """
Apply one or more LoRAs to a model. Apply one or more LoRAs to a model.
:param model: The model to patch. :param model: The model to patch.
:param lora: LoRA model to patch in. :param lora: LoRA model to patch in.
:param lora_weight: LoRA patch weight. :param lora_weight: LoRA patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers. :param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:param cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. :param original_weights: TODO:
""" """
if cached_weights is None:
cached_weights = {}
modified_weights: Dict[str, torch.Tensor] = {}
modified_cached_weights: Set[str] = set()
with torch.no_grad():
# assert lora.device.type == "cpu" # assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items(): for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix): if not layer_key.startswith(prefix):
@ -101,13 +95,8 @@ class LoRAExt(ExtensionBase):
module_param = module.get_parameter(param_name) module_param = module.get_parameter(param_name)
# save original weight # save original weight
if param_key not in modified_cached_weights and param_key not in modified_weights: if param_key not in original_weights:
if param_key in cached_weights: original_weights[param_key] = module_param.detach().to(device=TorchDevice.CPU_DEVICE, copy=True)
modified_cached_weights.add(param_key)
else:
modified_weights[param_key] = module_param.detach().to(
device=TorchDevice.CPU_DEVICE, copy=True
)
if module_param.shape != lora_param_weight.shape: if module_param.shape != lora_param_weight.shape:
# TODO: debug on lycoris # TODO: debug on lycoris
@ -118,8 +107,6 @@ class LoRAExt(ExtensionBase):
layer.to(device=TorchDevice.CPU_DEVICE) layer.to(device=TorchDevice.CPU_DEVICE)
return modified_cached_weights, modified_weights
@staticmethod @staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key assert "." not in lora_key

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import ExitStack, contextmanager from contextlib import ExitStack, contextmanager
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
@ -67,29 +67,18 @@ class ExtensionsManager:
if self._is_canceled and self._is_canceled(): if self._is_canceled and self._is_canceled():
raise CanceledException raise CanceledException
modified_weights: Dict[str, torch.Tensor] = {} original_weights: Dict[str, torch.Tensor] = {}
modified_cached_weights: Set[str] = set() if cached_weights:
original_weights.update(cached_weights)
try: try:
with ExitStack() as exit_stack: with ExitStack() as exit_stack:
for ext in self._extensions: for ext in self._extensions:
ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context( exit_stack.enter_context(ext.patch_unet(unet, original_weights))
ext.patch_unet(unet, cached_weights)
)
modified_cached_weights.update(ext_modified_cached_weights)
# store only first returned weight for each key, because
# next extension which changes it, will work with already modified weight
for param_key, weight in ext_modified_weights.items():
if param_key in modified_weights:
continue
modified_weights[param_key] = weight
yield None yield None
finally: finally:
with torch.no_grad(): with torch.no_grad():
for param_key in modified_cached_weights: for param_key, weight in original_weights.items():
unet.get_parameter(param_key).copy_(cached_weights[param_key])
for param_key, weight in modified_weights.items():
unet.get_parameter(param_key).copy_(weight) unet.get_parameter(param_key).copy_(weight)