mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
parent
faa88f72bf
commit
9e582563eb
@ -71,6 +71,9 @@ class LoRALayerBase:
|
|||||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
||||||
|
"""Log a warning if values contains unhandled keys."""
|
||||||
|
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
||||||
|
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
||||||
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||||
unknown_keys = set(values.keys()) - all_known_keys
|
unknown_keys = set(values.keys()) - all_known_keys
|
||||||
if unknown_keys:
|
if unknown_keys:
|
||||||
@ -232,7 +235,6 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
else:
|
else:
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
|
||||||
# Although lokr_t1 not used in algo, it still defined in LoKR weights
|
|
||||||
self.check_keys(
|
self.check_keys(
|
||||||
values,
|
values,
|
||||||
{
|
{
|
||||||
@ -242,7 +244,6 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
"lokr_w2",
|
"lokr_w2",
|
||||||
"lokr_w2_a",
|
"lokr_w2_a",
|
||||||
"lokr_w2_b",
|
"lokr_w2_b",
|
||||||
"lokr_t1",
|
|
||||||
"lokr_t2",
|
"lokr_t2",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
@ -56,5 +56,17 @@ class ExtensionBase:
|
|||||||
yield None
|
yield None
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
def patch_unet(
|
||||||
yield None
|
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
|
||||||
|
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
|
||||||
|
"""Apply patches to UNet model. This function responsible for restoring all changes except weights,
|
||||||
|
changed weights should only be reported in return.
|
||||||
|
Return contains 2 values:
|
||||||
|
- Set of cached weights, just keys from cached_weights dictionary
|
||||||
|
- Dict of not cached weights that should be copies on the cpu device
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unet (UNet2DConditionModel): The UNet model on execution device to patch.
|
||||||
|
cached_weights (Optional[Dict[str, torch.Tensor]]): Read-only copy of the model's state dict in CPU, for caches purposes.
|
||||||
|
"""
|
||||||
|
yield set(), {}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Dict, Optional
|
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
@ -21,7 +21,9 @@ class FreeUExt(ExtensionBase):
|
|||||||
self._freeu_config = freeu_config
|
self._freeu_config = freeu_config
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
def patch_unet(
|
||||||
|
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
|
||||||
|
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
|
||||||
unet.enable_freeu(
|
unet.enable_freeu(
|
||||||
b1=self._freeu_config.b1,
|
b1=self._freeu_config.b1,
|
||||||
b2=self._freeu_config.b2,
|
b2=self._freeu_config.b2,
|
||||||
@ -30,6 +32,6 @@ class FreeUExt(ExtensionBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield set(), {}
|
||||||
finally:
|
finally:
|
||||||
unet.disable_freeu()
|
unet.disable_freeu()
|
||||||
|
@ -28,7 +28,9 @@ class LoRAExt(ExtensionBase):
|
|||||||
self._weight = weight
|
self._weight = weight
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
def patch_unet(
|
||||||
|
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
|
||||||
|
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
|
||||||
lora_model = self._node_context.models.load(self._model_id).model
|
lora_model = self._node_context.models.load(self._model_id).model
|
||||||
modified_cached_weights, modified_weights = self.patch_model(
|
modified_cached_weights, modified_weights = self.patch_model(
|
||||||
model=unet,
|
model=unet,
|
||||||
@ -49,14 +51,14 @@ class LoRAExt(ExtensionBase):
|
|||||||
lora: LoRAModelRaw,
|
lora: LoRAModelRaw,
|
||||||
lora_weight: float,
|
lora_weight: float,
|
||||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
):
|
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Apply one or more LoRAs to a model.
|
Apply one or more LoRAs to a model.
|
||||||
:param model: The model to patch.
|
:param model: The model to patch.
|
||||||
:param lora: LoRA model to patch in.
|
:param lora: LoRA model to patch in.
|
||||||
:param lora_weight: LoRA patch weight.
|
:param lora_weight: LoRA patch weight.
|
||||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||||
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
:param cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||||
"""
|
"""
|
||||||
if cached_weights is None:
|
if cached_weights is None:
|
||||||
cached_weights = {}
|
cached_weights = {}
|
||||||
|
@ -70,13 +70,12 @@ class ExtensionsManager:
|
|||||||
modified_weights: Dict[str, torch.Tensor] = {}
|
modified_weights: Dict[str, torch.Tensor] = {}
|
||||||
modified_cached_weights: Set[str] = set()
|
modified_cached_weights: Set[str] = set()
|
||||||
|
|
||||||
exit_stack = ExitStack()
|
|
||||||
try:
|
try:
|
||||||
|
with ExitStack() as exit_stack:
|
||||||
for ext in self._extensions:
|
for ext in self._extensions:
|
||||||
res = exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
|
ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context(
|
||||||
if res is None:
|
ext.patch_unet(unet, cached_weights)
|
||||||
continue
|
)
|
||||||
ext_modified_cached_weights, ext_modified_weights = res
|
|
||||||
|
|
||||||
modified_cached_weights.update(ext_modified_cached_weights)
|
modified_cached_weights.update(ext_modified_cached_weights)
|
||||||
# store only first returned weight for each key, because
|
# store only first returned weight for each key, because
|
||||||
@ -89,7 +88,6 @@ class ExtensionsManager:
|
|||||||
yield None
|
yield None
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
exit_stack.close()
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for param_key in modified_cached_weights:
|
for param_key in modified_cached_weights:
|
||||||
unet.get_parameter(param_key).copy_(cached_weights[param_key])
|
unet.get_parameter(param_key).copy_(cached_weights[param_key])
|
||||||
|
Loading…
Reference in New Issue
Block a user