Compare commits

...

3 Commits

14 changed files with 159 additions and 43 deletions

View File

@ -9,7 +9,8 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.lora_model_patcher import LoraModelPatcher
from invokeai.backend.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
@ -80,7 +81,8 @@ class CompelInvocation(BaseInvocation):
),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
LoraModelPatcher.apply_lora_to_text_encoder(text_encoder, _lora_loader(), "text_encoder"),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
):
@ -181,7 +183,8 @@ class SDXLPromptInvocationBase:
),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
LoraModelPatcher.apply_lora_to_text_encoder(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
):
@ -259,15 +262,15 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
context, self.clip, self.prompt, False, "text_encoder", zero_on_empty=True
)
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_compel(
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
context, self.clip2, self.prompt, True, "text_encoder_2", zero_on_empty=True
)
else:
c2, c2_pooled, ec2 = self.run_clip_compel(
context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True
context, self.clip2, self.style, True, "text_encoder_2", zero_on_empty=True
)
original_size = (self.original_height, self.original_width)

View File

@ -52,7 +52,8 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.lora_model_patcher import LoraModelPatcher
from invokeai.backend.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType, LoadedModel
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
@ -739,7 +740,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
unet_info as unet,
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
# ModelPatcher.apply_lora_unet(unet, _lora_loader()),
LoraModelPatcher.apply_lora_to_unet(unet, _lora_loader()),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)

View File

