Consolidate all LoRA patching logic in the LoRAPatcher.

This commit is contained in:
Ryan Dick
2024-09-10 14:45:40 +00:00
committed by Kent Keirsey
parent ee5d8f6caf
commit fef26a5f2f
7 changed files with 87 additions and 250 deletions

View File

@ -20,6 +20,7 @@ from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
@ -81,9 +82,10 @@ class CompelInvocation(BaseInvocation):
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(
text_encoder,
loras=_lora_loader(),
LoRAPatcher.apply_lora_patches(
model=text_encoder,
patches=_lora_loader(),
prefix="lora_te_",
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
@ -176,9 +178,9 @@ class SDXLPromptInvocationBase:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora(
LoRAPatcher.apply_lora_patches(
text_encoder,
loras=_lora_loader(),
patches=_lora_loader(),
prefix=lora_prefix,
cached_weights=cached_weights,
),

View File

@ -37,6 +37,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState
@ -979,9 +980,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(
unet,
loras=_lora_loader(),
LoRAPatcher.apply_lora_patches(
model=unet,
patches=_lora_loader(),
prefix="lora_unet_",
cached_weights=cached_weights,
),
):

View File

@ -23,7 +23,7 @@ from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
@ -204,7 +204,11 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
with (
ExitStack() as exit_stack,
unet_info as unet,
LoRAPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:

View File

@ -1,5 +1,5 @@
from contextlib import contextmanager
from typing import Dict, Iterator, Optional, Tuple
from typing import Dict, Iterable, Optional, Tuple
import torch
@ -9,17 +9,16 @@ from invokeai.backend.util.original_weights_storage import OriginalWeightsStorag
class LoRAPatcher:
@classmethod
@staticmethod
@torch.no_grad()
@contextmanager
def apply_lora_patches(
cls,
model: torch.nn.Module,
patches: Iterator[Tuple[LoRAModelRaw, float]],
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
):
"""Apply one or more LoRA patches to a model.
"""Apply one or more LoRA patches to a model within a context manager.
:param model: The model to patch.
:param loras: An iterator that returns tuples of LoRA patches and associated weights. An iterator is used so
@ -30,23 +29,23 @@ class LoRAPatcher:
original_weights = OriginalWeightsStorage(cached_weights)
try:
for patch, patch_weight in patches:
cls._apply_lora_patch(
LoRAPatcher.apply_lora_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_weights=original_weights,
)
del patch
yield
finally:
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
@classmethod
@staticmethod
@torch.no_grad()
def _apply_lora_patch(
cls,
def apply_lora_patch(
model: torch.nn.Module,
prefix: str,
patch: LoRAModelRaw,
@ -54,7 +53,7 @@ class LoRAPatcher:
original_weights: OriginalWeightsStorage,
):
"""
Apply one a LoRA to a model.
Apply a single LoRA patch to a model.
:param model: The model to patch.
:param patch: LoRA model to patch in.
:param patch_weight: LoRA patch weight.
@ -65,11 +64,21 @@ class LoRAPatcher:
if patch_weight == 0:
return
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
prefix_len = len(prefix)
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module = model.get_submodule(layer_key)
module_key, module = LoRAPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
@ -87,7 +96,7 @@ class LoRAPatcher:
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in layer.get_parameters(module).items():
param_key = layer_key + "." + param_name
param_key = module_key + "." + param_name
module_param = module.get_parameter(param_name)
# Save original weight
@ -100,3 +109,40 @@ class LoRAPatcher:
module_param += lora_param_weight.to(dtype=dtype)
layer.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
def _get_submodule(
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool
) -> tuple[str, torch.nn.Module]:
"""Get the submodule corresponding to the given layer key.
:param model: The model to search.
:param layer_key: The layer key to search for.
:param layer_key_is_flattened: Whether the layer key is flattened. If flattened, then all '.' have been replaced
with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly without
searching, but some legacy code still uses flattened keys.
:return: A tuple containing the module key and the submodule.
"""
if not layer_key_is_flattened:
return layer_key, model.get_submodule(layer_key)
# Handle flattened keys.
assert "." not in layer_key
module = model
module_key = ""
key_parts = layer_key.split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = (module_key + "." + submodule_name).lstrip(".")
return module_key, module

View File

@ -5,32 +5,18 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
import numpy as np
import torch
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
from diffusers import UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
"""
loras = [
(lora_model1, 0.7),
(lora_model2, 0.4),
]
with LoRAHelper.apply_lora_unet(unet, loras):
# unet with applied loras
# unmodified unet
"""
class ModelPatcher:
@ -54,95 +40,6 @@ class ModelPatcher:
finally:
unet.set_attn_processor(unet_orig_processors)
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
if not lora_key.startswith(prefix):
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
module = model
module_key = ""
key_parts = lora_key[len(prefix) :].split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = (module_key + "." + submodule_name).lstrip(".")
return (module_key, module)
@classmethod
@contextmanager
def apply_lora_unet(
cls,
unet: UNet2DConditionModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
with cls.apply_lora(
unet,
loras=loras,
prefix="lora_unet_",
cached_weights=cached_weights,
):
yield
@classmethod
@contextmanager
def apply_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
yield
@classmethod
@contextmanager
def apply_lora(
cls,
model: AnyModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
"""
Apply one or more LoRAs to a model.
:param model: The model to patch.
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
"""
original_weights = OriginalWeightsStorage(cached_weights)
try:
for lora_model, lora_weight in loras:
LoRAExt.patch_model(
model=model,
prefix=prefix,
lora=lora_model,
lora_weight=lora_weight,
original_weights=original_weights,
)
del lora_model
yield
finally:
with torch.no_grad():
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
@classmethod
@contextmanager
def apply_ti(
@ -282,26 +179,6 @@ class ModelPatcher:
class ONNXModelPatcher:
@classmethod
@contextmanager
def apply_lora_unet(
cls,
unet: OnnxRuntimeModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@classmethod
@contextmanager
def apply_lora_text_encoder(
cls,
text_encoder: OnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
# based on
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
@classmethod

View File

@ -1,14 +1,13 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING
import torch
from diffusers import UNet2DConditionModel
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.invocations.model import ModelIdentifierField
@ -31,107 +30,14 @@ class LoRAExt(ExtensionBase):
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
self.patch_model(
assert isinstance(lora_model, LoRAModelRaw)
LoRAPatcher.apply_lora_patch(
model=unet,
prefix="lora_unet_",
lora=lora_model,
lora_weight=self._weight,
patch=lora_model,
patch_weight=self._weight,
original_weights=original_weights,
)
del lora_model
yield
@classmethod
@torch.no_grad()
def patch_model(
cls,
model: torch.nn.Module,
prefix: str,
lora: LoRAModelRaw,
lora_weight: float,
original_weights: OriginalWeightsStorage,
):
"""
Apply one or more LoRAs to a model.
:param model: The model to patch.
:param lora: LoRA model to patch in.
:param lora_weight: LoRA patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
"""
if lora_weight == 0:
return
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
# should be improved in the following ways:
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
# LoRA model is applied.
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys.
assert isinstance(model, torch.nn.Module)
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in layer.get_parameters(module).items():
param_key = module_key + "." + param_name
module_param = module.get_parameter(param_name)
# save original weight
original_weights.save(param_key, module_param)
if module_param.shape != lora_param_weight.shape:
# TODO: debug on lycoris
lora_param_weight = lora_param_weight.reshape(module_param.shape)
lora_param_weight *= lora_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
layer.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
if not lora_key.startswith(prefix):
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
module = model
module_key = ""
key_parts = lora_key[len(prefix) :].split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = (module_key + "." + submodule_name).lstrip(".")
return (module_key, module)

View File

@ -3,7 +3,7 @@ import torch
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.lora.lora_patcher import LoRAPatcher
@pytest.mark.parametrize(
@ -41,7 +41,7 @@ def test_apply_lora(device: str):
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
expected_patched_linear_weight = orig_linear_weight + (lora_dim * lora_weight)
with ModelPatcher.apply_lora(model, [(lora, lora_weight)], prefix=""):
with LoRAPatcher.apply_lora_patches(model=model, patches=[(lora, lora_weight)], prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
@ -83,7 +83,7 @@ def test_apply_lora_change_device():
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
with ModelPatcher.apply_lora(model, [(lora, 0.5)], prefix=""):
with LoRAPatcher.apply_lora_patches(model=model, patches=[(lora, 0.5)], prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"