mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
A very primitive working version of peft patching. It is very slow. LoRAs don't get unloaded yet, so can only be run once. And the results are *slightly* different than the old implementation. I suspect this is because the lora weight is not being applied to the UNet, but there could be other issues as well.
This commit is contained in:
parent
22c66cf55b
commit
f9fda503a3
@ -78,7 +78,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
),
|
),
|
||||||
text_encoder_info as text_encoder,
|
text_encoder_info as text_encoder,
|
||||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||||
PeftModelPatcher.apply_peft_patch(text_encoder, _lora_loader(), "text_encoder"),
|
PeftModelPatcher.apply_peft_model_to_text_encoder(text_encoder, _lora_loader(), "text_encoder"),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
|
||||||
):
|
):
|
||||||
@ -176,7 +176,7 @@ class SDXLPromptInvocationBase:
|
|||||||
),
|
),
|
||||||
text_encoder_info as text_encoder,
|
text_encoder_info as text_encoder,
|
||||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||||
PeftModelPatcher.apply_peft_patch(text_encoder, _lora_loader(), lora_prefix),
|
PeftModelPatcher.apply_peft_model_to_text_encoder(text_encoder, _lora_loader(), lora_prefix),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
||||||
):
|
):
|
||||||
|
@ -48,9 +48,10 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
|||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
|
||||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
|
from invokeai.backend.peft.peft_model import PeftModel
|
||||||
|
from invokeai.backend.peft.peft_model_patcher import PeftModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
@ -714,13 +715,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
def step_callback(state: PipelineIntermediateState) -> None:
|
def step_callback(state: PipelineIntermediateState) -> None:
|
||||||
context.util.sd_step_callback(state, unet_config.base)
|
context.util.sd_step_callback(state, unet_config.base)
|
||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[PeftModel, float]]:
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.models.load(lora.lora)
|
lora_info = context.models.load(lora.lora)
|
||||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
assert isinstance(lora_info.model, PeftModel)
|
||||||
yield (lora_info.model, lora.weight)
|
yield (lora_info.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
|
||||||
|
|
||||||
unet_info = context.models.load(self.unet.unet)
|
unet_info = context.models.load(self.unet.unet)
|
||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
@ -730,7 +730,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
|
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
|
||||||
unet_info as unet,
|
unet_info as unet,
|
||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
PeftModelPatcher.apply_peft_model_to_unet(unet, _lora_loader()),
|
||||||
):
|
):
|
||||||
assert isinstance(unet, UNet2DConditionModel)
|
assert isinstance(unet, UNet2DConditionModel)
|
||||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
@ -4,11 +4,171 @@ from contextlib import contextmanager
|
|||||||
from typing import Iterator, Tuple
|
from typing import Iterator, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers.models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||||
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
|
from diffusers.utils.peft_utils import get_peft_kwargs, scale_lora_layers
|
||||||
|
from diffusers.utils.state_dict_utils import convert_state_dict_to_peft, convert_unet_state_dict_to_peft
|
||||||
|
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||||
|
|
||||||
from invokeai.backend.peft.peft_model import PeftModel
|
from invokeai.backend.peft.peft_model import PeftModel
|
||||||
|
|
||||||
|
UNET_NAME = "unet"
|
||||||
|
|
||||||
|
|
||||||
class PeftModelPatcher:
|
class PeftModelPatcher:
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
@torch.no_grad()
|
||||||
|
def apply_peft_model_to_text_encoder(
|
||||||
|
cls,
|
||||||
|
text_encoder: torch.nn.Module,
|
||||||
|
peft_models: Iterator[Tuple[PeftModel, float]],
|
||||||
|
prefix: str,
|
||||||
|
):
|
||||||
|
original_weights = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
for peft_model, peft_model_weight in peft_models:
|
||||||
|
keys = list(peft_model.state_dict.keys())
|
||||||
|
|
||||||
|
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||||
|
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||||
|
text_encoder_lora_state_dict = {
|
||||||
|
k.replace(f"{prefix}.", ""): v for k, v in peft_model.state_dict.items() if k in text_encoder_keys
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(text_encoder_lora_state_dict) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if peft_model.name in getattr(text_encoder, "peft_config", {}):
|
||||||
|
raise ValueError(f"Adapter name {peft_model.name} already in use in the text encoder ({prefix}).")
|
||||||
|
|
||||||
|
rank = {}
|
||||||
|
# TODO(ryand): Is this necessary?
|
||||||
|
# text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||||
|
|
||||||
|
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||||
|
|
||||||
|
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||||
|
rank_key = f"{name}.out_proj.lora_B.weight"
|
||||||
|
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||||
|
|
||||||
|
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||||
|
if patch_mlp:
|
||||||
|
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||||
|
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
|
||||||
|
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
|
||||||
|
|
||||||
|
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
||||||
|
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
||||||
|
|
||||||
|
network_alphas = peft_model.network_alphas
|
||||||
|
if network_alphas is not None:
|
||||||
|
alpha_keys = [
|
||||||
|
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
||||||
|
]
|
||||||
|
network_alphas = {
|
||||||
|
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||||
|
}
|
||||||
|
|
||||||
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||||
|
lora_config_kwargs["inference_mode"] = True
|
||||||
|
lora_config = LoraConfig(**lora_config_kwargs)
|
||||||
|
|
||||||
|
new_text_encoder = inject_adapter_in_model(lora_config, text_encoder, peft_model.name)
|
||||||
|
incompatible_keys = set_peft_model_state_dict(
|
||||||
|
new_text_encoder, text_encoder_lora_state_dict, peft_model.name
|
||||||
|
)
|
||||||
|
if incompatible_keys is not None:
|
||||||
|
# check only for unexpected keys
|
||||||
|
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||||
|
if unexpected_keys:
|
||||||
|
raise ValueError(f"Failed to inject unexpected PEFT keys: {unexpected_keys}")
|
||||||
|
|
||||||
|
# inject LoRA layers and load the state dict
|
||||||
|
# in transformers we automatically check whether the adapter name is already in use or not
|
||||||
|
# text_encoder.load_adapter(
|
||||||
|
# adapter_name=adapter_name,
|
||||||
|
# adapter_state_dict=text_encoder_lora_state_dict,
|
||||||
|
# peft_config=lora_config,
|
||||||
|
# )
|
||||||
|
|
||||||
|
scale_lora_layers(text_encoder, weight=peft_model_weight)
|
||||||
|
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||||
|
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# TODO
|
||||||
|
pass
|
||||||
|
# for module_key, weight in original_weights.items():
|
||||||
|
# model.get_submodule(module_key).weight.copy_(weight)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
@torch.no_grad()
|
||||||
|
def apply_peft_model_to_unet(
|
||||||
|
cls,
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
peft_models: Iterator[Tuple[PeftModel, float]],
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
for peft_model, peft_model_weight in peft_models:
|
||||||
|
keys = list(peft_model.state_dict.keys())
|
||||||
|
|
||||||
|
unet_keys = [k for k in keys if k.startswith(UNET_NAME)]
|
||||||
|
state_dict = {
|
||||||
|
k.replace(f"{UNET_NAME}.", ""): v for k, v in peft_model.state_dict.items() if k in unet_keys
|
||||||
|
}
|
||||||
|
|
||||||
|
network_alphas = peft_model.network_alphas
|
||||||
|
if network_alphas is not None:
|
||||||
|
alpha_keys = [k for k in network_alphas.keys() if k.startswith(UNET_NAME)]
|
||||||
|
network_alphas = {
|
||||||
|
k.replace(f"{UNET_NAME}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(state_dict) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if peft_model.name in getattr(unet, "peft_config", {}):
|
||||||
|
raise ValueError(f"Adapter name {peft_model.name} already in use in the Unet.")
|
||||||
|
|
||||||
|
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
||||||
|
|
||||||
|
if network_alphas is not None:
|
||||||
|
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
|
||||||
|
# `convert_unet_state_dict_to_peft` method.
|
||||||
|
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
|
||||||
|
|
||||||
|
rank = {}
|
||||||
|
for key, val in state_dict.items():
|
||||||
|
if "lora_B" in key:
|
||||||
|
rank[key] = val.shape[1]
|
||||||
|
|
||||||
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
|
||||||
|
lora_config_kwargs["inference_mode"] = True
|
||||||
|
lora_config = LoraConfig(**lora_config_kwargs)
|
||||||
|
|
||||||
|
inject_adapter_in_model(lora_config, unet, adapter_name=peft_model.name)
|
||||||
|
incompatible_keys = set_peft_model_state_dict(unet, state_dict, peft_model.name)
|
||||||
|
if incompatible_keys is not None:
|
||||||
|
# check only for unexpected keys
|
||||||
|
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||||
|
if unexpected_keys:
|
||||||
|
raise ValueError(f"Failed to inject unexpected PEFT keys: {unexpected_keys}")
|
||||||
|
|
||||||
|
# TODO(ryand): What does this do?
|
||||||
|
unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=True)
|
||||||
|
|
||||||
|
# TODO(ryand): Apply the lora weight. Where does diffusers do this? They don't seem to do it when they
|
||||||
|
# patch the UNet.
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# TODO
|
||||||
|
pass
|
||||||
|
# for module_key, weight in original_weights.items():
|
||||||
|
# model.get_submodule(module_key).weight.copy_(weight)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -28,8 +188,8 @@ class PeftModelPatcher:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
module_key = layer_key.replace(prefix + ".", "")
|
module_key = layer_key.replace(prefix + ".", "")
|
||||||
module_key = module_key.split
|
|
||||||
# TODO(ryand): Make this work.
|
# TODO(ryand): Make this work.
|
||||||
|
|
||||||
module = model_state_dict[module_key]
|
module = model_state_dict[module_key]
|
||||||
|
|
||||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||||
|
Loading…
Reference in New Issue
Block a user