@ -9,7 +9,6 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from ..raw_model import RawModel
from .resampler import Resampler
@ -92,7 +91,7 @@ class MLPProjModel(torch.nn.Module):
return clip_extra_context_tokens
class IPAdapter(RawModel):
class IPAdapter(torch.nn.Module):
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
def __init__(

View File

@ -0,0 +1,65 @@
from contextlib import contextmanager
from typing import Iterator, Tuple, Union
from diffusers.loaders.lora import LoraLoaderMixin
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.utils.peft_utils import recurse_remove_peft_layers
from transformers import CLIPTextModel
from invokeai.backend.lora_model_raw import LoRAModelRaw
class LoraModelPatcher:
@classmethod
def unload_lora_from_model(cls, m: Union[UNet2DConditionModel, CLIPTextModel]):
"""Unload all LoRA models from a UNet or Text Encoder.
This implementation is base on LoraLoaderMixin.unload_lora_weights().
"""
recurse_remove_peft_layers(m)
if hasattr(m, "peft_config"):
del m.peft_config # type: ignore
if hasattr(m, "_hf_peft_config_loaded"):
m._hf_peft_config_loaded = None # type: ignore
@classmethod
@contextmanager
def apply_lora_to_unet(cls, unet: UNet2DConditionModel, loras: Iterator[Tuple[LoRAModelRaw, float]]):
try:
# TODO(ryand): Test speed of low_cpu_mem_usage=True.
for lora, lora_weight in loras:
LoraLoaderMixin.load_lora_into_unet(
state_dict=lora.state_dict,
network_alphas=lora.network_alphas,
unet=unet,
low_cpu_mem_usage=True,
adapter_name=lora.name,
_pipeline=None,
)
yield
finally:
cls.unload_lora_from_model(unet)
@classmethod
@contextmanager
def apply_lora_to_text_encoder(
cls, text_encoder: CLIPTextModel, loras: Iterator[Tuple[LoRAModelRaw, float]], prefix: str
):
assert prefix in ["text_encoder", "text_encoder_2"]
try:
for lora, lora_weight in loras:
# Filter the state_dict to only include the keys that start with the prefix.
text_encoder_state_dict = {
key: value for key, value in lora.state_dict.items() if key.startswith(prefix + ".")
}
if len(text_encoder_state_dict) > 0:
LoraLoaderMixin.load_lora_into_text_encoder(
state_dict=text_encoder_state_dict,
network_alphas=lora.network_alphas,
text_encoder=text_encoder,
low_cpu_mem_usage=True,
adapter_name=lora.name,
_pipeline=None,
)
yield
finally:
cls.unload_lora_from_model(text_encoder)

View File

@ -0,0 +1,66 @@
from pathlib import Path
from typing import Optional, Union
import torch
from diffusers.loaders.lora import LoraLoaderMixin
from typing_extensions import Self
class LoRAModelRaw:
def __init__(
self,
name: str,
state_dict: dict[str, torch.Tensor],
network_alphas: Optional[dict[str, float]],
):
self._name = name
self.state_dict = state_dict
self.network_alphas = network_alphas
@property
def name(self) -> str:
return self._name
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
for key, layer in self.state_dict.items():
self.state_dict[key] = layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
"""Calculate the size of the model in bytes."""
model_size = 0
for layer in self.state_dict.values():
model_size += layer.numel() * layer.element_size()
return model_size
@classmethod
def from_checkpoint(
cls, file_path: Union[str, Path], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
) -> Self:
"""This function is based on diffusers LoraLoaderMixin.load_lora_weights()."""
file_path = Path(file_path)
if file_path.is_dir():
raise NotImplementedError("LoRA models from directories are not yet supported.")
dir_path = file_path.parent
file_name = file_path.name
state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(
pretrained_model_name_or_path_or_dict=str(file_path), local_files_only=True, weight_name=str(file_name)
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
model = cls(
# TODO(ryand): Handle both files and directories here?
name=Path(file_path).stem,
state_dict=state_dict,
network_alphas=network_alphas,
)
device = device or torch.device("cpu")
dtype = dtype or torch.float32
model.to(device=device, dtype=dtype)
return model

View File

@ -11,8 +11,6 @@ from typing_extensions import Self
from invokeai.backend.model_manager import BaseModelType
from .raw_model import RawModel
class LoRALayerBase:
# rank: Optional[int]
@ -368,7 +366,7 @@ class IA3Layer(LoRALayerBase):
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
class LoRAModelRaw(RawModel): # (torch.nn.Module):
class LoRAModelRaw(torch.nn.Module):
_name: str
layers: Dict[str, AnyLoRALayer]

View File

@ -31,12 +31,13 @@ from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string
from ..raw_model import RawModel
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora_model_raw import LoRAModelRaw
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
AnyModel = Union[ModelMixin, torch.nn.Module, IPAdapter, LoRAModelRaw, TextualInversionModelRaw, IAIOnnxRuntimeModel]
class InvalidModelConfigException(Exception):

View File

@ -6,7 +6,7 @@ from pathlib import Path
from typing import Optional, Tuple
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@ -51,7 +51,6 @@ class LoRALoader(ModelLoader):
model = LoRAModelRaw.from_checkpoint(
file_path=model_path,
dtype=self._torch_dtype,
base_model=self._model_base,
)
return model

View File

@ -17,7 +17,7 @@ from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from .lora import LoRAModelRaw
from .lora_model_raw import LoRAModelRaw
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
"""

View File

@ -6,17 +6,16 @@ from typing import Any, List, Optional, Tuple, Union
import numpy as np
import onnx
import torch
from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from ..raw_model import RawModel
ONNX_WEIGHTS_NAME = "model.onnx"
# NOTE FROM LS: This was copied from Stalker's original implementation.
# I have not yet gone through and fixed all the type hints
class IAIOnnxRuntimeModel(RawModel):
class IAIOnnxRuntimeModel(torch.nn.Module):
class _tensor_access:
def __init__(self, model): # type: ignore
self.model = model

View File

@ -1,15 +0,0 @@
"""Base class for 'Raw' models.
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
and is used for type checking of calls to the model patcher. Its main purpose
is to avoid a circular import issues when lora.py tries to import BaseModelType
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
from lora.py.
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
that adds additional methods and attributes.
"""
class RawModel:
"""Base class for 'Raw' model wrappers."""

View File

@ -9,10 +9,8 @@ from safetensors.torch import load_file
from transformers import CLIPTokenizer
from typing_extensions import Self
from .raw_model import RawModel
class TextualInversionModelRaw(RawModel):
class TextualInversionModelRaw(torch.nn.Module):
embedding: torch.Tensor # [n, 768]|[n, 1280]
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models

View File

@ -44,6 +44,7 @@ dependencies = [
"onnx==1.15.0",
"onnxruntime==1.16.3",
"opencv-python==4.9.0.80",
"peft==0.9.0",
"pytorch-lightning==2.1.3",
"safetensors==0.4.2",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
@ -73,7 +74,7 @@ dependencies = [
"easing-functions",
"einops",
"facexlib",
"matplotlib", # needed for plotting of Penner easing functions
"matplotlib", # needed for plotting of Penner easing functions
"npyscreen",
"omegaconf",
"picklescan",

View File

@ -5,7 +5,7 @@
import pytest
import torch
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
from invokeai.backend.lora_model_raw import LoRALayer, LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher