mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Suggested changes + simplify weights logic in patching
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
parent
8500bac3ca
commit
2227a2357f
@ -490,6 +490,9 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||
|
||||
# lora and locon
|
||||
if "lora_up.weight" in values:
|
||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -123,34 +123,25 @@ class ModelPatcher:
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||
"""
|
||||
modified_cached_weights: Set[str] = set()
|
||||
modified_weights: Dict[str, torch.Tensor] = {}
|
||||
original_weights: Dict[str, torch.Tensor] = {}
|
||||
if cached_weights:
|
||||
original_weights.update(cached_weights)
|
||||
try:
|
||||
for lora_model, lora_weight in loras:
|
||||
lora_modified_cached_weights, lora_modified_weights = LoRAExt.patch_model(
|
||||
LoRAExt.patch_model(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
lora=lora_model,
|
||||
lora_weight=lora_weight,
|
||||
cached_weights=cached_weights,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
del lora_model
|
||||
|
||||
modified_cached_weights.update(lora_modified_cached_weights)
|
||||
# Store only first returned weight for each key, because
|
||||
# next extension which changes it, will work with already modified weight
|
||||
for param_key, weight in lora_modified_weights.items():
|
||||
if param_key in modified_weights:
|
||||
continue
|
||||
modified_weights[param_key] = weight
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
with torch.no_grad():
|
||||
for param_key in modified_cached_weights:
|
||||
model.get_parameter(param_key).copy_(cached_weights[param_key])
|
||||
for param_key, weight in modified_weights.items():
|
||||
for param_key, weight in original_weights.items():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
@classmethod
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
@ -56,17 +56,17 @@ class ExtensionBase:
|
||||
yield None
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(
|
||||
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
|
||||
"""Apply patches to UNet model. This function responsible for restoring all changes except weights,
|
||||
changed weights should only be reported in return.
|
||||
Return contains 2 values:
|
||||
- Set of cached weights, just keys from cached_weights dictionary
|
||||
- Dict of not cached weights that should be copies on the cpu device
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
|
||||
"""A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
|
||||
diffusion process. Weight unpatching is handled upstream, and is achieved by adding unsaved weights in
|
||||
`original_weights` dict. Note that this enables some performance optimization by avoiding redundant operations.
|
||||
All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched by this
|
||||
context manager.
|
||||
|
||||
Args:
|
||||
unet (UNet2DConditionModel): The UNet model on execution device to patch.
|
||||
cached_weights (Optional[Dict[str, torch.Tensor]]): Read-only copy of the model's state dict in CPU, for caches purposes.
|
||||
cached_weights (Dict[str, torch.Tensor]]): A read-only copy of the model's original weights in CPU, for
|
||||
unpatching purposes. Extension can save tensor which being modified, if it is not saved yet, or can
|
||||
access original weight value.
|
||||
"""
|
||||
yield set(), {}
|
||||
yield
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
@ -21,9 +21,7 @@ class FreeUExt(ExtensionBase):
|
||||
self._freeu_config = freeu_config
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(
|
||||
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
|
||||
unet.enable_freeu(
|
||||
b1=self._freeu_config.b1,
|
||||
b2=self._freeu_config.b2,
|
||||
@ -32,6 +30,6 @@ class FreeUExt(ExtensionBase):
|
||||
)
|
||||
|
||||
try:
|
||||
yield set(), {}
|
||||
yield
|
||||
finally:
|
||||
unet.disable_freeu()
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
@ -28,44 +28,38 @@ class LoRAExt(ExtensionBase):
|
||||
self._weight = weight
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(
|
||||
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
|
||||
lora_model = self._node_context.models.load(self._model_id).model
|
||||
modified_cached_weights, modified_weights = self.patch_model(
|
||||
self.patch_model(
|
||||
model=unet,
|
||||
prefix="lora_unet_",
|
||||
lora=lora_model,
|
||||
lora_weight=self._weight,
|
||||
cached_weights=cached_weights,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
del lora_model
|
||||
|
||||
yield modified_cached_weights, modified_weights
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@torch.no_grad()
|
||||
def patch_model(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
lora: LoRAModelRaw,
|
||||
lora_weight: float,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
|
||||
original_weights: Dict[str, torch.Tensor],
|
||||
):
|
||||
"""
|
||||
Apply one or more LoRAs to a model.
|
||||
:param model: The model to patch.
|
||||
:param lora: LoRA model to patch in.
|
||||
:param lora_weight: LoRA patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:param cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||
:param original_weights: TODO:
|
||||
"""
|
||||
if cached_weights is None:
|
||||
cached_weights = {}
|
||||
|
||||
modified_weights: Dict[str, torch.Tensor] = {}
|
||||
modified_cached_weights: Set[str] = set()
|
||||
with torch.no_grad():
|
||||
# assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
@ -101,13 +95,8 @@ class LoRAExt(ExtensionBase):
|
||||
module_param = module.get_parameter(param_name)
|
||||
|
||||
# save original weight
|
||||
if param_key not in modified_cached_weights and param_key not in modified_weights:
|
||||
if param_key in cached_weights:
|
||||
modified_cached_weights.add(param_key)
|
||||
else:
|
||||
modified_weights[param_key] = module_param.detach().to(
|
||||
device=TorchDevice.CPU_DEVICE, copy=True
|
||||
)
|
||||
if param_key not in original_weights:
|
||||
original_weights[param_key] = module_param.detach().to(device=TorchDevice.CPU_DEVICE, copy=True)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
@ -118,8 +107,6 @@ class LoRAExt(ExtensionBase):
|
||||
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
return modified_cached_weights, modified_weights
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
@ -67,29 +67,18 @@ class ExtensionsManager:
|
||||
if self._is_canceled and self._is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
modified_weights: Dict[str, torch.Tensor] = {}
|
||||
modified_cached_weights: Set[str] = set()
|
||||
original_weights: Dict[str, torch.Tensor] = {}
|
||||
if cached_weights:
|
||||
original_weights.update(cached_weights)
|
||||
|
||||
try:
|
||||
with ExitStack() as exit_stack:
|
||||
for ext in self._extensions:
|
||||
ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context(
|
||||
ext.patch_unet(unet, cached_weights)
|
||||
)
|
||||
|
||||
modified_cached_weights.update(ext_modified_cached_weights)
|
||||
# store only first returned weight for each key, because
|
||||
# next extension which changes it, will work with already modified weight
|
||||
for param_key, weight in ext_modified_weights.items():
|
||||
if param_key in modified_weights:
|
||||
continue
|
||||
modified_weights[param_key] = weight
|
||||
exit_stack.enter_context(ext.patch_unet(unet, original_weights))
|
||||
|
||||
yield None
|
||||
|
||||
finally:
|
||||
with torch.no_grad():
|
||||
for param_key in modified_cached_weights:
|
||||
unet.get_parameter(param_key).copy_(cached_weights[param_key])
|
||||
for param_key, weight in modified_weights.items():
|
||||
for param_key, weight in original_weights.items():
|
||||
unet.get_parameter(param_key).copy_(weight)
|
||||
|
Loading…
Reference in New Issue
Block a user