Suggested changes + simplify weights logic in patching

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
Sergey Borisov 2024-07-30 00:34:37 +03:00
parent 8500bac3ca
commit 2227a2357f
6 changed files with 76 additions and 108 deletions

View File

@ -490,6 +490,9 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
for layer_key, values in state_dict.items(): for layer_key, values in state_dict.items():
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
# lora and locon # lora and locon
if "lora_up.weight" in values: if "lora_up.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values) layer: AnyLoRALayer = LoRALayer(layer_key, values)

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import pickle import pickle
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Set, Tuple, Type, Union from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
import numpy as np import numpy as np
import torch import torch
@ -123,34 +123,25 @@ class ModelPatcher:
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers. :param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
""" """
modified_cached_weights: Set[str] = set() original_weights: Dict[str, torch.Tensor] = {}
modified_weights: Dict[str, torch.Tensor] = {} if cached_weights:
original_weights.update(cached_weights)
try: try:
for lora_model, lora_weight in loras: for lora_model, lora_weight in loras:
lora_modified_cached_weights, lora_modified_weights = LoRAExt.patch_model( LoRAExt.patch_model(
model=model, model=model,
prefix=prefix, prefix=prefix,
lora=lora_model, lora=lora_model,
lora_weight=lora_weight, lora_weight=lora_weight,
cached_weights=cached_weights, original_weights=original_weights,
) )
del lora_model del lora_model
modified_cached_weights.update(lora_modified_cached_weights)
# Store only first returned weight for each key, because
# next extension which changes it, will work with already modified weight
for param_key, weight in lora_modified_weights.items():
if param_key in modified_weights:
continue
modified_weights[param_key] = weight
yield yield
finally: finally:
with torch.no_grad(): with torch.no_grad():
for param_key in modified_cached_weights: for param_key, weight in original_weights.items():
model.get_parameter(param_key).copy_(cached_weights[param_key])
for param_key, weight in modified_weights.items():
model.get_parameter(param_key).copy_(weight) model.get_parameter(param_key).copy_(weight)
@classmethod @classmethod

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Callable, Dict, List
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
@ -56,17 +56,17 @@ class ExtensionBase:
yield None yield None
@contextmanager @contextmanager
def patch_unet( def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None """A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
) -> Tuple[Set[str], Dict[str, torch.Tensor]]: diffusion process. Weight unpatching is handled upstream, and is achieved by adding unsaved weights in
"""Apply patches to UNet model. This function responsible for restoring all changes except weights, `original_weights` dict. Note that this enables some performance optimization by avoiding redundant operations.
changed weights should only be reported in return. All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched by this
Return contains 2 values: context manager.
- Set of cached weights, just keys from cached_weights dictionary
- Dict of not cached weights that should be copies on the cpu device
Args: Args:
unet (UNet2DConditionModel): The UNet model on execution device to patch. unet (UNet2DConditionModel): The UNet model on execution device to patch.
cached_weights (Optional[Dict[str, torch.Tensor]]): Read-only copy of the model's state dict in CPU, for caches purposes. cached_weights (Dict[str, torch.Tensor]]): A read-only copy of the model's original weights in CPU, for
unpatching purposes. Extension can save tensor which being modified, if it is not saved yet, or can
access original weight value.
""" """
yield set(), {} yield

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
@ -21,9 +21,7 @@ class FreeUExt(ExtensionBase):
self._freeu_config = freeu_config self._freeu_config = freeu_config
@contextmanager @contextmanager
def patch_unet( def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
unet.enable_freeu( unet.enable_freeu(
b1=self._freeu_config.b1, b1=self._freeu_config.b1,
b2=self._freeu_config.b2, b2=self._freeu_config.b2,
@ -32,6 +30,6 @@ class FreeUExt(ExtensionBase):
) )
try: try:
yield set(), {} yield
finally: finally:
unet.disable_freeu() unet.disable_freeu()

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, Tuple
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
@ -28,97 +28,84 @@ class LoRAExt(ExtensionBase):
self._weight = weight self._weight = weight
@contextmanager @contextmanager
def patch_unet( def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
lora_model = self._node_context.models.load(self._model_id).model lora_model = self._node_context.models.load(self._model_id).model
modified_cached_weights, modified_weights = self.patch_model( self.patch_model(
model=unet, model=unet,
prefix="lora_unet_", prefix="lora_unet_",
lora=lora_model, lora=lora_model,
lora_weight=self._weight, lora_weight=self._weight,
cached_weights=cached_weights, original_weights=original_weights,
) )
del lora_model del lora_model
yield modified_cached_weights, modified_weights yield
@classmethod @classmethod
@torch.no_grad()
def patch_model( def patch_model(
cls, cls,
model: torch.nn.Module, model: torch.nn.Module,
prefix: str, prefix: str,
lora: LoRAModelRaw, lora: LoRAModelRaw,
lora_weight: float, lora_weight: float,
cached_weights: Optional[Dict[str, torch.Tensor]] = None, original_weights: Dict[str, torch.Tensor],
) -> Tuple[Set[str], Dict[str, torch.Tensor]]: ):
""" """
Apply one or more LoRAs to a model. Apply one or more LoRAs to a model.
:param model: The model to patch. :param model: The model to patch.
:param lora: LoRA model to patch in. :param lora: LoRA model to patch in.
:param lora_weight: LoRA patch weight. :param lora_weight: LoRA patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers. :param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:param cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. :param original_weights: TODO:
""" """
if cached_weights is None:
cached_weights = {}
modified_weights: Dict[str, torch.Tensor] = {} # assert lora.device.type == "cpu"
modified_cached_weights: Set[str] = set() for layer_key, layer in lora.layers.items():
with torch.no_grad(): if not layer_key.startswith(prefix):
# assert lora.device.type == "cpu" continue
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
# should be improved in the following ways: # should be improved in the following ways:
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
# LoRA model is applied. # LoRA model is applied.
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys. # weights to have valid keys.
assert isinstance(model, torch.nn.Module) assert isinstance(model, torch.nn.Module)
module_key, module = cls._resolve_lora_key(model, layer_key, prefix) module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# All of the LoRA weight calculations will be done on the same device as the module weight. # All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.) # (Performance will be best if this is a CUDA device.)
device = module.weight.device device = module.weight.device
dtype = module.weight.dtype dtype = module.weight.dtype
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to # We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'. # same thing in a single call to '.to(...)'.
layer.to(device=device) layer.to(device=device)
layer.to(dtype=torch.float32) layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in layer.get_parameters(module).items(): for param_name, lora_param_weight in layer.get_parameters(module).items():
param_key = module_key + "." + param_name param_key = module_key + "." + param_name
module_param = module.get_parameter(param_name) module_param = module.get_parameter(param_name)
# save original weight # save original weight
if param_key not in modified_cached_weights and param_key not in modified_weights: if param_key not in original_weights:
if param_key in cached_weights: original_weights[param_key] = module_param.detach().to(device=TorchDevice.CPU_DEVICE, copy=True)
modified_cached_weights.add(param_key)
else:
modified_weights[param_key] = module_param.detach().to(
device=TorchDevice.CPU_DEVICE, copy=True
)
if module_param.shape != lora_param_weight.shape: if module_param.shape != lora_param_weight.shape:
# TODO: debug on lycoris # TODO: debug on lycoris
lora_param_weight = lora_param_weight.reshape(module_param.shape) lora_param_weight = lora_param_weight.reshape(module_param.shape)
lora_param_weight *= lora_weight * layer_scale lora_param_weight *= lora_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype) module_param += lora_param_weight.to(dtype=dtype)
layer.to(device=TorchDevice.CPU_DEVICE) layer.to(device=TorchDevice.CPU_DEVICE)
return modified_cached_weights, modified_weights
@staticmethod @staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import ExitStack, contextmanager from contextlib import ExitStack, contextmanager
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
@ -67,29 +67,18 @@ class ExtensionsManager:
if self._is_canceled and self._is_canceled(): if self._is_canceled and self._is_canceled():
raise CanceledException raise CanceledException
modified_weights: Dict[str, torch.Tensor] = {} original_weights: Dict[str, torch.Tensor] = {}
modified_cached_weights: Set[str] = set() if cached_weights:
original_weights.update(cached_weights)
try: try:
with ExitStack() as exit_stack: with ExitStack() as exit_stack:
for ext in self._extensions: for ext in self._extensions:
ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context( exit_stack.enter_context(ext.patch_unet(unet, original_weights))
ext.patch_unet(unet, cached_weights)
)
modified_cached_weights.update(ext_modified_cached_weights)
# store only first returned weight for each key, because
# next extension which changes it, will work with already modified weight
for param_key, weight in ext_modified_weights.items():
if param_key in modified_weights:
continue
modified_weights[param_key] = weight
yield None yield None
finally: finally:
with torch.no_grad(): with torch.no_grad():
for param_key in modified_cached_weights: for param_key, weight in original_weights.items():
unet.get_parameter(param_key).copy_(cached_weights[param_key])
for param_key, weight in modified_weights.items():
unet.get_parameter(param_key).copy_(weight) unet.get_parameter(param_key).copy_(weight)