Suggested changes + simplify weights logic in patching

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
Sergey Borisov 2024-07-30 00:34:37 +03:00
parent 8500bac3ca
commit 2227a2357f
6 changed files with 76 additions and 108 deletions

View File

@ -490,6 +490,9 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
for layer_key, values in state_dict.items():
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
# lora and locon
if "lora_up.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Set, Tuple, Type, Union
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
import numpy as np
import torch
@ -123,34 +123,25 @@ class ModelPatcher:
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
"""
modified_cached_weights: Set[str] = set()
modified_weights: Dict[str, torch.Tensor] = {}
original_weights: Dict[str, torch.Tensor] = {}
if cached_weights:
original_weights.update(cached_weights)
try:
for lora_model, lora_weight in loras:
lora_modified_cached_weights, lora_modified_weights = LoRAExt.patch_model(
LoRAExt.patch_model(
model=model,
prefix=prefix,
lora=lora_model,
lora_weight=lora_weight,
cached_weights=cached_weights,
original_weights=original_weights,
)
del lora_model
modified_cached_weights.update(lora_modified_cached_weights)
# Store only first returned weight for each key, because
# next extension which changes it, will work with already modified weight
for param_key, weight in lora_modified_weights.items():
if param_key in modified_weights:
continue
modified_weights[param_key] = weight
yield
finally:
with torch.no_grad():
for param_key in modified_cached_weights:
model.get_parameter(param_key).copy_(cached_weights[param_key])
for param_key, weight in modified_weights.items():
for param_key, weight in original_weights.items():
model.get_parameter(param_key).copy_(weight)
@classmethod

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Callable, Dict, List
import torch
from diffusers import UNet2DConditionModel
@ -56,17 +56,17 @@ class ExtensionBase:
yield None
@contextmanager
def patch_unet(
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
"""Apply patches to UNet model. This function responsible for restoring all changes except weights,
changed weights should only be reported in return.
Return contains 2 values:
- Set of cached weights, just keys from cached_weights dictionary
- Dict of not cached weights that should be copies on the cpu device
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
"""A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
diffusion process. Weight unpatching is handled upstream, and is achieved by adding unsaved weights in
`original_weights` dict. Note that this enables some performance optimization by avoiding redundant operations.
All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched by this
context manager.
Args:
unet (UNet2DConditionModel): The UNet model on execution device to patch.
cached_weights (Optional[Dict[str, torch.Tensor]]): Read-only copy of the model's state dict in CPU, for caches purposes.
cached_weights (Dict[str, torch.Tensor]]): A read-only copy of the model's original weights in CPU, for
unpatching purposes. Extension can save tensor which being modified, if it is not saved yet, or can
access original weight value.
"""
yield set(), {}
yield

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict
import torch
from diffusers import UNet2DConditionModel
@ -21,9 +21,7 @@ class FreeUExt(ExtensionBase):
self._freeu_config = freeu_config
@contextmanager
def patch_unet(
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
unet.enable_freeu(
b1=self._freeu_config.b1,
b2=self._freeu_config.b2,
@ -32,6 +30,6 @@ class FreeUExt(ExtensionBase):
)
try:
yield set(), {}
yield
finally:
unet.disable_freeu()

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, Tuple
import torch
from diffusers import UNet2DConditionModel
@ -28,97 +28,84 @@ class LoRAExt(ExtensionBase):
self._weight = weight
@contextmanager
def patch_unet(
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
lora_model = self._node_context.models.load(self._model_id).model
modified_cached_weights, modified_weights = self.patch_model(
self.patch_model(
model=unet,
prefix="lora_unet_",
lora=lora_model,
lora_weight=self._weight,
cached_weights=cached_weights,
original_weights=original_weights,
)
del lora_model
yield modified_cached_weights, modified_weights
yield
@classmethod
@torch.no_grad()
def patch_model(
cls,
model: torch.nn.Module,
prefix: str,
lora: LoRAModelRaw,
lora_weight: float,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
original_weights: Dict[str, torch.Tensor],
):
"""
Apply one or more LoRAs to a model.
:param model: The model to patch.
:param lora: LoRA model to patch in.
:param lora_weight: LoRA patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:param cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
:param original_weights: TODO:
"""
if cached_weights is None:
cached_weights = {}
modified_weights: Dict[str, torch.Tensor] = {}
modified_cached_weights: Set[str] = set()
with torch.no_grad():
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
# should be improved in the following ways:
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
# LoRA model is applied.
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys.
assert isinstance(model, torch.nn.Module)
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
# should be improved in the following ways:
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
# LoRA model is applied.
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys.
assert isinstance(model, torch.nn.Module)
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in layer.get_parameters(module).items():
param_key = module_key + "." + param_name
module_param = module.get_parameter(param_name)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in layer.get_parameters(module).items():
param_key = module_key + "." + param_name
module_param = module.get_parameter(param_name)
# save original weight
if param_key not in modified_cached_weights and param_key not in modified_weights:
if param_key in cached_weights:
modified_cached_weights.add(param_key)
else:
modified_weights[param_key] = module_param.detach().to(
device=TorchDevice.CPU_DEVICE, copy=True
)
# save original weight
if param_key not in original_weights:
original_weights[param_key] = module_param.detach().to(device=TorchDevice.CPU_DEVICE, copy=True)
if module_param.shape != lora_param_weight.shape:
# TODO: debug on lycoris
lora_param_weight = lora_param_weight.reshape(module_param.shape)
if module_param.shape != lora_param_weight.shape:
# TODO: debug on lycoris
lora_param_weight = lora_param_weight.reshape(module_param.shape)
lora_param_weight *= lora_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
lora_param_weight *= lora_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
layer.to(device=TorchDevice.CPU_DEVICE)
return modified_cached_weights, modified_weights
layer.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from contextlib import ExitStack, contextmanager
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import torch
from diffusers import UNet2DConditionModel
@ -67,29 +67,18 @@ class ExtensionsManager:
if self._is_canceled and self._is_canceled():
raise CanceledException
modified_weights: Dict[str, torch.Tensor] = {}
modified_cached_weights: Set[str] = set()
original_weights: Dict[str, torch.Tensor] = {}
if cached_weights:
original_weights.update(cached_weights)
try:
with ExitStack() as exit_stack:
for ext in self._extensions:
ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context(
ext.patch_unet(unet, cached_weights)
)
modified_cached_weights.update(ext_modified_cached_weights)
# store only first returned weight for each key, because
# next extension which changes it, will work with already modified weight
for param_key, weight in ext_modified_weights.items():
if param_key in modified_weights:
continue
modified_weights[param_key] = weight
exit_stack.enter_context(ext.patch_unet(unet, original_weights))
yield None
finally:
with torch.no_grad():
for param_key in modified_cached_weights:
unet.get_parameter(param_key).copy_(cached_weights[param_key])
for param_key, weight in modified_weights.items():
for param_key, weight in original_weights.items():
unet.get_parameter(param_key).copy_(weight)