This commit is contained in:
Ryan Dick 2024-04-04 22:40:42 -04:00
parent 8260252aeb
commit 22c66cf55b
7 changed files with 337 additions and 40 deletions

View File

@ -9,8 +9,9 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.peft.peft_model import PeftModel
from invokeai.backend.peft.peft_model_patcher import PeftModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningFieldData,
@ -61,15 +62,12 @@ class CompelInvocation(BaseInvocation):
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_loader() -> Iterator[Tuple[PeftModel, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, PeftModel)
yield (lora_info.model, lora.weight)
del lora_info
return
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
@ -80,7 +78,7 @@ class CompelInvocation(BaseInvocation):
),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
PeftModelPatcher.apply_peft_patch(text_encoder, _lora_loader(), "text_encoder"),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
):
@ -161,16 +159,13 @@ class SDXLPromptInvocationBase:
c_pooled = None
return c, c_pooled, None
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_loader() -> Iterator[Tuple[PeftModel, float]]:
for lora in clip_field.loras:
lora_info = context.models.load(lora.lora)
lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw)
assert isinstance(lora_model, PeftModel)
yield (lora_model, lora.weight)
del lora_info
return
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
@ -181,7 +176,7 @@ class SDXLPromptInvocationBase:
),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
PeftModelPatcher.apply_peft_patch(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
):
@ -259,15 +254,15 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
context, self.clip, self.prompt, False, "text_encoder", zero_on_empty=True
)
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_compel(
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
context, self.clip2, self.prompt, True, "text_encoder_2", zero_on_empty=True
)
else:
c2, c2_pooled, ec2 = self.run_clip_compel(
context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True
context, self.clip2, self.style, True, "text_encoder_2", zero_on_empty=True
)
original_size = (self.original_height, self.original_width)

View File

@ -4,9 +4,9 @@ import torch
from diffusers.models.modeling_utils import ModelMixin
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.peft.peft_model import PeftModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw
# ModelMixin is the base class for all diffusers and transformers models
AnyModel = Union[ModelMixin, torch.nn.Module, IPAdapter, LoRAModelRaw, TextualInversionModelRaw, IAIOnnxRuntimeModel]
AnyModel = Union[ModelMixin, torch.nn.Module, IPAdapter, PeftModel, TextualInversionModelRaw, IAIOnnxRuntimeModel]

View File

