From 22c66cf55b233270f9ed9a9d51793495ba06ae80 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 4 Apr 2024 22:40:42 -0400 Subject: [PATCH] WIP --- invokeai/app/invocations/compel.py | 27 ++- .../backend/model_manager/any_model_type.py | 4 +- .../model_manager/load/model_loaders/lora.py | 4 +- invokeai/backend/peft/peft_format_utils.py | 85 ++++++++++ invokeai/backend/peft/peft_model.py | 36 ++-- invokeai/backend/peft/peft_model_patcher.py | 67 ++++++++ invokeai/backend/peft/sdxl_format_utils.py | 154 ++++++++++++++++++ 7 files changed, 337 insertions(+), 40 deletions(-) create mode 100644 invokeai/backend/peft/peft_format_utils.py create mode 100644 invokeai/backend/peft/peft_model_patcher.py create mode 100644 invokeai/backend/peft/sdxl_format_utils.py diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index c23dd3d908..6daa0f54ad 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -9,8 +9,9 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import generate_ti_list -from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_patcher import ModelPatcher +from invokeai.backend.peft.peft_model import PeftModel +from invokeai.backend.peft.peft_model_patcher import PeftModelPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ConditioningFieldData, @@ -61,15 +62,12 @@ class CompelInvocation(BaseInvocation): text_encoder_model = text_encoder_info.model assert isinstance(text_encoder_model, CLIPTextModel) - def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: + def _lora_loader() -> Iterator[Tuple[PeftModel, float]]: for lora in self.clip.loras: lora_info = context.models.load(lora.lora) - assert isinstance(lora_info.model, LoRAModelRaw) + assert isinstance(lora_info.model, PeftModel) yield (lora_info.model, lora.weight) del lora_info - return - - # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context) @@ -80,7 +78,7 @@ class CompelInvocation(BaseInvocation): ), text_encoder_info as text_encoder, # Apply the LoRA after text_encoder has been moved to its target device for faster patching. - ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), + PeftModelPatcher.apply_peft_patch(text_encoder, _lora_loader(), "text_encoder"), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers), ): @@ -161,16 +159,13 @@ class SDXLPromptInvocationBase: c_pooled = None return c, c_pooled, None - def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: + def _lora_loader() -> Iterator[Tuple[PeftModel, float]]: for lora in clip_field.loras: lora_info = context.models.load(lora.lora) lora_model = lora_info.model - assert isinstance(lora_model, LoRAModelRaw) + assert isinstance(lora_model, PeftModel) yield (lora_model, lora.weight) del lora_info - return - - # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context) @@ -181,7 +176,7 @@ class SDXLPromptInvocationBase: ), text_encoder_info as text_encoder, # Apply the LoRA after text_encoder has been moved to its target device for faster patching. - ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), + PeftModelPatcher.apply_peft_patch(text_encoder, _lora_loader(), lora_prefix), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), ): @@ -259,15 +254,15 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: c1, c1_pooled, ec1 = self.run_clip_compel( - context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True + context, self.clip, self.prompt, False, "text_encoder", zero_on_empty=True ) if self.style.strip() == "": c2, c2_pooled, ec2 = self.run_clip_compel( - context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True + context, self.clip2, self.prompt, True, "text_encoder_2", zero_on_empty=True ) else: c2, c2_pooled, ec2 = self.run_clip_compel( - context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True + context, self.clip2, self.style, True, "text_encoder_2", zero_on_empty=True ) original_size = (self.original_height, self.original_width) diff --git a/invokeai/backend/model_manager/any_model_type.py b/invokeai/backend/model_manager/any_model_type.py index 837bb9dcd0..37961ef26a 100644 --- a/invokeai/backend/model_manager/any_model_type.py +++ b/invokeai/backend/model_manager/any_model_type.py @@ -4,9 +4,9 @@ import torch from diffusers.models.modeling_utils import ModelMixin from invokeai.backend.ip_adapter.ip_adapter import IPAdapter -from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel +from invokeai.backend.peft.peft_model import PeftModel from invokeai.backend.textual_inversion import TextualInversionModelRaw # ModelMixin is the base class for all diffusers and transformers models -AnyModel = Union[ModelMixin, torch.nn.Module, IPAdapter, LoRAModelRaw, TextualInversionModelRaw, IAIOnnxRuntimeModel] +AnyModel = Union[ModelMixin, torch.nn.Module, IPAdapter, PeftModel, TextualInversionModelRaw, IAIOnnxRuntimeModel] diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index 8d353d4b71..3c14b2f1e6 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Optional from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import ( AnyModelConfig, BaseModelType, @@ -17,6 +16,7 @@ from invokeai.backend.model_manager import ( from invokeai.backend.model_manager.any_model_type import AnyModel from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase +from invokeai.backend.peft.peft_model import PeftModel from .. import ModelLoader, ModelLoaderRegistry @@ -47,7 +47,7 @@ class LoRALoader(ModelLoader): raise ValueError("There are no submodels in a LoRA model.") model_path = Path(config.path) assert self._model_base is not None - model = LoRAModelRaw.from_checkpoint( + model = PeftModel.from_checkpoint( file_path=model_path, dtype=self._torch_dtype, base_model=self._model_base, diff --git a/invokeai/backend/peft/peft_format_utils.py b/invokeai/backend/peft/peft_format_utils.py new file mode 100644 index 0000000000..673fd25e05 --- /dev/null +++ b/invokeai/backend/peft/peft_format_utils.py @@ -0,0 +1,85 @@ +import torch +from diffusers.utils.state_dict_utils import convert_state_dict + +KOHYA_SS_TO_PEFT = { + "lora_down": "lora_A", + "lora_up": "lora_B", + # This is not a comprehensive dict. See `convert_state_dict_to_peft(...)` for more info on the conversion. +} + + +def convert_state_dict_kohya_to_peft(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + # TODO(ryand): Check that state_dict is in Kohya format. + + peft_partial_state_dict = convert_state_dict(state_dict, KOHYA_SS_TO_PEFT) + + peft_state_dict: dict[str, torch.Tensor] = {} + for key, weight in peft_partial_state_dict.items(): + + + for kohya_key, weight in kohya_ss_partial_state_dict.items(): + if "text_encoder_2." in kohya_key: + kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.") + elif "text_encoder." in kohya_key: + kohya_key = kohya_key.replace("text_encoder.", "lora_te1.") + elif "unet" in kohya_key: + kohya_key = kohya_key.replace("unet", "lora_unet") + kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2) + kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names + kohya_ss_state_dict[kohya_key] = weight + if "lora_down" in kohya_key: + alpha_key = f'{kohya_key.split(".")[0]}.alpha' + kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight)) +def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs): + r""" + Converts a `PEFT` state dict to `Kohya` format that can be used in AUTOMATIC1111, ComfyUI, SD.Next, InvokeAI, etc. + The method only supports the conversion from PEFT to Kohya for now. + + Args: + state_dict (`dict[str, torch.Tensor]`): + The state dict to convert. + original_type (`StateDictType`, *optional*): + The original type of the state dict, if not provided, the method will try to infer it automatically. + kwargs (`dict`, *args*): + Additional arguments to pass to the method. + + - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended + with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in + `get_peft_model_state_dict` method: + https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 + but we add it here in case we don't want to rely on that method. + """ + + peft_adapter_name = kwargs.pop("adapter_name", None) + if peft_adapter_name is not None: + peft_adapter_name = "." + peft_adapter_name + else: + peft_adapter_name = "" + + if original_type is None: + if any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()): + original_type = StateDictType.PEFT + + if original_type not in KOHYA_STATE_DICT_MAPPINGS.keys(): + raise ValueError(f"Original type {original_type} is not supported") + + # Use the convert_state_dict function with the appropriate mapping + kohya_ss_partial_state_dict = convert_state_dict(state_dict, KOHYA_STATE_DICT_MAPPINGS[StateDictType.PEFT]) + kohya_ss_state_dict = {} + + # Additional logic for replacing header, alpha parameters `.` with `_` in all keys + for kohya_key, weight in kohya_ss_partial_state_dict.items(): + if "text_encoder_2." in kohya_key: + kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.") + elif "text_encoder." in kohya_key: + kohya_key = kohya_key.replace("text_encoder.", "lora_te1.") + elif "unet" in kohya_key: + kohya_key = kohya_key.replace("unet", "lora_unet") + kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2) + kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names + kohya_ss_state_dict[kohya_key] = weight + if "lora_down" in kohya_key: + alpha_key = f'{kohya_key.split(".")[0]}.alpha' + kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight)) + + return kohya_ss_state_dict diff --git a/invokeai/backend/peft/peft_model.py b/invokeai/backend/peft/peft_model.py index c21890f831..707753ceb3 100644 --- a/invokeai/backend/peft/peft_model.py +++ b/invokeai/backend/peft/peft_model.py @@ -2,9 +2,11 @@ from pathlib import Path from typing import Optional, Union import torch -from safetensors.torch import load_file +from diffusers.loaders.lora_conversion_utils import _convert_kohya_lora_to_diffusers from invokeai.backend.model_manager.config import BaseModelType +from invokeai.backend.peft.sdxl_format_utils import convert_sdxl_keys_to_diffusers_format +from invokeai.backend.util.serialization import load_state_dict class PeftModel: @@ -14,17 +16,15 @@ class PeftModel: self, name: str, state_dict: dict[str, torch.Tensor], + network_alphas: dict[str, torch.Tensor], ): - self._name = name - self._state_dict = state_dict - - @property - def name(self) -> str: - return self._name + self.name = name + self.state_dict = state_dict + self.network_alphas = network_alphas def calc_size(self) -> int: model_size = 0 - for tensor in self._state_dict.values(): + for tensor in self.state_dict.values(): model_size += tensor.nelement() * tensor.element_size() return model_size @@ -41,16 +41,12 @@ class PeftModel: file_path = Path(file_path) - # TODO(ryand): Implement a helper function for this. This logic is duplicated repeatedly. - if file_path.suffix == ".safetensors": - state_dict = load_file(file_path, device="cpu") - else: - state_dict = torch.load(file_path, map_location="cpu") + state_dict = load_state_dict(file_path, device=str(device)) + # lora_unet_up_blocks_1_attentions_2_transformer_blocks_1_ff_net_2.lora_down.weight + if base_model == BaseModelType.StableDiffusionXL: + state_dict = convert_sdxl_keys_to_diffusers_format(state_dict) - # TODO(ryand): - # - Detect state_dict format - # - Convert state_dict to diffusers format if necessary - - # if base_model == BaseModelType.StableDiffusionXL: - # state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) - return cls(name=file_path.stem, state_dict=state_dict) + # TODO(ryand): We shouldn't be using an unexported function from diffusers here. Consider opening an upstream PR + # to move this function to state_dict_utils.py. + state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict) + return cls(name=file_path.stem, state_dict=state_dict, network_alphas=network_alphas) diff --git a/invokeai/backend/peft/peft_model_patcher.py b/invokeai/backend/peft/peft_model_patcher.py new file mode 100644 index 0000000000..0174fb9581 --- /dev/null +++ b/invokeai/backend/peft/peft_model_patcher.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import Iterator, Tuple + +import torch + +from invokeai.backend.peft.peft_model import PeftModel + + +class PeftModelPatcher: + @classmethod + @contextmanager + @torch.no_grad() + def apply_peft_patch( + cls, + model: torch.nn.Module, + peft_models: Iterator[Tuple[PeftModel, float]], + prefix: str, + ): + original_weights = {} + + model_state_dict = model.state_dict() + try: + for peft_model, peft_model_weight in peft_models: + for layer_key, layer in peft_model.state_dict.items(): + if not layer_key.startswith(prefix): + continue + + module_key = layer_key.replace(prefix + ".", "") + module_key = module_key.split + # TODO(ryand): Make this work. + module = model_state_dict[module_key] + + # All of the LoRA weight calculations will be done on the same device as the module weight. + # (Performance will be best if this is a CUDA device.) + device = module.weight.device + dtype = module.weight.dtype + + if module_key not in original_weights: + # TODO(ryand): Set non_blocking = True? + original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) + + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + + # We intentionally move to the target device first, then cast. Experimentally, this was found to + # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the + # same thing in a single call to '.to(...)'. + layer.to(device=device) + layer.to(dtype=torch.float32) + # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA + # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. + layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) + layer.to(device=torch.device("cpu")) + + assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! + if module.weight.shape != layer_weight.shape: + # TODO: debug on lycoris + assert hasattr(layer_weight, "reshape") + layer_weight = layer_weight.reshape(module.weight.shape) + + assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! + module.weight += layer_weight.to(dtype=dtype) + yield + finally: + for module_key, weight in original_weights.items(): + model.get_submodule(module_key).weight.copy_(weight) diff --git a/invokeai/backend/peft/sdxl_format_utils.py b/invokeai/backend/peft/sdxl_format_utils.py new file mode 100644 index 0000000000..31afb7c019 --- /dev/null +++ b/invokeai/backend/peft/sdxl_format_utils.py @@ -0,0 +1,154 @@ +import bisect + +import torch + + +def convert_sdxl_keys_to_diffusers_format(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Convert the keys of an SDXL LoRA state_dict to diffusers format. + + The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in + diffusers format, then this function will have no effect. + + This function is adapted from: + https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409 + + Args: + state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict. + + Raises: + ValueError: If state_dict contains an unrecognized key, or not all keys could be converted. + + Returns: + Dict[str, Tensor]: The diffusers-format state_dict. + """ + converted_count = 0 # The number of Stability AI keys converted to diffusers format. + not_converted_count = 0 # The number of keys that were not converted. + + # Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes. + # For example, we want to efficiently find `input_blocks_4_1` in the list when searching for + # `input_blocks_4_1_proj_in`. + stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP) + stability_unet_keys.sort() + + new_state_dict = {} + for full_key, value in state_dict.items(): + if full_key.startswith("lora_unet_"): + search_key = full_key.replace("lora_unet_", "") + # Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix. + position = bisect.bisect_right(stability_unet_keys, search_key) + map_key = stability_unet_keys[position - 1] + # Now, check if the map_key *actually* matches the search_key. + if search_key.startswith(map_key): + new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key]) + new_state_dict[new_key] = value + converted_count += 1 + else: + new_state_dict[full_key] = value + not_converted_count += 1 + elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): + # The CLIP text encoders have the same keys in both Stability AI and diffusers formats. + new_state_dict[full_key] = value + continue + else: + raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.") + + if converted_count > 0 and not_converted_count > 0: + raise ValueError( + f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count}," + f" not_converted={not_converted_count}" + ) + + return new_state_dict + + +# Code based on: +# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32 +def make_sdxl_unet_conversion_map() -> list[tuple[str, str]]: + """Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format.""" + unet_conversion_map_layer: list[tuple[str, str]] = [] + + for i in range(3): # num_blocks is 3 in sdxl + # loop over downblocks/upblocks + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + # if i > 0: commentout for sdxl + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0.", "norm1."), + ("in_layers.2.", "conv1."), + ("out_layers.0.", "norm2."), + ("out_layers.3.", "conv2."), + ("emb_layers.1.", "time_emb_proj."), + ("skip_connection.", "conv_shortcut."), + ] + + unet_conversion_map: list[tuple[str, str]] = [] + for sd, hf in unet_conversion_map_layer: + if "resnets" in hf: + for sd_res, hf_res in unet_conversion_map_resnet: + unet_conversion_map.append((sd + sd_res, hf + hf_res)) + else: + unet_conversion_map.append((sd, hf)) + + for j in range(2): + hf_time_embed_prefix = f"time_embedding.linear_{j+1}." + sd_time_embed_prefix = f"time_embed.{j*2}." + unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) + + for j in range(2): + hf_label_embed_prefix = f"add_embedding.linear_{j+1}." + sd_label_embed_prefix = f"label_emb.0.{j*2}." + unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) + + unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) + unet_conversion_map.append(("out.0.", "conv_norm_out.")) + unet_conversion_map.append(("out.2.", "conv_out.")) + + return unet_conversion_map + + +# A mapping of state_dict key prefixes from Stability AI SDXL format to diffusers SDXL format. +SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = { + sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map() +}