mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP - still need to figure out why the keys are wrong.
This commit is contained in:
parent
827ac4b841
commit
e315fb9e7b
@ -9,7 +9,8 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.lora_model_patcher import LoraModelPatcher
|
||||
from invokeai.backend.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
@ -80,7 +81,8 @@ class CompelInvocation(BaseInvocation):
|
||||
),
|
||||
text_encoder_info as text_encoder,
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
# ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
LoraModelPatcher.apply_lora_to_text_encoder(text_encoder, _lora_loader(), "text_encoder"),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
|
||||
):
|
||||
@ -181,7 +183,8 @@ class SDXLPromptInvocationBase:
|
||||
),
|
||||
text_encoder_info as text_encoder,
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
# ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
LoraModelPatcher.apply_lora_to_text_encoder(text_encoder, _lora_loader(), lora_prefix),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
||||
):
|
||||
@ -259,15 +262,15 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
||||
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
|
||||
context, self.clip, self.prompt, False, "text_encoder", zero_on_empty=True
|
||||
)
|
||||
if self.style.strip() == "":
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(
|
||||
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
|
||||
context, self.clip2, self.prompt, True, "text_encoder_2", zero_on_empty=True
|
||||
)
|
||||
else:
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(
|
||||
context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True
|
||||
context, self.clip2, self.style, True, "text_encoder_2", zero_on_empty=True
|
||||
)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
|
@ -52,7 +52,8 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.lora_model_patcher import LoraModelPatcher
|
||||
from invokeai.backend.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
@ -739,7 +740,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
|
||||
unet_info as unet,
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||
# ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||
LoraModelPatcher.apply_lora_to_unet(unet, _lora_loader()),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
65
invokeai/backend/lora_model_patcher.py
Normal file
65
invokeai/backend/lora_model_patcher.py
Normal file
@ -0,0 +1,65 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator, Tuple, Union
|
||||
|
||||
from diffusers.loaders.lora import LoraLoaderMixin
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.utils.peft_utils import recurse_remove_peft_layers
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from invokeai.backend.lora_model_raw import LoRAModelRaw
|
||||
|
||||
|
||||
class LoraModelPatcher:
|
||||
@classmethod
|
||||
def unload_lora_from_model(cls, m: Union[UNet2DConditionModel, CLIPTextModel]):
|
||||
"""Unload all LoRA models from a UNet or Text Encoder.
|
||||
This implementation is base on LoraLoaderMixin.unload_lora_weights().
|
||||
"""
|
||||
recurse_remove_peft_layers(m)
|
||||
if hasattr(m, "peft_config"):
|
||||
del m.peft_config # type: ignore
|
||||
if hasattr(m, "_hf_peft_config_loaded"):
|
||||
m._hf_peft_config_loaded = None # type: ignore
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_to_unet(cls, unet: UNet2DConditionModel, loras: Iterator[Tuple[LoRAModelRaw, float]]):
|
||||
try:
|
||||
# TODO(ryand): Test speed of low_cpu_mem_usage=True.
|
||||
for lora, lora_weight in loras:
|
||||
LoraLoaderMixin.load_lora_into_unet(
|
||||
state_dict=lora.state_dict,
|
||||
network_alphas=lora.network_alphas,
|
||||
unet=unet,
|
||||
low_cpu_mem_usage=True,
|
||||
adapter_name=lora.name,
|
||||
_pipeline=None,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
cls.unload_lora_from_model(unet)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_to_text_encoder(
|
||||
cls, text_encoder: CLIPTextModel, loras: Iterator[Tuple[LoRAModelRaw, float]], prefix: str
|
||||
):
|
||||
assert prefix in ["text_encoder", "text_encoder_2"]
|
||||
try:
|
||||
for lora, lora_weight in loras:
|
||||
# Filter the state_dict to only include the keys that start with the prefix.
|
||||
text_encoder_state_dict = {
|
||||
key: value for key, value in lora.state_dict.items() if key.startswith(prefix + ".")
|
||||
}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
state_dict=text_encoder_state_dict,
|
||||
network_alphas=lora.network_alphas,
|
||||
text_encoder=text_encoder,
|
||||
low_cpu_mem_usage=True,
|
||||
adapter_name=lora.name,
|
||||
_pipeline=None,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
cls.unload_lora_from_model(text_encoder)
|
66
invokeai/backend/lora_model_raw.py
Normal file
66
invokeai/backend/lora_model_raw.py
Normal file
@ -0,0 +1,66 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from diffusers.loaders.lora import LoraLoaderMixin
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class LoRAModelRaw:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
network_alphas: Optional[dict[str, float]],
|
||||
):
|
||||
self._name = name
|
||||
self.state_dict = state_dict
|
||||
self.network_alphas = network_alphas
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
for key, layer in self.state_dict.items():
|
||||
self.state_dict[key] = layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
"""Calculate the size of the model in bytes."""
|
||||
model_size = 0
|
||||
for layer in self.state_dict.values():
|
||||
model_size += layer.numel() * layer.element_size()
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls, file_path: Union[str, Path], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
|
||||
) -> Self:
|
||||
"""This function is based on diffusers LoraLoaderMixin.load_lora_weights()."""
|
||||
|
||||
file_path = Path(file_path)
|
||||
if file_path.is_dir():
|
||||
raise NotImplementedError("LoRA models from directories are not yet supported.")
|
||||
|
||||
dir_path = file_path.parent
|
||||
file_name = file_path.name
|
||||
|
||||
state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=str(file_path), local_files_only=True, weight_name=str(file_name)
|
||||
)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
model = cls(
|
||||
# TODO(ryand): Handle both files and directories here?
|
||||
name=Path(file_path).stem,
|
||||
state_dict=state_dict,
|
||||
network_alphas=network_alphas,
|
||||
)
|
||||
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
model.to(device=device, dtype=dtype)
|
||||
return model
|
@ -32,7 +32,7 @@ from typing_extensions import Annotated, Any, Dict
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
|
@ -6,7 +6,7 @@ from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@ -51,7 +51,6 @@ class LoRALoader(ModelLoader):
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
dtype=self._torch_dtype,
|
||||
base_model=self._model_base,
|
||||
)
|
||||
return model
|
||||
|
||||
|
@ -17,7 +17,7 @@ from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
|
||||
from .lora import LoRAModelRaw
|
||||
from .lora_model_raw import LoRAModelRaw
|
||||
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||
|
||||
"""
|
||||
|
@ -5,7 +5,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
|
||||
from invokeai.backend.lora_model_raw import LoRALayer, LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user