mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Suggested changes + simplify weights logic in patching
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
parent
8500bac3ca
commit
2227a2357f
@ -490,6 +490,9 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
|||||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||||
|
|
||||||
for layer_key, values in state_dict.items():
|
for layer_key, values in state_dict.items():
|
||||||
|
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||||
|
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||||
|
|
||||||
# lora and locon
|
# lora and locon
|
||||||
if "lora_up.weight" in values:
|
if "lora_up.weight" in values:
|
||||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Set, Tuple, Type, Union
|
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -123,34 +123,25 @@ class ModelPatcher:
|
|||||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||||
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||||
"""
|
"""
|
||||||
modified_cached_weights: Set[str] = set()
|
original_weights: Dict[str, torch.Tensor] = {}
|
||||||
modified_weights: Dict[str, torch.Tensor] = {}
|
if cached_weights:
|
||||||
|
original_weights.update(cached_weights)
|
||||||
try:
|
try:
|
||||||
for lora_model, lora_weight in loras:
|
for lora_model, lora_weight in loras:
|
||||||
lora_modified_cached_weights, lora_modified_weights = LoRAExt.patch_model(
|
LoRAExt.patch_model(
|
||||||
model=model,
|
model=model,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
lora=lora_model,
|
lora=lora_model,
|
||||||
lora_weight=lora_weight,
|
lora_weight=lora_weight,
|
||||||
cached_weights=cached_weights,
|
original_weights=original_weights,
|
||||||
)
|
)
|
||||||
del lora_model
|
del lora_model
|
||||||
|
|
||||||
modified_cached_weights.update(lora_modified_cached_weights)
|
|
||||||
# Store only first returned weight for each key, because
|
|
||||||
# next extension which changes it, will work with already modified weight
|
|
||||||
for param_key, weight in lora_modified_weights.items():
|
|
||||||
if param_key in modified_weights:
|
|
||||||
continue
|
|
||||||
modified_weights[param_key] = weight
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for param_key in modified_cached_weights:
|
for param_key, weight in original_weights.items():
|
||||||
model.get_parameter(param_key).copy_(cached_weights[param_key])
|
|
||||||
for param_key, weight in modified_weights.items():
|
|
||||||
model.get_parameter(param_key).copy_(weight)
|
model.get_parameter(param_key).copy_(weight)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Callable, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
@ -56,17 +56,17 @@ class ExtensionBase:
|
|||||||
yield None
|
yield None
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(
|
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
|
||||||
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
|
"""A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
|
||||||
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
|
diffusion process. Weight unpatching is handled upstream, and is achieved by adding unsaved weights in
|
||||||
"""Apply patches to UNet model. This function responsible for restoring all changes except weights,
|
`original_weights` dict. Note that this enables some performance optimization by avoiding redundant operations.
|
||||||
changed weights should only be reported in return.
|
All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched by this
|
||||||
Return contains 2 values:
|
context manager.
|
||||||
- Set of cached weights, just keys from cached_weights dictionary
|
|
||||||
- Dict of not cached weights that should be copies on the cpu device
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
unet (UNet2DConditionModel): The UNet model on execution device to patch.
|
unet (UNet2DConditionModel): The UNet model on execution device to patch.
|
||||||
cached_weights (Optional[Dict[str, torch.Tensor]]): Read-only copy of the model's state dict in CPU, for caches purposes.
|
cached_weights (Dict[str, torch.Tensor]]): A read-only copy of the model's original weights in CPU, for
|
||||||
|
unpatching purposes. Extension can save tensor which being modified, if it is not saved yet, or can
|
||||||
|
access original weight value.
|
||||||
"""
|
"""
|
||||||
yield set(), {}
|
yield
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
@ -21,9 +21,7 @@ class FreeUExt(ExtensionBase):
|
|||||||
self._freeu_config = freeu_config
|
self._freeu_config = freeu_config
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(
|
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
|
||||||
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
|
|
||||||
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
|
|
||||||
unet.enable_freeu(
|
unet.enable_freeu(
|
||||||
b1=self._freeu_config.b1,
|
b1=self._freeu_config.b1,
|
||||||
b2=self._freeu_config.b2,
|
b2=self._freeu_config.b2,
|
||||||
@ -32,6 +30,6 @@ class FreeUExt(ExtensionBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield set(), {}
|
yield
|
||||||
finally:
|
finally:
|
||||||
unet.disable_freeu()
|
unet.disable_freeu()
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
@ -28,44 +28,38 @@ class LoRAExt(ExtensionBase):
|
|||||||
self._weight = weight
|
self._weight = weight
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(
|
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
|
||||||
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
|
|
||||||
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
|
|
||||||
lora_model = self._node_context.models.load(self._model_id).model
|
lora_model = self._node_context.models.load(self._model_id).model
|
||||||
modified_cached_weights, modified_weights = self.patch_model(
|
self.patch_model(
|
||||||
model=unet,
|
model=unet,
|
||||||
prefix="lora_unet_",
|
prefix="lora_unet_",
|
||||||
lora=lora_model,
|
lora=lora_model,
|
||||||
lora_weight=self._weight,
|
lora_weight=self._weight,
|
||||||
cached_weights=cached_weights,
|
original_weights=original_weights,
|
||||||
)
|
)
|
||||||
del lora_model
|
del lora_model
|
||||||
|
|
||||||
yield modified_cached_weights, modified_weights
|
yield
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@torch.no_grad()
|
||||||
def patch_model(
|
def patch_model(
|
||||||
cls,
|
cls,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
lora: LoRAModelRaw,
|
lora: LoRAModelRaw,
|
||||||
lora_weight: float,
|
lora_weight: float,
|
||||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
original_weights: Dict[str, torch.Tensor],
|
||||||
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
|
):
|
||||||
"""
|
"""
|
||||||
Apply one or more LoRAs to a model.
|
Apply one or more LoRAs to a model.
|
||||||
:param model: The model to patch.
|
:param model: The model to patch.
|
||||||
:param lora: LoRA model to patch in.
|
:param lora: LoRA model to patch in.
|
||||||
:param lora_weight: LoRA patch weight.
|
:param lora_weight: LoRA patch weight.
|
||||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||||
:param cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
:param original_weights: TODO:
|
||||||
"""
|
"""
|
||||||
if cached_weights is None:
|
|
||||||
cached_weights = {}
|
|
||||||
|
|
||||||
modified_weights: Dict[str, torch.Tensor] = {}
|
|
||||||
modified_cached_weights: Set[str] = set()
|
|
||||||
with torch.no_grad():
|
|
||||||
# assert lora.device.type == "cpu"
|
# assert lora.device.type == "cpu"
|
||||||
for layer_key, layer in lora.layers.items():
|
for layer_key, layer in lora.layers.items():
|
||||||
if not layer_key.startswith(prefix):
|
if not layer_key.startswith(prefix):
|
||||||
@ -101,13 +95,8 @@ class LoRAExt(ExtensionBase):
|
|||||||
module_param = module.get_parameter(param_name)
|
module_param = module.get_parameter(param_name)
|
||||||
|
|
||||||
# save original weight
|
# save original weight
|
||||||
if param_key not in modified_cached_weights and param_key not in modified_weights:
|
if param_key not in original_weights:
|
||||||
if param_key in cached_weights:
|
original_weights[param_key] = module_param.detach().to(device=TorchDevice.CPU_DEVICE, copy=True)
|
||||||
modified_cached_weights.add(param_key)
|
|
||||||
else:
|
|
||||||
modified_weights[param_key] = module_param.detach().to(
|
|
||||||
device=TorchDevice.CPU_DEVICE, copy=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if module_param.shape != lora_param_weight.shape:
|
if module_param.shape != lora_param_weight.shape:
|
||||||
# TODO: debug on lycoris
|
# TODO: debug on lycoris
|
||||||
@ -118,8 +107,6 @@ class LoRAExt(ExtensionBase):
|
|||||||
|
|
||||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||||
|
|
||||||
return modified_cached_weights, modified_weights
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||||
assert "." not in lora_key
|
assert "." not in lora_key
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import ExitStack, contextmanager
|
from contextlib import ExitStack, contextmanager
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
@ -67,29 +67,18 @@ class ExtensionsManager:
|
|||||||
if self._is_canceled and self._is_canceled():
|
if self._is_canceled and self._is_canceled():
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
|
|
||||||
modified_weights: Dict[str, torch.Tensor] = {}
|
original_weights: Dict[str, torch.Tensor] = {}
|
||||||
modified_cached_weights: Set[str] = set()
|
if cached_weights:
|
||||||
|
original_weights.update(cached_weights)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with ExitStack() as exit_stack:
|
with ExitStack() as exit_stack:
|
||||||
for ext in self._extensions:
|
for ext in self._extensions:
|
||||||
ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context(
|
exit_stack.enter_context(ext.patch_unet(unet, original_weights))
|
||||||
ext.patch_unet(unet, cached_weights)
|
|
||||||
)
|
|
||||||
|
|
||||||
modified_cached_weights.update(ext_modified_cached_weights)
|
|
||||||
# store only first returned weight for each key, because
|
|
||||||
# next extension which changes it, will work with already modified weight
|
|
||||||
for param_key, weight in ext_modified_weights.items():
|
|
||||||
if param_key in modified_weights:
|
|
||||||
continue
|
|
||||||
modified_weights[param_key] = weight
|
|
||||||
|
|
||||||
yield None
|
yield None
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for param_key in modified_cached_weights:
|
for param_key, weight in original_weights.items():
|
||||||
unet.get_parameter(param_key).copy_(cached_weights[param_key])
|
|
||||||
for param_key, weight in modified_weights.items():
|
|
||||||
unet.get_parameter(param_key).copy_(weight)
|
unet.get_parameter(param_key).copy_(weight)
|
||||||
|
Loading…
Reference in New Issue
Block a user