diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index c23dd3d908..1a95fff37f 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -9,7 +9,8 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import generate_ti_list -from invokeai.backend.lora import LoRAModelRaw +from invokeai.backend.lora_model_patcher import LoraModelPatcher +from invokeai.backend.lora_model_raw import LoRAModelRaw from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, @@ -80,7 +81,8 @@ class CompelInvocation(BaseInvocation): ), text_encoder_info as text_encoder, # Apply the LoRA after text_encoder has been moved to its target device for faster patching. - ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), + # ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), + LoraModelPatcher.apply_lora_to_text_encoder(text_encoder, _lora_loader(), "text_encoder"), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers), ): @@ -181,7 +183,8 @@ class SDXLPromptInvocationBase: ), text_encoder_info as text_encoder, # Apply the LoRA after text_encoder has been moved to its target device for faster patching. - ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), + # ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), + LoraModelPatcher.apply_lora_to_text_encoder(text_encoder, _lora_loader(), lora_prefix), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), ): @@ -259,15 +262,15 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: c1, c1_pooled, ec1 = self.run_clip_compel( - context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True + context, self.clip, self.prompt, False, "text_encoder", zero_on_empty=True ) if self.style.strip() == "": c2, c2_pooled, ec2 = self.run_clip_compel( - context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True + context, self.clip2, self.prompt, True, "text_encoder_2", zero_on_empty=True ) else: c2, c2_pooled, ec2 = self.run_clip_compel( - context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True + context, self.clip2, self.style, True, "text_encoder_2", zero_on_empty=True ) original_size = (self.original_height, self.original_width) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index bc79efdeba..d6993fbdbb 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -52,7 +52,8 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus -from invokeai.backend.lora import LoRAModelRaw +from invokeai.backend.lora_model_patcher import LoraModelPatcher +from invokeai.backend.lora_model_raw import LoRAModelRaw from invokeai.backend.model_manager import BaseModelType, LoadedModel from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless @@ -739,7 +740,8 @@ class DenoiseLatentsInvocation(BaseInvocation): set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME unet_info as unet, # Apply the LoRA after unet has been moved to its target device for faster patching. - ModelPatcher.apply_lora_unet(unet, _lora_loader()), + # ModelPatcher.apply_lora_unet(unet, _lora_loader()), + LoraModelPatcher.apply_lora_to_unet(unet, _lora_loader()), ): assert isinstance(unet, UNet2DConditionModel) latents = latents.to(device=unet.device, dtype=unet.dtype) diff --git a/invokeai/backend/lora_model_patcher.py b/invokeai/backend/lora_model_patcher.py new file mode 100644 index 0000000000..afd87be311 --- /dev/null +++ b/invokeai/backend/lora_model_patcher.py @@ -0,0 +1,65 @@ +from contextlib import contextmanager +from typing import Iterator, Tuple, Union + +from diffusers.loaders.lora import LoraLoaderMixin +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from diffusers.utils.peft_utils import recurse_remove_peft_layers +from transformers import CLIPTextModel + +from invokeai.backend.lora_model_raw import LoRAModelRaw + + +class LoraModelPatcher: + @classmethod + def unload_lora_from_model(cls, m: Union[UNet2DConditionModel, CLIPTextModel]): + """Unload all LoRA models from a UNet or Text Encoder. + This implementation is base on LoraLoaderMixin.unload_lora_weights(). + """ + recurse_remove_peft_layers(m) + if hasattr(m, "peft_config"): + del m.peft_config # type: ignore + if hasattr(m, "_hf_peft_config_loaded"): + m._hf_peft_config_loaded = None # type: ignore + + @classmethod + @contextmanager + def apply_lora_to_unet(cls, unet: UNet2DConditionModel, loras: Iterator[Tuple[LoRAModelRaw, float]]): + try: + # TODO(ryand): Test speed of low_cpu_mem_usage=True. + for lora, lora_weight in loras: + LoraLoaderMixin.load_lora_into_unet( + state_dict=lora.state_dict, + network_alphas=lora.network_alphas, + unet=unet, + low_cpu_mem_usage=True, + adapter_name=lora.name, + _pipeline=None, + ) + yield + finally: + cls.unload_lora_from_model(unet) + + @classmethod + @contextmanager + def apply_lora_to_text_encoder( + cls, text_encoder: CLIPTextModel, loras: Iterator[Tuple[LoRAModelRaw, float]], prefix: str + ): + assert prefix in ["text_encoder", "text_encoder_2"] + try: + for lora, lora_weight in loras: + # Filter the state_dict to only include the keys that start with the prefix. + text_encoder_state_dict = { + key: value for key, value in lora.state_dict.items() if key.startswith(prefix + ".") + } + if len(text_encoder_state_dict) > 0: + LoraLoaderMixin.load_lora_into_text_encoder( + state_dict=text_encoder_state_dict, + network_alphas=lora.network_alphas, + text_encoder=text_encoder, + low_cpu_mem_usage=True, + adapter_name=lora.name, + _pipeline=None, + ) + yield + finally: + cls.unload_lora_from_model(text_encoder) diff --git a/invokeai/backend/lora_model_raw.py b/invokeai/backend/lora_model_raw.py new file mode 100644 index 0000000000..1f0ec71636 --- /dev/null +++ b/invokeai/backend/lora_model_raw.py @@ -0,0 +1,66 @@ +from pathlib import Path +from typing import Optional, Union + +import torch +from diffusers.loaders.lora import LoraLoaderMixin +from typing_extensions import Self + + +class LoRAModelRaw: + def __init__( + self, + name: str, + state_dict: dict[str, torch.Tensor], + network_alphas: Optional[dict[str, float]], + ): + self._name = name + self.state_dict = state_dict + self.network_alphas = network_alphas + + @property + def name(self) -> str: + return self._name + + def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None: + for key, layer in self.state_dict.items(): + self.state_dict[key] = layer.to(device=device, dtype=dtype) + + def calc_size(self) -> int: + """Calculate the size of the model in bytes.""" + model_size = 0 + for layer in self.state_dict.values(): + model_size += layer.numel() * layer.element_size() + return model_size + + @classmethod + def from_checkpoint( + cls, file_path: Union[str, Path], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + ) -> Self: + """This function is based on diffusers LoraLoaderMixin.load_lora_weights().""" + + file_path = Path(file_path) + if file_path.is_dir(): + raise NotImplementedError("LoRA models from directories are not yet supported.") + + dir_path = file_path.parent + file_name = file_path.name + + state_dict, network_alphas = LoraLoaderMixin.lora_state_dict( + pretrained_model_name_or_path_or_dict=str(file_path), local_files_only=True, weight_name=str(file_name) + ) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + model = cls( + # TODO(ryand): Handle both files and directories here? + name=Path(file_path).stem, + state_dict=state_dict, + network_alphas=network_alphas, + ) + + device = device or torch.device("cpu") + dtype = dtype or torch.float32 + model.to(device=device, dtype=dtype) + return model diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora_old.py similarity index 100% rename from invokeai/backend/lora.py rename to invokeai/backend/lora_old.py diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 9836ee3167..78fd24e9be 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -32,7 +32,7 @@ from typing_extensions import Annotated, Any, Dict from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES from invokeai.app.util.misc import uuid_string from invokeai.backend.ip_adapter.ip_adapter import IPAdapter -from invokeai.backend.lora import LoRAModelRaw +from invokeai.backend.lora_model_raw import LoRAModelRaw from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel from invokeai.backend.textual_inversion import TextualInversionModelRaw diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index 20a39e56c3..2f809147de 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Optional, Tuple from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.lora import LoRAModelRaw +from invokeai.backend.lora_model_raw import LoRAModelRaw from invokeai.backend.model_manager import ( AnyModel, AnyModelConfig, @@ -51,7 +51,6 @@ class LoRALoader(ModelLoader): model = LoRAModelRaw.from_checkpoint( file_path=model_path, dtype=self._torch_dtype, - base_model=self._model_base, ) return model diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index 76271fc025..8455e9e85b 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -17,7 +17,7 @@ from invokeai.backend.model_manager import AnyModel from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel -from .lora import LoRAModelRaw +from .lora_model_raw import LoRAModelRaw from .textual_inversion import TextualInversionManager, TextualInversionModelRaw """ diff --git a/tests/backend/model_manager/test_lora.py b/tests/backend/model_manager/test_lora.py index 114a4cfdcf..c5c11fb85f 100644 --- a/tests/backend/model_manager/test_lora.py +++ b/tests/backend/model_manager/test_lora.py @@ -5,7 +5,7 @@ import pytest import torch -from invokeai.backend.lora import LoRALayer, LoRAModelRaw +from invokeai.backend.lora_model_raw import LoRALayer, LoRAModelRaw from invokeai.backend.model_patcher import ModelPatcher