Suggested changes + simplify weights logic in patching

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
Sergey Borisov 2024-07-30 00:34:37 +03:00
parent 8500bac3ca
commit 2227a2357f
6 changed files with 76 additions and 108 deletions

View File

@ -490,6 +490,9 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
for layer_key, values in state_dict.items():
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
# lora and locon
if "lora_up.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Set, Tuple, Type, Union
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
import numpy as np
import torch
@ -123,34 +123,25 @@ class ModelPatcher:
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
"""
modified_cached_weights: Set[str] = set()
modified_weights: Dict[str, torch.Tensor] = {}
original_weights: Dict[str, torch.Tensor] = {}
if cached_weights:
original_weights.update(cached_weights)
try:
for lora_model, lora_weight in loras:
lora_modified_cached_weights, lora_modified_weights = LoRAExt.patch_model(
LoRAExt.patch_model(
model=model,
prefix=prefix,
lora=lora_model,
lora_weight=lora_weight,
cached_weights=cached_weights,
original_weights=original_weights,
)
del lora_model
modified_cached_weights.update(lora_modified_cached_weights)
# Store only first returned weight for each key, because
# next extension which changes it, will work with already modified weight
for param_key, weight in lora_modified_weights.items():
if param_key in modified_weights:
continue
modified_weights[param_key] = weight
yield
finally:
with torch.no_grad():
for param_key in modified_cached_weights:
model.get_parameter(param_key).copy_(cached_weights[param_key])
for param_key, weight in modified_weights.items():
for param_key, weight in original_weights.items():
model.get_parameter(param_key).copy_(weight)
@classmethod

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Callable, Dict, List
import torch
from diffusers import UNet2DConditionModel
@ -56,17 +56,17 @@ class ExtensionBase:
yield None
@contextmanager
def patch_unet(
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
"""Apply patches to UNet model. This function responsible for restoring all changes except weights,
changed weights should only be reported in return.
Return contains 2 values:
- Set of cached weights, just keys from cached_weights dictionary
- Dict of not cached weights that should be copies on the cpu device
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
"""A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
diffusion process. Weight unpatching is handled upstream, and is achieved by adding unsaved weights in
`original_weights` dict. Note that this enables some performance optimization by avoiding redundant operations.
All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched by this
context manager.
Args:
unet (UNet2DConditionModel): The UNet model on execution device to patch.
cached_weights (Optional[Dict[str, torch.Tensor]]): Read-only copy of the model's state dict in CPU, for caches purposes.
cached_weights (Dict[str, torch.Tensor]]): A read-only copy of the model's original weights in CPU, for
unpatching purposes. Extension can save tensor which being modified, if it is not saved yet, or can
access original weight value.
"""
yield set(), {}
yield

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict
import torch
from diffusers import UNet2DConditionModel
@ -21,9 +21,7 @@ class FreeUExt(ExtensionBase):
self._freeu_config = freeu_config
@contextmanager
def patch_unet(
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
unet.enable_freeu(
b1=self._freeu_config.b1,
b2=self._freeu_config.b2,
@ -32,6 +30,6 @@ class FreeUExt(ExtensionBase):
)
try:
yield set(), {}
yield
finally:
unet.disable_freeu()

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, Tuple
import torch
from diffusers import UNet2DConditionModel
@ -28,44 +28,38 @@ class LoRAExt(ExtensionBase):
self._weight = weight
@contextmanager
def patch_unet(
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
lora_model = self._node_context.models.load(self._model_id).model
modified_cached_weights, modified_weights = self.patch_model(
self.patch_model(
model=unet,
prefix="lora_unet_",
lora=lora_model,
lora_weight=self._weight,
cached_weights=cached_weights,
original_weights=original_weights,
)
del lora_model
yield modified_cached_weights, modified_weights
yield
@classmethod
@torch.no_grad()
def patch_model(
cls,
model: torch.nn.Module,
prefix: str,
lora: LoRAModelRaw,
lora_weight: float,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
original_weights: Dict[str, torch.Tensor],
):
"""
Apply one or more LoRAs to a model.
:param model: The model to patch.
:param lora: LoRA model to patch in.
:param lora_weight: LoRA patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:param cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
:param original_weights: TODO:
"""
if cached_weights is None:
cached_weights = {}
modified_weights: Dict[str, torch.Tensor] = {}
modified_cached_weights: Set[str] = set()
with torch.no_grad():
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
@ -101,13 +95,8 @@ class LoRAExt(ExtensionBase):
module_param = module.get_parameter(param_name)
# save original weight
if param_key not in modified_cached_weights and param_key not in modified_weights:
if param_key in cached_weights:
modified_cached_weights.add(param_key)
else:
modified_weights[param_key] = module_param.detach().to(
device=TorchDevice.CPU_DEVICE, copy=True
)
if param_key not in original_weights:
original_weights[param_key] = module_param.detach().to(device=TorchDevice.CPU_DEVICE, copy=True)
if module_param.shape != lora_param_weight.shape:
# TODO: debug on lycoris
@ -118,8 +107,6 @@ class LoRAExt(ExtensionBase):
layer.to(device=TorchDevice.CPU_DEVICE)
return modified_cached_weights, modified_weights
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from contextlib import ExitStack, contextmanager
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import torch
from diffusers import UNet2DConditionModel
@ -67,29 +67,18 @@ class ExtensionsManager:
if self._is_canceled and self._is_canceled():
raise CanceledException
modified_weights: Dict[str, torch.Tensor] = {}
modified_cached_weights: Set[str] = set()
original_weights: Dict[str, torch.Tensor] = {}
if cached_weights:
original_weights.update(cached_weights)
try:
with ExitStack() as exit_stack:
for ext in self._extensions:
ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context(
ext.patch_unet(unet, cached_weights)
)
modified_cached_weights.update(ext_modified_cached_weights)
# store only first returned weight for each key, because
# next extension which changes it, will work with already modified weight
for param_key, weight in ext_modified_weights.items():
if param_key in modified_weights:
continue
modified_weights[param_key] = weight
exit_stack.enter_context(ext.patch_unet(unet, original_weights))
yield None
finally:
with torch.no_grad():
for param_key in modified_cached_weights:
unet.get_parameter(param_key).copy_(cached_weights[param_key])
for param_key, weight in modified_weights.items():
for param_key, weight in original_weights.items():
unet.get_parameter(param_key).copy_(weight)