mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP - still need to figure out why the keys are wrong.
This commit is contained in:
parent
827ac4b841
commit
e315fb9e7b
@ -9,7 +9,8 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
|||||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.ti_utils import generate_ti_list
|
from invokeai.app.util.ti_utils import generate_ti_list
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora_model_patcher import LoraModelPatcher
|
||||||
|
from invokeai.backend.lora_model_raw import LoRAModelRaw
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
@ -80,7 +81,8 @@ class CompelInvocation(BaseInvocation):
|
|||||||
),
|
),
|
||||||
text_encoder_info as text_encoder,
|
text_encoder_info as text_encoder,
|
||||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
# ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||||
|
LoraModelPatcher.apply_lora_to_text_encoder(text_encoder, _lora_loader(), "text_encoder"),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
|
||||||
):
|
):
|
||||||
@ -181,7 +183,8 @@ class SDXLPromptInvocationBase:
|
|||||||
),
|
),
|
||||||
text_encoder_info as text_encoder,
|
text_encoder_info as text_encoder,
|
||||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
# ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||||
|
LoraModelPatcher.apply_lora_to_text_encoder(text_encoder, _lora_loader(), lora_prefix),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
||||||
):
|
):
|
||||||
@ -259,15 +262,15 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
c1, c1_pooled, ec1 = self.run_clip_compel(
|
||||||
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
|
context, self.clip, self.prompt, False, "text_encoder", zero_on_empty=True
|
||||||
)
|
)
|
||||||
if self.style.strip() == "":
|
if self.style.strip() == "":
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(
|
c2, c2_pooled, ec2 = self.run_clip_compel(
|
||||||
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
|
context, self.clip2, self.prompt, True, "text_encoder_2", zero_on_empty=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(
|
c2, c2_pooled, ec2 = self.run_clip_compel(
|
||||||
context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True
|
context, self.clip2, self.style, True, "text_encoder_2", zero_on_empty=True
|
||||||
)
|
)
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
|
@ -52,7 +52,8 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
|||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora_model_patcher import LoraModelPatcher
|
||||||
|
from invokeai.backend.lora_model_raw import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||||
@ -739,7 +740,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
|
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
|
||||||
unet_info as unet,
|
unet_info as unet,
|
||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
# ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||||
|
LoraModelPatcher.apply_lora_to_unet(unet, _lora_loader()),
|
||||||
):
|
):
|
||||||
assert isinstance(unet, UNet2DConditionModel)
|
assert isinstance(unet, UNet2DConditionModel)
|
||||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
65
invokeai/backend/lora_model_patcher.py
Normal file
65
invokeai/backend/lora_model_patcher.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Iterator, Tuple, Union
|
||||||
|
|
||||||
|
from diffusers.loaders.lora import LoraLoaderMixin
|
||||||
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
|
from diffusers.utils.peft_utils import recurse_remove_peft_layers
|
||||||
|
from transformers import CLIPTextModel
|
||||||
|
|
||||||
|
from invokeai.backend.lora_model_raw import LoRAModelRaw
|
||||||
|
|
||||||
|
|
||||||
|
class LoraModelPatcher:
|
||||||
|
@classmethod
|
||||||
|
def unload_lora_from_model(cls, m: Union[UNet2DConditionModel, CLIPTextModel]):
|
||||||
|
"""Unload all LoRA models from a UNet or Text Encoder.
|
||||||
|
This implementation is base on LoraLoaderMixin.unload_lora_weights().
|
||||||
|
"""
|
||||||
|
recurse_remove_peft_layers(m)
|
||||||
|
if hasattr(m, "peft_config"):
|
||||||
|
del m.peft_config # type: ignore
|
||||||
|
if hasattr(m, "_hf_peft_config_loaded"):
|
||||||
|
m._hf_peft_config_loaded = None # type: ignore
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def apply_lora_to_unet(cls, unet: UNet2DConditionModel, loras: Iterator[Tuple[LoRAModelRaw, float]]):
|
||||||
|
try:
|
||||||
|
# TODO(ryand): Test speed of low_cpu_mem_usage=True.
|
||||||
|
for lora, lora_weight in loras:
|
||||||
|
LoraLoaderMixin.load_lora_into_unet(
|
||||||
|
state_dict=lora.state_dict,
|
||||||
|
network_alphas=lora.network_alphas,
|
||||||
|
unet=unet,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
adapter_name=lora.name,
|
||||||
|
_pipeline=None,
|
||||||
|
)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
cls.unload_lora_from_model(unet)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def apply_lora_to_text_encoder(
|
||||||
|
cls, text_encoder: CLIPTextModel, loras: Iterator[Tuple[LoRAModelRaw, float]], prefix: str
|
||||||
|
):
|
||||||
|
assert prefix in ["text_encoder", "text_encoder_2"]
|
||||||
|
try:
|
||||||
|
for lora, lora_weight in loras:
|
||||||
|
# Filter the state_dict to only include the keys that start with the prefix.
|
||||||
|
text_encoder_state_dict = {
|
||||||
|
key: value for key, value in lora.state_dict.items() if key.startswith(prefix + ".")
|
||||||
|
}
|
||||||
|
if len(text_encoder_state_dict) > 0:
|
||||||
|
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||||
|
state_dict=text_encoder_state_dict,
|
||||||
|
network_alphas=lora.network_alphas,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
adapter_name=lora.name,
|
||||||
|
_pipeline=None,
|
||||||
|
)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
cls.unload_lora_from_model(text_encoder)
|
66
invokeai/backend/lora_model_raw.py
Normal file
66
invokeai/backend/lora_model_raw.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers.loaders.lora import LoraLoaderMixin
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAModelRaw:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
state_dict: dict[str, torch.Tensor],
|
||||||
|
network_alphas: Optional[dict[str, float]],
|
||||||
|
):
|
||||||
|
self._name = name
|
||||||
|
self.state_dict = state_dict
|
||||||
|
self.network_alphas = network_alphas
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||||
|
for key, layer in self.state_dict.items():
|
||||||
|
self.state_dict[key] = layer.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
"""Calculate the size of the model in bytes."""
|
||||||
|
model_size = 0
|
||||||
|
for layer in self.state_dict.values():
|
||||||
|
model_size += layer.numel() * layer.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_checkpoint(
|
||||||
|
cls, file_path: Union[str, Path], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
|
||||||
|
) -> Self:
|
||||||
|
"""This function is based on diffusers LoraLoaderMixin.load_lora_weights()."""
|
||||||
|
|
||||||
|
file_path = Path(file_path)
|
||||||
|
if file_path.is_dir():
|
||||||
|
raise NotImplementedError("LoRA models from directories are not yet supported.")
|
||||||
|
|
||||||
|
dir_path = file_path.parent
|
||||||
|
file_name = file_path.name
|
||||||
|
|
||||||
|
state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(
|
||||||
|
pretrained_model_name_or_path_or_dict=str(file_path), local_files_only=True, weight_name=str(file_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||||
|
if not is_correct_format:
|
||||||
|
raise ValueError("Invalid LoRA checkpoint.")
|
||||||
|
|
||||||
|
model = cls(
|
||||||
|
# TODO(ryand): Handle both files and directories here?
|
||||||
|
name=Path(file_path).stem,
|
||||||
|
state_dict=state_dict,
|
||||||
|
network_alphas=network_alphas,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = device or torch.device("cpu")
|
||||||
|
dtype = dtype or torch.float32
|
||||||
|
model.to(device=device, dtype=dtype)
|
||||||
|
return model
|
@ -32,7 +32,7 @@ from typing_extensions import Annotated, Any, Dict
|
|||||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora_model_raw import LoRAModelRaw
|
||||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from pathlib import Path
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora_model_raw import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModel,
|
AnyModel,
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@ -51,7 +51,6 @@ class LoRALoader(ModelLoader):
|
|||||||
model = LoRAModelRaw.from_checkpoint(
|
model = LoRAModelRaw.from_checkpoint(
|
||||||
file_path=model_path,
|
file_path=model_path,
|
||||||
dtype=self._torch_dtype,
|
dtype=self._torch_dtype,
|
||||||
base_model=self._model_base,
|
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ from invokeai.backend.model_manager import AnyModel
|
|||||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||||
|
|
||||||
from .lora import LoRAModelRaw
|
from .lora_model_raw import LoRAModelRaw
|
||||||
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
|
from invokeai.backend.lora_model_raw import LoRALayer, LoRAModelRaw
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user