Handle loras in modular denoise

This commit is contained in:
Sergey Borisov 2024-07-24 05:07:29 +03:00
parent 7c975f0d00
commit ab0bfa709a
4 changed files with 227 additions and 4 deletions

View File

@ -60,6 +60,7 @@ from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionB
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
from invokeai.backend.stable_diffusion.extensions.lora_patcher import LoRAPatcherExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
@ -833,6 +834,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
if self.unet.freeu_config: if self.unet.freeu_config:
ext_manager.add_extension(FreeUExt(self.unet.freeu_config)) ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
### lora
if self.unet.loras:
ext_manager.add_extension(
LoRAPatcherExt(
node_context=context,
loras=self.unet.loras,
prefix="lora_unet_",
)
)
# context for loading additional models # context for loading additional models
with ExitStack() as exit_stack: with ExitStack() as exit_stack:
# later should be smth like: # later should be smth like:

View File

@ -49,6 +49,9 @@ class LoRALayerBase:
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
raise NotImplementedError() raise NotImplementedError()
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
raise NotImplementedError()
def calc_size(self) -> int: def calc_size(self) -> int:
model_size = 0 model_size = 0
for val in [self.bias]: for val in [self.bias]:
@ -93,6 +96,9 @@ class LoRALayer(LoRALayerBase):
return weight return weight
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
return {"weight": self.get_weight(orig_module.weight)}
def calc_size(self) -> int: def calc_size(self) -> int:
model_size = super().calc_size() model_size = super().calc_size()
for val in [self.up, self.mid, self.down]: for val in [self.up, self.mid, self.down]:
@ -149,6 +155,9 @@ class LoHALayer(LoRALayerBase):
return weight return weight
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
return {"weight": self.get_weight(orig_module.weight)}
def calc_size(self) -> int: def calc_size(self) -> int:
model_size = super().calc_size() model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]: for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
@ -241,6 +250,9 @@ class LoKRLayer(LoRALayerBase):
return weight return weight
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
return {"weight": self.get_weight(orig_module.weight)}
def calc_size(self) -> int: def calc_size(self) -> int:
model_size = super().calc_size() model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]: for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
@ -293,6 +305,9 @@ class FullLayer(LoRALayerBase):
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
return self.weight return self.weight
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
return {"weight": self.get_weight(orig_module.weight)}
def calc_size(self) -> int: def calc_size(self) -> int:
model_size = super().calc_size() model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size() model_size += self.weight.nelement() * self.weight.element_size()
@ -327,6 +342,9 @@ class IA3Layer(LoRALayerBase):
assert orig_weight is not None assert orig_weight is not None
return orig_weight * weight return orig_weight * weight
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
return {"weight": self.get_weight(orig_module.weight)}
def calc_size(self) -> int: def calc_size(self) -> int:
model_size = super().calc_size() model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size() model_size += self.weight.nelement() * self.weight.element_size()

View File

@ -0,0 +1,172 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple
import torch
from diffusers import UNet2DConditionModel
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.invocations.model import LoRAField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
class LoRAPatcherExt(ExtensionBase):
def __init__(
self,
node_context: InvocationContext,
loras: List[LoRAField],
prefix: str,
):
super().__init__()
self._loras = loras
self._prefix = prefix
self._node_context = node_context
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self._loras:
lora_info = self._node_context.models.load(lora.lora)
lora_model = lora_info.model
yield (lora_model, lora.weight)
del lora_info
return
yield self._patch_model(
model=unet,
prefix=self._prefix,
loras=_lora_loader(),
cached_weights=cached_weights,
)
@classmethod
@contextmanager
def static_patch_model(
cls,
model: torch.nn.Module,
prefix: str,
loras: Iterator[Tuple[LoRAModelRaw, float]],
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
):
modified_cached_weights, modified_weights = cls._patch_model(
model=model,
prefix=prefix,
loras=loras,
cached_weights=cached_weights,
)
try:
yield
finally:
with torch.no_grad():
for param_key in modified_cached_weights:
model.get_parameter(param_key).copy_(cached_weights[param_key])
for param_key, weight in modified_weights.items():
model.get_parameter(param_key).copy_(weight)
@classmethod
def _patch_model(
cls,
model: UNet2DConditionModel,
prefix: str,
loras: Iterator[Tuple[LoRAModelRaw, float]],
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
):
"""
Apply one or more LoRAs to a model.
:param model: The model to patch.
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
"""
if cached_weights is None:
cached_weights = {}
modified_weights = {}
modified_cached_weights = set()
with torch.no_grad():
for lora, lora_weight in loras:
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
# should be improved in the following ways:
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
# LoRA model is applied.
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys.
assert isinstance(model, torch.nn.Module)
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in layer.get_parameters(module).items():
param_key = module_key + "." + param_name
module_param = module.get_parameter(param_name)
# save original weight
if param_key not in modified_cached_weights and param_key not in modified_weights:
if param_key in cached_weights:
modified_cached_weights.add(param_key)
else:
modified_weights[param_key] = module_param.detach().to(
device=TorchDevice.CPU_DEVICE, copy=True
)
if module_param.shape != lora_param_weight.shape:
# TODO: debug on lycoris
lora_param_weight = lora_param_weight.reshape(module_param.shape)
lora_param_weight *= (lora_weight * layer_scale)
module_param += lora_param_weight.to(dtype=dtype)
layer.to(device=TorchDevice.CPU_DEVICE)
return modified_cached_weights, modified_weights
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
if not lora_key.startswith(prefix):
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
module = model
module_key = ""
key_parts = lora_key[len(prefix) :].split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = (module_key + "." + submodule_name).lstrip(".")
return (module_key, module)

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import ExitStack, contextmanager from contextlib import ExitStack, contextmanager
from typing import TYPE_CHECKING, Callable, Dict, List, Optional from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
@ -67,9 +67,31 @@ class ExtensionsManager:
if self._is_canceled and self._is_canceled(): if self._is_canceled and self._is_canceled():
raise CanceledException raise CanceledException
# TODO: create weight patch logic in PR with extension which uses it modified_weights: Dict[str, torch.Tensor] = {}
with ExitStack() as exit_stack: modified_cached_weights: Set[str] = set()
exit_stack = ExitStack()
try:
for ext in self._extensions: for ext in self._extensions:
exit_stack.enter_context(ext.patch_unet(unet, cached_weights)) res = exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
if res is None:
continue
ext_modified_cached_weights, ext_modified_weights = res
modified_cached_weights.update(ext_modified_cached_weights)
# store only first returned weight for each key, because
# next extension which changes it, will work with already modified weight
for param_key, weight in ext_modified_weights.items():
if param_key in modified_weights:
continue
modified_weights[param_key] = weight
yield None yield None
finally:
exit_stack.close()
with torch.no_grad():
for param_key in modified_cached_weights:
unet.get_parameter(param_key).copy_(cached_weights[param_key])
for param_key, weight in modified_weights.items():
unet.get_parameter(param_key).copy_(weight)