WIP - still need to figure out why the keys are wrong.

This commit is contained in:
Ryan Dick 2024-03-28 17:14:47 -04:00
parent 827ac4b841
commit e315fb9e7b
9 changed files with 148 additions and 13 deletions

View File

@ -9,7 +9,8 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.lora_model_patcher import LoraModelPatcher
from invokeai.backend.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
@ -80,7 +81,8 @@ class CompelInvocation(BaseInvocation):
),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
LoraModelPatcher.apply_lora_to_text_encoder(text_encoder, _lora_loader(), "text_encoder"),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
):
@ -181,7 +183,8 @@ class SDXLPromptInvocationBase:
),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
LoraModelPatcher.apply_lora_to_text_encoder(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
):
@ -259,15 +262,15 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
context, self.clip, self.prompt, False, "text_encoder", zero_on_empty=True
)
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_compel(
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
context, self.clip2, self.prompt, True, "text_encoder_2", zero_on_empty=True
)
else:
c2, c2_pooled, ec2 = self.run_clip_compel(
context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True
context, self.clip2, self.style, True, "text_encoder_2", zero_on_empty=True
)
original_size = (self.original_height, self.original_width)

View File

@ -52,7 +52,8 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.lora_model_patcher import LoraModelPatcher
from invokeai.backend.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType, LoadedModel
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
@ -739,7 +740,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
unet_info as unet,
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
# ModelPatcher.apply_lora_unet(unet, _lora_loader()),
LoraModelPatcher.apply_lora_to_unet(unet, _lora_loader()),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)

View File

@ -0,0 +1,65 @@
from contextlib import contextmanager
from typing import Iterator, Tuple, Union
from diffusers.loaders.lora import LoraLoaderMixin
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.utils.peft_utils import recurse_remove_peft_layers
from transformers import CLIPTextModel
from invokeai.backend.lora_model_raw import LoRAModelRaw
class LoraModelPatcher:
@classmethod
def unload_lora_from_model(cls, m: Union[UNet2DConditionModel, CLIPTextModel]):
"""Unload all LoRA models from a UNet or Text Encoder.
This implementation is base on LoraLoaderMixin.unload_lora_weights().
"""
recurse_remove_peft_layers(m)
if hasattr(m, "peft_config"):
del m.peft_config # type: ignore
if hasattr(m, "_hf_peft_config_loaded"):
m._hf_peft_config_loaded = None # type: ignore
@classmethod
@contextmanager
def apply_lora_to_unet(cls, unet: UNet2DConditionModel, loras: Iterator[Tuple[LoRAModelRaw, float]]):
try:
# TODO(ryand): Test speed of low_cpu_mem_usage=True.
for lora, lora_weight in loras:
LoraLoaderMixin.load_lora_into_unet(
state_dict=lora.state_dict,
network_alphas=lora.network_alphas,
unet=unet,
low_cpu_mem_usage=True,
adapter_name=lora.name,
_pipeline=None,
)
yield
finally:
cls.unload_lora_from_model(unet)
@classmethod
@contextmanager
def apply_lora_to_text_encoder(
cls, text_encoder: CLIPTextModel, loras: Iterator[Tuple[LoRAModelRaw, float]], prefix: str
):
assert prefix in ["text_encoder", "text_encoder_2"]
try:
for lora, lora_weight in loras:
# Filter the state_dict to only include the keys that start with the prefix.
text_encoder_state_dict = {
key: value for key, value in lora.state_dict.items() if key.startswith(prefix + ".")
}
if len(text_encoder_state_dict) > 0:
LoraLoaderMixin.load_lora_into_text_encoder(
state_dict=text_encoder_state_dict,
network_alphas=lora.network_alphas,
text_encoder=text_encoder,
low_cpu_mem_usage=True,
adapter_name=lora.name,
_pipeline=None,
)
yield
finally:
cls.unload_lora_from_model(text_encoder)

View File

@ -0,0 +1,66 @@
from pathlib import Path
from typing import Optional, Union
import torch
from diffusers.loaders.lora import LoraLoaderMixin
from typing_extensions import Self
class LoRAModelRaw:
def __init__(
self,
name: str,
state_dict: dict[str, torch.Tensor],
network_alphas: Optional[dict[str, float]],
):
self._name = name
self.state_dict = state_dict
self.network_alphas = network_alphas
@property
def name(self) -> str:
return self._name
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
for key, layer in self.state_dict.items():
self.state_dict[key] = layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
"""Calculate the size of the model in bytes."""
model_size = 0
for layer in self.state_dict.values():
model_size += layer.numel() * layer.element_size()
return model_size
@classmethod
def from_checkpoint(
cls, file_path: Union[str, Path], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
) -> Self:
"""This function is based on diffusers LoraLoaderMixin.load_lora_weights()."""
file_path = Path(file_path)
if file_path.is_dir():
raise NotImplementedError("LoRA models from directories are not yet supported.")
dir_path = file_path.parent
file_name = file_path.name
state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(
pretrained_model_name_or_path_or_dict=str(file_path), local_files_only=True, weight_name=str(file_name)
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
model = cls(
# TODO(ryand): Handle both files and directories here?
name=Path(file_path).stem,
state_dict=state_dict,
network_alphas=network_alphas,
)
device = device or torch.device("cpu")
dtype = dtype or torch.float32
model.to(device=device, dtype=dtype)
return model

View File

@ -32,7 +32,7 @@ from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.lora_model_raw import LoRAModelRaw
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw

View File

@ -6,7 +6,7 @@ from pathlib import Path
from typing import Optional, Tuple
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@ -51,7 +51,6 @@ class LoRALoader(ModelLoader):
model = LoRAModelRaw.from_checkpoint(
file_path=model_path,
dtype=self._torch_dtype,
base_model=self._model_base,
)
return model

View File

@ -17,7 +17,7 @@ from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from .lora import LoRAModelRaw
from .lora_model_raw import LoRAModelRaw
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
"""

View File

@ -5,7 +5,7 @@
import pytest
import torch
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
from invokeai.backend.lora_model_raw import LoRALayer, LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher