mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
parent
faa88f72bf
commit
9e582563eb
@ -71,6 +71,9 @@ class LoRALayerBase:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
||||
"""Log a warning if values contains unhandled keys."""
|
||||
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
||||
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
||||
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||
unknown_keys = set(values.keys()) - all_known_keys
|
||||
if unknown_keys:
|
||||
@ -232,7 +235,6 @@ class LoKRLayer(LoRALayerBase):
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
# Although lokr_t1 not used in algo, it still defined in LoKR weights
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
@ -242,7 +244,6 @@ class LoKRLayer(LoRALayerBase):
|
||||
"lokr_w2",
|
||||
"lokr_w2_a",
|
||||
"lokr_w2_b",
|
||||
"lokr_t1",
|
||||
"lokr_t2",
|
||||
},
|
||||
)
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
@ -56,5 +56,17 @@ class ExtensionBase:
|
||||
yield None
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||
yield None
|
||||
def patch_unet(
|
||||
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
|
||||
"""Apply patches to UNet model. This function responsible for restoring all changes except weights,
|
||||
changed weights should only be reported in return.
|
||||
Return contains 2 values:
|
||||
- Set of cached weights, just keys from cached_weights dictionary
|
||||
- Dict of not cached weights that should be copies on the cpu device
|
||||
|
||||
Args:
|
||||
unet (UNet2DConditionModel): The UNet model on execution device to patch.
|
||||
cached_weights (Optional[Dict[str, torch.Tensor]]): Read-only copy of the model's state dict in CPU, for caches purposes.
|
||||
"""
|
||||
yield set(), {}
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
@ -21,7 +21,9 @@ class FreeUExt(ExtensionBase):
|
||||
self._freeu_config = freeu_config
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||
def patch_unet(
|
||||
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
|
||||
unet.enable_freeu(
|
||||
b1=self._freeu_config.b1,
|
||||
b2=self._freeu_config.b2,
|
||||
@ -30,6 +32,6 @@ class FreeUExt(ExtensionBase):
|
||||
)
|
||||
|
||||
try:
|
||||
yield
|
||||
yield set(), {}
|
||||
finally:
|
||||
unet.disable_freeu()
|
||||
|
@ -28,7 +28,9 @@ class LoRAExt(ExtensionBase):
|
||||
self._weight = weight
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||
def patch_unet(
|
||||
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
|
||||
lora_model = self._node_context.models.load(self._model_id).model
|
||||
modified_cached_weights, modified_weights = self.patch_model(
|
||||
model=unet,
|
||||
@ -49,14 +51,14 @@ class LoRAExt(ExtensionBase):
|
||||
lora: LoRAModelRaw,
|
||||
lora_weight: float,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Apply one or more LoRAs to a model.
|
||||
:param model: The model to patch.
|
||||
:param lora: LoRA model to patch in.
|
||||
:param lora_weight: LoRA patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||
:param cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||
"""
|
||||
if cached_weights is None:
|
||||
cached_weights = {}
|
||||
|
@ -70,26 +70,24 @@ class ExtensionsManager:
|
||||
modified_weights: Dict[str, torch.Tensor] = {}
|
||||
modified_cached_weights: Set[str] = set()
|
||||
|
||||
exit_stack = ExitStack()
|
||||
try:
|
||||
for ext in self._extensions:
|
||||
res = exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
|
||||
if res is None:
|
||||
continue
|
||||
ext_modified_cached_weights, ext_modified_weights = res
|
||||
with ExitStack() as exit_stack:
|
||||
for ext in self._extensions:
|
||||
ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context(
|
||||
ext.patch_unet(unet, cached_weights)
|
||||
)
|
||||
|
||||
modified_cached_weights.update(ext_modified_cached_weights)
|
||||
# store only first returned weight for each key, because
|
||||
# next extension which changes it, will work with already modified weight
|
||||
for param_key, weight in ext_modified_weights.items():
|
||||
if param_key in modified_weights:
|
||||
continue
|
||||
modified_weights[param_key] = weight
|
||||
modified_cached_weights.update(ext_modified_cached_weights)
|
||||
# store only first returned weight for each key, because
|
||||
# next extension which changes it, will work with already modified weight
|
||||
for param_key, weight in ext_modified_weights.items():
|
||||
if param_key in modified_weights:
|
||||
continue
|
||||
modified_weights[param_key] = weight
|
||||
|
||||
yield None
|
||||
yield None
|
||||
|
||||
finally:
|
||||
exit_stack.close()
|
||||
with torch.no_grad():
|
||||
for param_key in modified_cached_weights:
|
||||
unet.get_parameter(param_key).copy_(cached_weights[param_key])
|
||||
|
Loading…
Reference in New Issue
Block a user