From f9fda503a307289dcb08ab4b47389b3ac114aa44 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 5 Apr 2024 12:02:05 -0400 Subject: [PATCH] A very primitive working version of peft patching. It is very slow. LoRAs don't get unloaded yet, so can only be run once. And the results are *slightly* different than the old implementation. I suspect this is because the lora weight is not being applied to the UNet, but there could be other issues as well. --- invokeai/app/invocations/compel.py | 4 +- invokeai/app/invocations/latent.py | 10 +- invokeai/backend/peft/peft_model_patcher.py | 162 +++++++++++++++++++- 3 files changed, 168 insertions(+), 8 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 6daa0f54ad..6b66587b1c 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -78,7 +78,7 @@ class CompelInvocation(BaseInvocation): ), text_encoder_info as text_encoder, # Apply the LoRA after text_encoder has been moved to its target device for faster patching. - PeftModelPatcher.apply_peft_patch(text_encoder, _lora_loader(), "text_encoder"), + PeftModelPatcher.apply_peft_model_to_text_encoder(text_encoder, _lora_loader(), "text_encoder"), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers), ): @@ -176,7 +176,7 @@ class SDXLPromptInvocationBase: ), text_encoder_info as text_encoder, # Apply the LoRA after text_encoder has been moved to its target device for faster patching. - PeftModelPatcher.apply_peft_patch(text_encoder, _lora_loader(), lora_prefix), + PeftModelPatcher.apply_peft_model_to_text_encoder(text_encoder, _lora_loader(), lora_prefix), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), ): diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 3c66b7014f..3defa8778a 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -48,9 +48,10 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus -from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import BaseModelType, LoadedModel from invokeai.backend.model_patcher import ModelPatcher +from invokeai.backend.peft.peft_model import PeftModel +from invokeai.backend.peft.peft_model_patcher import PeftModelPatcher from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo from invokeai.backend.util.silence_warnings import SilenceWarnings @@ -714,13 +715,12 @@ class DenoiseLatentsInvocation(BaseInvocation): def step_callback(state: PipelineIntermediateState) -> None: context.util.sd_step_callback(state, unet_config.base) - def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: + def _lora_loader() -> Iterator[Tuple[PeftModel, float]]: for lora in self.unet.loras: lora_info = context.models.load(lora.lora) - assert isinstance(lora_info.model, LoRAModelRaw) + assert isinstance(lora_info.model, PeftModel) yield (lora_info.model, lora.weight) del lora_info - return unet_info = context.models.load(self.unet.unet) assert isinstance(unet_info.model, UNet2DConditionModel) @@ -730,7 +730,7 @@ class DenoiseLatentsInvocation(BaseInvocation): set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME unet_info as unet, # Apply the LoRA after unet has been moved to its target device for faster patching. - ModelPatcher.apply_lora_unet(unet, _lora_loader()), + PeftModelPatcher.apply_peft_model_to_unet(unet, _lora_loader()), ): assert isinstance(unet, UNet2DConditionModel) latents = latents.to(device=unet.device, dtype=unet.dtype) diff --git a/invokeai/backend/peft/peft_model_patcher.py b/invokeai/backend/peft/peft_model_patcher.py index 0174fb9581..3159dc8c0f 100644 --- a/invokeai/backend/peft/peft_model_patcher.py +++ b/invokeai/backend/peft/peft_model_patcher.py @@ -4,11 +4,171 @@ from contextlib import contextmanager from typing import Iterator, Tuple import torch +from diffusers.models.lora import text_encoder_attn_modules, text_encoder_mlp_modules +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from diffusers.utils.peft_utils import get_peft_kwargs, scale_lora_layers +from diffusers.utils.state_dict_utils import convert_state_dict_to_peft, convert_unet_state_dict_to_peft +from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from invokeai.backend.peft.peft_model import PeftModel +UNET_NAME = "unet" + class PeftModelPatcher: + @classmethod + @contextmanager + @torch.no_grad() + def apply_peft_model_to_text_encoder( + cls, + text_encoder: torch.nn.Module, + peft_models: Iterator[Tuple[PeftModel, float]], + prefix: str, + ): + original_weights = {} + + try: + for peft_model, peft_model_weight in peft_models: + keys = list(peft_model.state_dict.keys()) + + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] + text_encoder_lora_state_dict = { + k.replace(f"{prefix}.", ""): v for k, v in peft_model.state_dict.items() if k in text_encoder_keys + } + + if len(text_encoder_lora_state_dict) == 0: + continue + + if peft_model.name in getattr(text_encoder, "peft_config", {}): + raise ValueError(f"Adapter name {peft_model.name} already in use in the text encoder ({prefix}).") + + rank = {} + # TODO(ryand): Is this necessary? + # text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) + + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + rank_key = f"{name}.out_proj.lora_B.weight" + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) + if patch_mlp: + for name, _ in text_encoder_mlp_modules(text_encoder): + rank_key_fc1 = f"{name}.fc1.lora_B.weight" + rank_key_fc2 = f"{name}.fc2.lora_B.weight" + + rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1] + rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1] + + network_alphas = peft_model.network_alphas + if network_alphas is not None: + alpha_keys = [ + k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix + ] + network_alphas = { + k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + lora_config_kwargs["inference_mode"] = True + lora_config = LoraConfig(**lora_config_kwargs) + + new_text_encoder = inject_adapter_in_model(lora_config, text_encoder, peft_model.name) + incompatible_keys = set_peft_model_state_dict( + new_text_encoder, text_encoder_lora_state_dict, peft_model.name + ) + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + raise ValueError(f"Failed to inject unexpected PEFT keys: {unexpected_keys}") + + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + # text_encoder.load_adapter( + # adapter_name=adapter_name, + # adapter_state_dict=text_encoder_lora_state_dict, + # peft_config=lora_config, + # ) + + scale_lora_layers(text_encoder, weight=peft_model_weight) + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + + yield + finally: + # TODO + pass + # for module_key, weight in original_weights.items(): + # model.get_submodule(module_key).weight.copy_(weight) + + @classmethod + @contextmanager + @torch.no_grad() + def apply_peft_model_to_unet( + cls, + unet: UNet2DConditionModel, + peft_models: Iterator[Tuple[PeftModel, float]], + ): + try: + for peft_model, peft_model_weight in peft_models: + keys = list(peft_model.state_dict.keys()) + + unet_keys = [k for k in keys if k.startswith(UNET_NAME)] + state_dict = { + k.replace(f"{UNET_NAME}.", ""): v for k, v in peft_model.state_dict.items() if k in unet_keys + } + + network_alphas = peft_model.network_alphas + if network_alphas is not None: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(UNET_NAME)] + network_alphas = { + k.replace(f"{UNET_NAME}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + if len(state_dict) == 0: + continue + + if peft_model.name in getattr(unet, "peft_config", {}): + raise ValueError(f"Adapter name {peft_model.name} already in use in the Unet.") + + state_dict = convert_unet_state_dict_to_peft(state_dict) + + if network_alphas is not None: + # The alphas state dict have the same structure as Unet, thus we convert it to peft format using + # `convert_unet_state_dict_to_peft` method. + network_alphas = convert_unet_state_dict_to_peft(network_alphas) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) + lora_config_kwargs["inference_mode"] = True + lora_config = LoraConfig(**lora_config_kwargs) + + inject_adapter_in_model(lora_config, unet, adapter_name=peft_model.name) + incompatible_keys = set_peft_model_state_dict(unet, state_dict, peft_model.name) + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + raise ValueError(f"Failed to inject unexpected PEFT keys: {unexpected_keys}") + + # TODO(ryand): What does this do? + unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=True) + + # TODO(ryand): Apply the lora weight. Where does diffusers do this? They don't seem to do it when they + # patch the UNet. + yield + finally: + # TODO + pass + # for module_key, weight in original_weights.items(): + # model.get_submodule(module_key).weight.copy_(weight) + @classmethod @contextmanager @torch.no_grad() @@ -28,8 +188,8 @@ class PeftModelPatcher: continue module_key = layer_key.replace(prefix + ".", "") - module_key = module_key.split # TODO(ryand): Make this work. + module = model_state_dict[module_key] # All of the LoRA weight calculations will be done on the same device as the module weight.