@ -6,7 +6,6 @@ from pathlib import Path
from typing import Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import (
AnyModelConfig,
BaseModelType,
@ -17,6 +16,7 @@ from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.any_model_type import AnyModel
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.peft.peft_model import PeftModel
from .. import ModelLoader, ModelLoaderRegistry
@ -47,7 +47,7 @@ class LoRALoader(ModelLoader):
raise ValueError("There are no submodels in a LoRA model.")
model_path = Path(config.path)
assert self._model_base is not None
model = LoRAModelRaw.from_checkpoint(
model = PeftModel.from_checkpoint(
file_path=model_path,
dtype=self._torch_dtype,
base_model=self._model_base,

View File

@ -0,0 +1,85 @@
import torch
from diffusers.utils.state_dict_utils import convert_state_dict
KOHYA_SS_TO_PEFT = {
"lora_down": "lora_A",
"lora_up": "lora_B",
# This is not a comprehensive dict. See `convert_state_dict_to_peft(...)` for more info on the conversion.
}
def convert_state_dict_kohya_to_peft(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
# TODO(ryand): Check that state_dict is in Kohya format.
peft_partial_state_dict = convert_state_dict(state_dict, KOHYA_SS_TO_PEFT)
peft_state_dict: dict[str, torch.Tensor] = {}
for key, weight in peft_partial_state_dict.items():
for kohya_key, weight in kohya_ss_partial_state_dict.items():
if "text_encoder_2." in kohya_key:
kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.")
elif "text_encoder." in kohya_key:
kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
elif "unet" in kohya_key:
kohya_key = kohya_key.replace("unet", "lora_unet")
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
kohya_ss_state_dict[kohya_key] = weight
if "lora_down" in kohya_key:
alpha_key = f'{kohya_key.split(".")[0]}.alpha'
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
r"""
Converts a `PEFT` state dict to `Kohya` format that can be used in AUTOMATIC1111, ComfyUI, SD.Next, InvokeAI, etc.
The method only supports the conversion from PEFT to Kohya for now.
Args:
state_dict (`dict[str, torch.Tensor]`):
The state dict to convert.
original_type (`StateDictType`, *optional*):
The original type of the state dict, if not provided, the method will try to infer it automatically.
kwargs (`dict`, *args*):
Additional arguments to pass to the method.
- **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
`get_peft_model_state_dict` method:
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
but we add it here in case we don't want to rely on that method.
"""
peft_adapter_name = kwargs.pop("adapter_name", None)
if peft_adapter_name is not None:
peft_adapter_name = "." + peft_adapter_name
else:
peft_adapter_name = ""
if original_type is None:
if any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
original_type = StateDictType.PEFT
if original_type not in KOHYA_STATE_DICT_MAPPINGS.keys():
raise ValueError(f"Original type {original_type} is not supported")
# Use the convert_state_dict function with the appropriate mapping
kohya_ss_partial_state_dict = convert_state_dict(state_dict, KOHYA_STATE_DICT_MAPPINGS[StateDictType.PEFT])
kohya_ss_state_dict = {}
# Additional logic for replacing header, alpha parameters `.` with `_` in all keys
for kohya_key, weight in kohya_ss_partial_state_dict.items():
if "text_encoder_2." in kohya_key:
kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.")
elif "text_encoder." in kohya_key:
kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
elif "unet" in kohya_key:
kohya_key = kohya_key.replace("unet", "lora_unet")
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
kohya_ss_state_dict[kohya_key] = weight
if "lora_down" in kohya_key:
alpha_key = f'{kohya_key.split(".")[0]}.alpha'
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
return kohya_ss_state_dict

View File

@ -2,9 +2,11 @@ from pathlib import Path
from typing import Optional, Union
import torch
from safetensors.torch import load_file
from diffusers.loaders.lora_conversion_utils import _convert_kohya_lora_to_diffusers
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.peft.sdxl_format_utils import convert_sdxl_keys_to_diffusers_format
from invokeai.backend.util.serialization import load_state_dict
class PeftModel:
@ -14,17 +16,15 @@ class PeftModel:
self,
name: str,
state_dict: dict[str, torch.Tensor],
network_alphas: dict[str, torch.Tensor],
):
self._name = name
self._state_dict = state_dict
@property
def name(self) -> str:
return self._name
self.name = name
self.state_dict = state_dict
self.network_alphas = network_alphas
def calc_size(self) -> int:
model_size = 0
for tensor in self._state_dict.values():
for tensor in self.state_dict.values():
model_size += tensor.nelement() * tensor.element_size()
return model_size
@ -41,16 +41,12 @@ class PeftModel:
file_path = Path(file_path)
# TODO(ryand): Implement a helper function for this. This logic is duplicated repeatedly.
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path, device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
state_dict = load_state_dict(file_path, device=str(device))
# lora_unet_up_blocks_1_attentions_2_transformer_blocks_1_ff_net_2.lora_down.weight
if base_model == BaseModelType.StableDiffusionXL:
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
# TODO(ryand):
# - Detect state_dict format
# - Convert state_dict to diffusers format if necessary
# if base_model == BaseModelType.StableDiffusionXL:
# state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
return cls(name=file_path.stem, state_dict=state_dict)
# TODO(ryand): We shouldn't be using an unexported function from diffusers here. Consider opening an upstream PR
# to move this function to state_dict_utils.py.
state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict)
return cls(name=file_path.stem, state_dict=state_dict, network_alphas=network_alphas)

View File

@ -0,0 +1,67 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import Iterator, Tuple
import torch
from invokeai.backend.peft.peft_model import PeftModel
class PeftModelPatcher:
@classmethod
@contextmanager
@torch.no_grad()
def apply_peft_patch(
cls,
model: torch.nn.Module,
peft_models: Iterator[Tuple[PeftModel, float]],
prefix: str,
):
original_weights = {}
model_state_dict = model.state_dict()
try:
for peft_model, peft_model_weight in peft_models:
for layer_key, layer in peft_model.state_dict.items():
if not layer_key.startswith(prefix):
continue
module_key = layer_key.replace(prefix + ".", "")
module_key = module_key.split
# TODO(ryand): Make this work.
module = model_state_dict[module_key]
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
if module_key not in original_weights:
# TODO(ryand): Set non_blocking = True?
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device=torch.device("cpu"))
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
assert hasattr(layer_weight, "reshape")
layer_weight = layer_weight.reshape(module.weight.shape)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
module.weight += layer_weight.to(dtype=dtype)
yield
finally:
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight)

View File

@ -0,0 +1,154 @@
import bisect
import torch
def convert_sdxl_keys_to_diffusers_format(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
Args:
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
Raises:
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
# `input_blocks_4_1_proj_in`.
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
position = bisect.bisect_right(stability_unet_keys, search_key)
map_key = stability_unet_keys[position - 1]
# Now, check if the map_key *actually* matches the search_key.
if search_key.startswith(map_key):
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
new_state_dict[new_key] = value
converted_count += 1
else:
new_state_dict[full_key] = value
not_converted_count += 1
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
new_state_dict[full_key] = value
continue
else:
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
if converted_count > 0 and not_converted_count > 0:
raise ValueError(
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
f" not_converted={not_converted_count}"
)
return new_state_dict
# Code based on:
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def make_sdxl_unet_conversion_map() -> list[tuple[str, str]]:
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer: list[tuple[str, str]] = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map: list[tuple[str, str]] = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
# A mapping of state_dict key prefixes from Stability AI SDXL format to diffusers SDXL format.
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
}