diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 2a8b2df705..4f7c72eea6 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -10,6 +10,7 @@ from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import generate_ti_list from invokeai.backend.lora.lora_model import LoRAModelRaw +from invokeai.backend.lora.lora_model_patcher import LoraModelPatcher from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, @@ -80,7 +81,7 @@ class CompelInvocation(BaseInvocation): ), text_encoder_info as text_encoder, # Apply the LoRA after text_encoder has been moved to its target device for faster patching. - ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), + LoraModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers), ): @@ -181,7 +182,7 @@ class SDXLPromptInvocationBase: ), text_encoder_info as text_encoder, # Apply the LoRA after text_encoder has been moved to its target device for faster patching. - ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), + LoraModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), ): diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index fec4c7ca1f..f383a533d5 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -49,6 +49,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.lora.lora_model import LoRAModelRaw +from invokeai.backend.lora.lora_model_patcher import LoraModelPatcher from invokeai.backend.model_manager import BaseModelType, LoadedModel from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless @@ -730,7 +731,7 @@ class DenoiseLatentsInvocation(BaseInvocation): set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME unet_info as unet, # Apply the LoRA after unet has been moved to its target device for faster patching. - ModelPatcher.apply_lora_unet(unet, _lora_loader()), + LoraModelPatcher.apply_lora_unet(unet, _lora_loader()), ): assert isinstance(unet, UNet2DConditionModel) latents = latents.to(device=unet.device, dtype=unet.dtype) diff --git a/invokeai/backend/lora/lora_model_patcher.py b/invokeai/backend/lora/lora_model_patcher.py new file mode 100644 index 0000000000..633f487a92 --- /dev/null +++ b/invokeai/backend/lora/lora_model_patcher.py @@ -0,0 +1,141 @@ +from contextlib import contextmanager +from typing import Iterator, List, Tuple + +import torch +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from transformers import CLIPTextModel + +from invokeai.backend.lora.lora_model import LoRAModelRaw +from invokeai.backend.model_manager.any_model_type import AnyModel + + +class LoraModelPatcher: + @staticmethod + def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: + assert "." not in lora_key + + if not lora_key.startswith(prefix): + raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}") + + module = model + module_key = "" + key_parts = lora_key[len(prefix) :].split("_") + + submodule_name = key_parts.pop(0) + + while len(key_parts) > 0: + try: + module = module.get_submodule(submodule_name) + module_key += "." + submodule_name + submodule_name = key_parts.pop(0) + except Exception: + submodule_name += "_" + key_parts.pop(0) + + module = module.get_submodule(submodule_name) + module_key = (module_key + "." + submodule_name).lstrip(".") + + return (module_key, module) + + @classmethod + @contextmanager + def apply_lora_unet( + cls, + unet: UNet2DConditionModel, + loras: Iterator[Tuple[LoRAModelRaw, float]], + ) -> None: + with cls.apply_lora(unet, loras, "lora_unet_"): + yield + + @classmethod + @contextmanager + def apply_lora_text_encoder( + cls, + text_encoder: CLIPTextModel, + loras: Iterator[Tuple[LoRAModelRaw, float]], + ) -> None: + with cls.apply_lora(text_encoder, loras, "lora_te_"): + yield + + @classmethod + @contextmanager + def apply_sdxl_lora_text_encoder( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModelRaw, float]], + ) -> None: + with cls.apply_lora(text_encoder, loras, "lora_te1_"): + yield + + @classmethod + @contextmanager + def apply_sdxl_lora_text_encoder2( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModelRaw, float]], + ) -> None: + with cls.apply_lora(text_encoder, loras, "lora_te2_"): + yield + + @classmethod + @contextmanager + def apply_lora( + cls, + model: AnyModel, + loras: Iterator[Tuple[LoRAModelRaw, float]], + prefix: str, + ) -> None: + original_weights = {} + try: + with torch.no_grad(): + for lora, lora_weight in loras: + # assert lora.device.type == "cpu" + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue + + # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This + # should be improved in the following ways: + # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a + # LoRA model is applied. + # 2. From an API perspective, there's no reason that the `LoraModelPatcher` should be aware of + # the intricacies of Stable Diffusion key resolution. It should just expect the input LoRA + # weights to have valid keys. + assert isinstance(model, torch.nn.Module) + module_key, module = cls._resolve_lora_key(model, layer_key, prefix) + + # All of the LoRA weight calculations will be done on the same device as the module weight. + # (Performance will be best if this is a CUDA device.) + device = module.weight.device + dtype = module.weight.dtype + + if module_key not in original_weights: + original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) + + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + + # We intentionally move to the target device first, then cast. Experimentally, this was found to + # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the + # same thing in a single call to '.to(...)'. + layer.to(device=device) + layer.to(dtype=torch.float32) + # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA + # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. + layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) + layer.to(device=torch.device("cpu")) + + assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! + if module.weight.shape != layer_weight.shape: + # TODO: debug on lycoris + assert hasattr(layer_weight, "reshape") + layer_weight = layer_weight.reshape(module.weight.shape) + + assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! + module.weight += layer_weight.to(dtype=dtype) + + yield # wait for context manager exit + + finally: + assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() + with torch.no_grad(): + for module_key, weight in original_weights.items(): + model.get_submodule(module_key).weight.copy_(weight) diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index f59074db28..de28b0ab24 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -14,156 +14,13 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz from invokeai.app.shared.models import FreeUConfig from invokeai.backend.lora.lora_model import LoRAModelRaw -from invokeai.backend.model_manager.any_model_type import AnyModel from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel from .textual_inversion import TextualInversionManager, TextualInversionModelRaw -""" -loras = [ - (lora_model1, 0.7), - (lora_model2, 0.4), -] -with LoRAHelper.apply_lora_unet(unet, loras): - # unet with applied loras -# unmodified unet -""" - - -# TODO: rename smth like ModelPatcher and add TI method? class ModelPatcher: - @staticmethod - def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: - assert "." not in lora_key - - if not lora_key.startswith(prefix): - raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}") - - module = model - module_key = "" - key_parts = lora_key[len(prefix) :].split("_") - - submodule_name = key_parts.pop(0) - - while len(key_parts) > 0: - try: - module = module.get_submodule(submodule_name) - module_key += "." + submodule_name - submodule_name = key_parts.pop(0) - except Exception: - submodule_name += "_" + key_parts.pop(0) - - module = module.get_submodule(submodule_name) - module_key = (module_key + "." + submodule_name).lstrip(".") - - return (module_key, module) - - @classmethod - @contextmanager - def apply_lora_unet( - cls, - unet: UNet2DConditionModel, - loras: Iterator[Tuple[LoRAModelRaw, float]], - ) -> None: - with cls.apply_lora(unet, loras, "lora_unet_"): - yield - - @classmethod - @contextmanager - def apply_lora_text_encoder( - cls, - text_encoder: CLIPTextModel, - loras: Iterator[Tuple[LoRAModelRaw, float]], - ) -> None: - with cls.apply_lora(text_encoder, loras, "lora_te_"): - yield - - @classmethod - @contextmanager - def apply_sdxl_lora_text_encoder( - cls, - text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModelRaw, float]], - ) -> None: - with cls.apply_lora(text_encoder, loras, "lora_te1_"): - yield - - @classmethod - @contextmanager - def apply_sdxl_lora_text_encoder2( - cls, - text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModelRaw, float]], - ) -> None: - with cls.apply_lora(text_encoder, loras, "lora_te2_"): - yield - - @classmethod - @contextmanager - def apply_lora( - cls, - model: AnyModel, - loras: Iterator[Tuple[LoRAModelRaw, float]], - prefix: str, - ) -> None: - original_weights = {} - try: - with torch.no_grad(): - for lora, lora_weight in loras: - # assert lora.device.type == "cpu" - for layer_key, layer in lora.layers.items(): - if not layer_key.startswith(prefix): - continue - - # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This - # should be improved in the following ways: - # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a - # LoRA model is applied. - # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the - # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA - # weights to have valid keys. - assert isinstance(model, torch.nn.Module) - module_key, module = cls._resolve_lora_key(model, layer_key, prefix) - - # All of the LoRA weight calculations will be done on the same device as the module weight. - # (Performance will be best if this is a CUDA device.) - device = module.weight.device - dtype = module.weight.dtype - - if module_key not in original_weights: - original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) - - layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 - - # We intentionally move to the target device first, then cast. Experimentally, this was found to - # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the - # same thing in a single call to '.to(...)'. - layer.to(device=device) - layer.to(dtype=torch.float32) - # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA - # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. - layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) - layer.to(device=torch.device("cpu")) - - assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! - if module.weight.shape != layer_weight.shape: - # TODO: debug on lycoris - assert hasattr(layer_weight, "reshape") - layer_weight = layer_weight.reshape(module.weight.shape) - - assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! - module.weight += layer_weight.to(dtype=dtype) - - yield # wait for context manager exit - - finally: - assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() - with torch.no_grad(): - for module_key, weight in original_weights.items(): - model.get_submodule(module_key).weight.copy_(weight) - @classmethod @contextmanager def apply_ti( diff --git a/tests/backend/model_manager/test_lora.py b/tests/backend/model_manager/test_lora.py index 6e8c911e12..b885f4d8ff 100644 --- a/tests/backend/model_manager/test_lora.py +++ b/tests/backend/model_manager/test_lora.py @@ -7,7 +7,7 @@ import torch from invokeai.backend.lora.lora_layer import LoRALayer from invokeai.backend.lora.lora_model import LoRAModelRaw -from invokeai.backend.model_patcher import ModelPatcher +from invokeai.backend.lora.lora_model_patcher import LoraModelPatcher @pytest.mark.parametrize( @@ -45,7 +45,7 @@ def test_apply_lora(device): orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone() expected_patched_linear_weight = orig_linear_weight + (lora_dim * lora_weight) - with ModelPatcher.apply_lora(model, [(lora, lora_weight)], prefix=""): + with LoraModelPatcher.apply_lora(model, [(lora, lora_weight)], prefix=""): # After patching, all LoRA layer weights should have been moved back to the cpu. assert lora_layers["linear_layer_1"].up.device.type == "cpu" assert lora_layers["linear_layer_1"].down.device.type == "cpu" @@ -87,7 +87,7 @@ def test_apply_lora_change_device(): orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone() - with ModelPatcher.apply_lora(model, [(lora, 0.5)], prefix=""): + with LoraModelPatcher.apply_lora(model, [(lora, 0.5)], prefix=""): # After patching, all LoRA layer weights should have been moved back to the cpu. assert lora_layers["linear_layer_1"].up.device.type == "cpu" assert lora_layers["linear_layer_1"].down.device.type == "cpu"