Suggested changes

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
Sergey Borisov 2024-07-27 04:25:15 +03:00
parent faa88f72bf
commit 9e582563eb
5 changed files with 41 additions and 26 deletions

View File

@ -71,6 +71,9 @@ class LoRALayerBase:
self.bias = self.bias.to(device=device, dtype=dtype) self.bias = self.bias.to(device=device, dtype=dtype)
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]): def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
"""Log a warning if values contains unhandled keys."""
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"} all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
unknown_keys = set(values.keys()) - all_known_keys unknown_keys = set(values.keys()) - all_known_keys
if unknown_keys: if unknown_keys:
@ -232,7 +235,6 @@ class LoKRLayer(LoRALayerBase):
else: else:
self.rank = None # unscaled self.rank = None # unscaled
# Although lokr_t1 not used in algo, it still defined in LoKR weights
self.check_keys( self.check_keys(
values, values,
{ {
@ -242,7 +244,6 @@ class LoKRLayer(LoRALayerBase):
"lokr_w2", "lokr_w2",
"lokr_w2_a", "lokr_w2_a",
"lokr_w2_b", "lokr_w2_b",
"lokr_t1",
"lokr_t2", "lokr_t2",
}, },
) )

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List, Optional from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
@ -56,5 +56,17 @@ class ExtensionBase:
yield None yield None
@contextmanager @contextmanager
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): def patch_unet(
yield None self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
"""Apply patches to UNet model. This function responsible for restoring all changes except weights,
changed weights should only be reported in return.
Return contains 2 values:
- Set of cached weights, just keys from cached_weights dictionary
- Dict of not cached weights that should be copies on the cpu device
Args:
unet (UNet2DConditionModel): The UNet model on execution device to patch.
cached_weights (Optional[Dict[str, torch.Tensor]]): Read-only copy of the model's state dict in CPU, for caches purposes.
"""
yield set(), {}

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Optional from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
@ -21,7 +21,9 @@ class FreeUExt(ExtensionBase):
self._freeu_config = freeu_config self._freeu_config = freeu_config
@contextmanager @contextmanager
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): def patch_unet(
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
unet.enable_freeu( unet.enable_freeu(
b1=self._freeu_config.b1, b1=self._freeu_config.b1,
b2=self._freeu_config.b2, b2=self._freeu_config.b2,
@ -30,6 +32,6 @@ class FreeUExt(ExtensionBase):
) )
try: try:
yield yield set(), {}
finally: finally:
unet.disable_freeu() unet.disable_freeu()

View File

@ -28,7 +28,9 @@ class LoRAExt(ExtensionBase):
self._weight = weight self._weight = weight
@contextmanager @contextmanager
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): def patch_unet(
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
lora_model = self._node_context.models.load(self._model_id).model lora_model = self._node_context.models.load(self._model_id).model
modified_cached_weights, modified_weights = self.patch_model( modified_cached_weights, modified_weights = self.patch_model(
model=unet, model=unet,
@ -49,14 +51,14 @@ class LoRAExt(ExtensionBase):
lora: LoRAModelRaw, lora: LoRAModelRaw,
lora_weight: float, lora_weight: float,
cached_weights: Optional[Dict[str, torch.Tensor]] = None, cached_weights: Optional[Dict[str, torch.Tensor]] = None,
): ) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
""" """
Apply one or more LoRAs to a model. Apply one or more LoRAs to a model.
:param model: The model to patch. :param model: The model to patch.
:param lora: LoRA model to patch in. :param lora: LoRA model to patch in.
:param lora_weight: LoRA patch weight. :param lora_weight: LoRA patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers. :param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. :param cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
""" """
if cached_weights is None: if cached_weights is None:
cached_weights = {} cached_weights = {}

View File

@ -70,13 +70,12 @@ class ExtensionsManager:
modified_weights: Dict[str, torch.Tensor] = {} modified_weights: Dict[str, torch.Tensor] = {}
modified_cached_weights: Set[str] = set() modified_cached_weights: Set[str] = set()
exit_stack = ExitStack()
try: try:
with ExitStack() as exit_stack:
for ext in self._extensions: for ext in self._extensions:
res = exit_stack.enter_context(ext.patch_unet(unet, cached_weights)) ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context(
if res is None: ext.patch_unet(unet, cached_weights)
continue )
ext_modified_cached_weights, ext_modified_weights = res
modified_cached_weights.update(ext_modified_cached_weights) modified_cached_weights.update(ext_modified_cached_weights)
# store only first returned weight for each key, because # store only first returned weight for each key, because
@ -89,7 +88,6 @@ class ExtensionsManager:
yield None yield None
finally: finally:
exit_stack.close()
with torch.no_grad(): with torch.no_grad():
for param_key in modified_cached_weights: for param_key in modified_cached_weights:
unet.get_parameter(param_key).copy_(cached_weights[param_key]) unet.get_parameter(param_key).copy_(cached_weights[param_key])