A very primitive working version of peft patching. It is very slow. LoRAs don't get unloaded yet, so can only be run once. And the results are *slightly* different than the old implementation. I suspect this is because the lora weight is not being applied to the UNet, but there could be other issues as well.

This commit is contained in:
Ryan Dick 2024-04-05 12:02:05 -04:00
parent 22c66cf55b
commit f9fda503a3
3 changed files with 168 additions and 8 deletions

View File

@ -78,7 +78,7 @@ class CompelInvocation(BaseInvocation):
), ),
text_encoder_info as text_encoder, text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching. # Apply the LoRA after text_encoder has been moved to its target device for faster patching.
PeftModelPatcher.apply_peft_patch(text_encoder, _lora_loader(), "text_encoder"), PeftModelPatcher.apply_peft_model_to_text_encoder(text_encoder, _lora_loader(), "text_encoder"),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
): ):
@ -176,7 +176,7 @@ class SDXLPromptInvocationBase:
), ),
text_encoder_info as text_encoder, text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching. # Apply the LoRA after text_encoder has been moved to its target device for faster patching.
PeftModelPatcher.apply_peft_patch(text_encoder, _lora_loader(), lora_prefix), PeftModelPatcher.apply_peft_model_to_text_encoder(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
): ):

View File

@ -48,9 +48,10 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType, LoadedModel from invokeai.backend.model_manager import BaseModelType, LoadedModel
from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.peft.peft_model import PeftModel
from invokeai.backend.peft.peft_model_patcher import PeftModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings
@ -714,13 +715,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
def step_callback(state: PipelineIntermediateState) -> None: def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base) context.util.sd_step_callback(state, unet_config.base)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[PeftModel, float]]:
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.models.load(lora.lora) lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw) assert isinstance(lora_info.model, PeftModel)
yield (lora_info.model, lora.weight) yield (lora_info.model, lora.weight)
del lora_info del lora_info
return
unet_info = context.models.load(self.unet.unet) unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel) assert isinstance(unet_info.model, UNet2DConditionModel)
@ -730,7 +730,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
unet_info as unet, unet_info as unet,
# Apply the LoRA after unet has been moved to its target device for faster patching. # Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()), PeftModelPatcher.apply_peft_model_to_unet(unet, _lora_loader()),
): ):
assert isinstance(unet, UNet2DConditionModel) assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype) latents = latents.to(device=unet.device, dtype=unet.dtype)

View File

@ -4,11 +4,171 @@ from contextlib import contextmanager
from typing import Iterator, Tuple from typing import Iterator, Tuple
import torch import torch
from diffusers.models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.utils.peft_utils import get_peft_kwargs, scale_lora_layers
from diffusers.utils.state_dict_utils import convert_state_dict_to_peft, convert_unet_state_dict_to_peft
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from invokeai.backend.peft.peft_model import PeftModel from invokeai.backend.peft.peft_model import PeftModel
UNET_NAME = "unet"
class PeftModelPatcher: class PeftModelPatcher:
@classmethod
@contextmanager
@torch.no_grad()
def apply_peft_model_to_text_encoder(
cls,
text_encoder: torch.nn.Module,
peft_models: Iterator[Tuple[PeftModel, float]],
prefix: str,
):
original_weights = {}
try:
for peft_model, peft_model_weight in peft_models:
keys = list(peft_model.state_dict.keys())
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in peft_model.state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) == 0:
continue
if peft_model.name in getattr(text_encoder, "peft_config", {}):
raise ValueError(f"Adapter name {peft_model.name} already in use in the text encoder ({prefix}).")
rank = {}
# TODO(ryand): Is this necessary?
# text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_B.weight"
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
network_alphas = peft_model.network_alphas
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
lora_config_kwargs["inference_mode"] = True
lora_config = LoraConfig(**lora_config_kwargs)
new_text_encoder = inject_adapter_in_model(lora_config, text_encoder, peft_model.name)
incompatible_keys = set_peft_model_state_dict(
new_text_encoder, text_encoder_lora_state_dict, peft_model.name
)
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
raise ValueError(f"Failed to inject unexpected PEFT keys: {unexpected_keys}")
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
# text_encoder.load_adapter(
# adapter_name=adapter_name,
# adapter_state_dict=text_encoder_lora_state_dict,
# peft_config=lora_config,
# )
scale_lora_layers(text_encoder, weight=peft_model_weight)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
yield
finally:
# TODO
pass
# for module_key, weight in original_weights.items():
# model.get_submodule(module_key).weight.copy_(weight)
@classmethod
@contextmanager
@torch.no_grad()
def apply_peft_model_to_unet(
cls,
unet: UNet2DConditionModel,
peft_models: Iterator[Tuple[PeftModel, float]],
):
try:
for peft_model, peft_model_weight in peft_models:
keys = list(peft_model.state_dict.keys())
unet_keys = [k for k in keys if k.startswith(UNET_NAME)]
state_dict = {
k.replace(f"{UNET_NAME}.", ""): v for k, v in peft_model.state_dict.items() if k in unet_keys
}
network_alphas = peft_model.network_alphas
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(UNET_NAME)]
network_alphas = {
k.replace(f"{UNET_NAME}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
if len(state_dict) == 0:
continue
if peft_model.name in getattr(unet, "peft_config", {}):
raise ValueError(f"Adapter name {peft_model.name} already in use in the Unet.")
state_dict = convert_unet_state_dict_to_peft(state_dict)
if network_alphas is not None:
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
# `convert_unet_state_dict_to_peft` method.
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
lora_config_kwargs["inference_mode"] = True
lora_config = LoraConfig(**lora_config_kwargs)
inject_adapter_in_model(lora_config, unet, adapter_name=peft_model.name)
incompatible_keys = set_peft_model_state_dict(unet, state_dict, peft_model.name)
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
raise ValueError(f"Failed to inject unexpected PEFT keys: {unexpected_keys}")
# TODO(ryand): What does this do?
unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=True)
# TODO(ryand): Apply the lora weight. Where does diffusers do this? They don't seem to do it when they
# patch the UNet.
yield
finally:
# TODO
pass
# for module_key, weight in original_weights.items():
# model.get_submodule(module_key).weight.copy_(weight)
@classmethod @classmethod
@contextmanager @contextmanager
@torch.no_grad() @torch.no_grad()
@ -28,8 +188,8 @@ class PeftModelPatcher:
continue continue
module_key = layer_key.replace(prefix + ".", "") module_key = layer_key.replace(prefix + ".", "")
module_key = module_key.split
# TODO(ryand): Make this work. # TODO(ryand): Make this work.
module = model_state_dict[module_key] module = model_state_dict[module_key]
# All of the LoRA weight calculations will be done on the same device as the module weight. # All of the LoRA weight calculations will be done on the same device as the module weight.