mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Make lora as separate extensions
This commit is contained in:
parent
46c632e7cc
commit
faa88f72bf
@ -80,12 +80,12 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
# apply all patches while the model is on the target device
|
# apply all patches while the model is on the target device
|
||||||
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
|
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||||
tokenizer_info as tokenizer,
|
tokenizer_info as tokenizer,
|
||||||
ModelPatcher.apply_lora_text_encoder(
|
ModelPatcher.apply_lora_text_encoder(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
loras=_lora_loader(),
|
loras=_lora_loader(),
|
||||||
model_state_dict=model_state_dict,
|
cached_weights=cached_weights,
|
||||||
),
|
),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
||||||
@ -175,13 +175,13 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
# apply all patches while the model is on the target device
|
# apply all patches while the model is on the target device
|
||||||
text_encoder_info.model_on_device() as (state_dict, text_encoder),
|
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||||
tokenizer_info as tokenizer,
|
tokenizer_info as tokenizer,
|
||||||
ModelPatcher.apply_lora(
|
ModelPatcher.apply_lora(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
loras=_lora_loader(),
|
loras=_lora_loader(),
|
||||||
prefix=lora_prefix,
|
prefix=lora_prefix,
|
||||||
model_state_dict=state_dict,
|
cached_weights=cached_weights,
|
||||||
),
|
),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
||||||
|
@ -60,7 +60,7 @@ from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionB
|
|||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
|
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.lora_patcher import LoRAPatcherExt
|
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
||||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
@ -836,13 +836,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
### lora
|
### lora
|
||||||
if self.unet.loras:
|
if self.unet.loras:
|
||||||
ext_manager.add_extension(
|
for lora_field in self.unet.loras:
|
||||||
LoRAPatcherExt(
|
ext_manager.add_extension(
|
||||||
node_context=context,
|
LoRAExt(
|
||||||
loras=self.unet.loras,
|
node_context=context,
|
||||||
prefix="lora_unet_",
|
model_id=lora_field.lora,
|
||||||
|
weight=lora_field.weight,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# context for loading additional models
|
# context for loading additional models
|
||||||
with ExitStack() as exit_stack:
|
with ExitStack() as exit_stack:
|
||||||
@ -924,14 +925,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
unet_info.model_on_device() as (model_state_dict, unet),
|
unet_info.model_on_device() as (cached_weights, unet),
|
||||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||||
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora_unet(
|
ModelPatcher.apply_lora_unet(
|
||||||
unet,
|
unet,
|
||||||
loras=_lora_loader(),
|
loras=_lora_loader(),
|
||||||
model_state_dict=model_state_dict,
|
cached_weights=cached_weights,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
assert isinstance(unet, UNet2DConditionModel)
|
assert isinstance(unet, UNet2DConditionModel)
|
||||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
|
from typing import Any, Dict, Generator, Iterator, List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -17,8 +17,8 @@ from invokeai.backend.lora import LoRAModelRaw
|
|||||||
from invokeai.backend.model_manager import AnyModel
|
from invokeai.backend.model_manager import AnyModel
|
||||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
loras = [
|
loras = [
|
||||||
@ -85,13 +85,13 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
with cls.apply_lora(
|
with cls.apply_lora(
|
||||||
unet,
|
unet,
|
||||||
loras=loras,
|
loras=loras,
|
||||||
prefix="lora_unet_",
|
prefix="lora_unet_",
|
||||||
model_state_dict=model_state_dict,
|
cached_weights=cached_weights,
|
||||||
):
|
):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@ -101,9 +101,9 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
|
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -113,7 +113,7 @@ class ModelPatcher:
|
|||||||
model: AnyModel,
|
model: AnyModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
"""
|
"""
|
||||||
Apply one or more LoRAs to a model.
|
Apply one or more LoRAs to a model.
|
||||||
@ -121,66 +121,37 @@ class ModelPatcher:
|
|||||||
:param model: The model to patch.
|
:param model: The model to patch.
|
||||||
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
||||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||||
:model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||||
"""
|
"""
|
||||||
original_weights = {}
|
modified_cached_weights: Set[str] = set()
|
||||||
|
modified_weights: Dict[str, torch.Tensor] = {}
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
for lora_model, lora_weight in loras:
|
||||||
for lora, lora_weight in loras:
|
lora_modified_cached_weights, lora_modified_weights = LoRAExt.patch_model(
|
||||||
# assert lora.device.type == "cpu"
|
model=model,
|
||||||
for layer_key, layer in lora.layers.items():
|
prefix=prefix,
|
||||||
if not layer_key.startswith(prefix):
|
lora=lora_model,
|
||||||
continue
|
lora_weight=lora_weight,
|
||||||
|
cached_weights=cached_weights,
|
||||||
|
)
|
||||||
|
del lora_model
|
||||||
|
|
||||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
modified_cached_weights.update(lora_modified_cached_weights)
|
||||||
# should be improved in the following ways:
|
# Store only first returned weight for each key, because
|
||||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
# next extension which changes it, will work with already modified weight
|
||||||
# LoRA model is applied.
|
for param_key, weight in lora_modified_weights.items():
|
||||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
if param_key in modified_weights:
|
||||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
continue
|
||||||
# weights to have valid keys.
|
modified_weights[param_key] = weight
|
||||||
assert isinstance(model, torch.nn.Module)
|
|
||||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
|
||||||
|
|
||||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
yield
|
||||||
# (Performance will be best if this is a CUDA device.)
|
|
||||||
device = module.weight.device
|
|
||||||
dtype = module.weight.dtype
|
|
||||||
|
|
||||||
if module_key not in original_weights:
|
|
||||||
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
|
|
||||||
original_weights[module_key] = model_state_dict[module_key + ".weight"]
|
|
||||||
else:
|
|
||||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
|
||||||
|
|
||||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
|
||||||
|
|
||||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
|
||||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
|
||||||
# same thing in a single call to '.to(...)'.
|
|
||||||
layer.to(device=device)
|
|
||||||
layer.to(dtype=torch.float32)
|
|
||||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
|
||||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
|
||||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
|
||||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
|
||||||
|
|
||||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
|
||||||
if module.weight.shape != layer_weight.shape:
|
|
||||||
# TODO: debug on lycoris
|
|
||||||
assert hasattr(layer_weight, "reshape")
|
|
||||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
|
||||||
|
|
||||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
|
||||||
module.weight += layer_weight.to(dtype=dtype)
|
|
||||||
|
|
||||||
yield # wait for context manager exit
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for module_key, weight in original_weights.items():
|
for param_key in modified_cached_weights:
|
||||||
model.get_submodule(module_key).weight.copy_(weight)
|
model.get_parameter(param_key).copy_(cached_weights[param_key])
|
||||||
|
for param_key, weight in modified_weights.items():
|
||||||
|
model.get_parameter(param_key).copy_(weight)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
145
invokeai/backend/stable_diffusion/extensions/lora.py
Normal file
145
invokeai/backend/stable_diffusion/extensions/lora.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAExt(ExtensionBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
node_context: InvocationContext,
|
||||||
|
model_id: ModelIdentifierField,
|
||||||
|
weight: float,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._node_context = node_context
|
||||||
|
self._model_id = model_id
|
||||||
|
self._weight = weight
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||||
|
lora_model = self._node_context.models.load(self._model_id).model
|
||||||
|
modified_cached_weights, modified_weights = self.patch_model(
|
||||||
|
model=unet,
|
||||||
|
prefix="lora_unet_",
|
||||||
|
lora=lora_model,
|
||||||
|
lora_weight=self._weight,
|
||||||
|
cached_weights=cached_weights,
|
||||||
|
)
|
||||||
|
del lora_model
|
||||||
|
|
||||||
|
yield modified_cached_weights, modified_weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def patch_model(
|
||||||
|
cls,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
prefix: str,
|
||||||
|
lora: LoRAModelRaw,
|
||||||
|
lora_weight: float,
|
||||||
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Apply one or more LoRAs to a model.
|
||||||
|
:param model: The model to patch.
|
||||||
|
:param lora: LoRA model to patch in.
|
||||||
|
:param lora_weight: LoRA patch weight.
|
||||||
|
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||||
|
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||||
|
"""
|
||||||
|
if cached_weights is None:
|
||||||
|
cached_weights = {}
|
||||||
|
|
||||||
|
modified_weights: Dict[str, torch.Tensor] = {}
|
||||||
|
modified_cached_weights: Set[str] = set()
|
||||||
|
with torch.no_grad():
|
||||||
|
# assert lora.device.type == "cpu"
|
||||||
|
for layer_key, layer in lora.layers.items():
|
||||||
|
if not layer_key.startswith(prefix):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
||||||
|
# should be improved in the following ways:
|
||||||
|
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
||||||
|
# LoRA model is applied.
|
||||||
|
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
||||||
|
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
||||||
|
# weights to have valid keys.
|
||||||
|
assert isinstance(model, torch.nn.Module)
|
||||||
|
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||||
|
|
||||||
|
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||||
|
# (Performance will be best if this is a CUDA device.)
|
||||||
|
device = module.weight.device
|
||||||
|
dtype = module.weight.dtype
|
||||||
|
|
||||||
|
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||||
|
|
||||||
|
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||||
|
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||||
|
# same thing in a single call to '.to(...)'.
|
||||||
|
layer.to(device=device)
|
||||||
|
layer.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||||
|
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||||
|
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||||
|
param_key = module_key + "." + param_name
|
||||||
|
module_param = module.get_parameter(param_name)
|
||||||
|
|
||||||
|
# save original weight
|
||||||
|
if param_key not in modified_cached_weights and param_key not in modified_weights:
|
||||||
|
if param_key in cached_weights:
|
||||||
|
modified_cached_weights.add(param_key)
|
||||||
|
else:
|
||||||
|
modified_weights[param_key] = module_param.detach().to(
|
||||||
|
device=TorchDevice.CPU_DEVICE, copy=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if module_param.shape != lora_param_weight.shape:
|
||||||
|
# TODO: debug on lycoris
|
||||||
|
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||||
|
|
||||||
|
lora_param_weight *= lora_weight * layer_scale
|
||||||
|
module_param += lora_param_weight.to(dtype=dtype)
|
||||||
|
|
||||||
|
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||||
|
|
||||||
|
return modified_cached_weights, modified_weights
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||||
|
assert "." not in lora_key
|
||||||
|
|
||||||
|
if not lora_key.startswith(prefix):
|
||||||
|
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
||||||
|
|
||||||
|
module = model
|
||||||
|
module_key = ""
|
||||||
|
key_parts = lora_key[len(prefix) :].split("_")
|
||||||
|
|
||||||
|
submodule_name = key_parts.pop(0)
|
||||||
|
|
||||||
|
while len(key_parts) > 0:
|
||||||
|
try:
|
||||||
|
module = module.get_submodule(submodule_name)
|
||||||
|
module_key += "." + submodule_name
|
||||||
|
submodule_name = key_parts.pop(0)
|
||||||
|
except Exception:
|
||||||
|
submodule_name += "_" + key_parts.pop(0)
|
||||||
|
|
||||||
|
module = module.get_submodule(submodule_name)
|
||||||
|
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||||
|
|
||||||
|
return (module_key, module)
|
@ -1,172 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from diffusers import UNet2DConditionModel
|
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from invokeai.app.invocations.model import LoRAField
|
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAPatcherExt(ExtensionBase):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
node_context: InvocationContext,
|
|
||||||
loras: List[LoRAField],
|
|
||||||
prefix: str,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self._loras = loras
|
|
||||||
self._prefix = prefix
|
|
||||||
self._node_context = node_context
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
|
||||||
for lora in self._loras:
|
|
||||||
lora_info = self._node_context.models.load(lora.lora)
|
|
||||||
lora_model = lora_info.model
|
|
||||||
yield (lora_model, lora.weight)
|
|
||||||
del lora_info
|
|
||||||
return
|
|
||||||
|
|
||||||
yield self._patch_model(
|
|
||||||
model=unet,
|
|
||||||
prefix=self._prefix,
|
|
||||||
loras=_lora_loader(),
|
|
||||||
cached_weights=cached_weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def static_patch_model(
|
|
||||||
cls,
|
|
||||||
model: torch.nn.Module,
|
|
||||||
prefix: str,
|
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
|
||||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
|
||||||
):
|
|
||||||
modified_cached_weights, modified_weights = cls._patch_model(
|
|
||||||
model=model,
|
|
||||||
prefix=prefix,
|
|
||||||
loras=loras,
|
|
||||||
cached_weights=cached_weights,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
|
|
||||||
finally:
|
|
||||||
with torch.no_grad():
|
|
||||||
for param_key in modified_cached_weights:
|
|
||||||
model.get_parameter(param_key).copy_(cached_weights[param_key])
|
|
||||||
for param_key, weight in modified_weights.items():
|
|
||||||
model.get_parameter(param_key).copy_(weight)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _patch_model(
|
|
||||||
cls,
|
|
||||||
model: UNet2DConditionModel,
|
|
||||||
prefix: str,
|
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
|
||||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Apply one or more LoRAs to a model.
|
|
||||||
:param model: The model to patch.
|
|
||||||
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
|
||||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
|
||||||
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
|
||||||
"""
|
|
||||||
if cached_weights is None:
|
|
||||||
cached_weights = {}
|
|
||||||
|
|
||||||
modified_weights = {}
|
|
||||||
modified_cached_weights = set()
|
|
||||||
with torch.no_grad():
|
|
||||||
for lora, lora_weight in loras:
|
|
||||||
# assert lora.device.type == "cpu"
|
|
||||||
for layer_key, layer in lora.layers.items():
|
|
||||||
if not layer_key.startswith(prefix):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
|
||||||
# should be improved in the following ways:
|
|
||||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
|
||||||
# LoRA model is applied.
|
|
||||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
|
||||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
|
||||||
# weights to have valid keys.
|
|
||||||
assert isinstance(model, torch.nn.Module)
|
|
||||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
|
||||||
|
|
||||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
|
||||||
# (Performance will be best if this is a CUDA device.)
|
|
||||||
device = module.weight.device
|
|
||||||
dtype = module.weight.dtype
|
|
||||||
|
|
||||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
|
||||||
|
|
||||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
|
||||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
|
||||||
# same thing in a single call to '.to(...)'.
|
|
||||||
layer.to(device=device)
|
|
||||||
layer.to(dtype=torch.float32)
|
|
||||||
|
|
||||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
|
||||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
|
||||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
|
||||||
param_key = module_key + "." + param_name
|
|
||||||
module_param = module.get_parameter(param_name)
|
|
||||||
|
|
||||||
# save original weight
|
|
||||||
if param_key not in modified_cached_weights and param_key not in modified_weights:
|
|
||||||
if param_key in cached_weights:
|
|
||||||
modified_cached_weights.add(param_key)
|
|
||||||
else:
|
|
||||||
modified_weights[param_key] = module_param.detach().to(
|
|
||||||
device=TorchDevice.CPU_DEVICE, copy=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if module_param.shape != lora_param_weight.shape:
|
|
||||||
# TODO: debug on lycoris
|
|
||||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
|
||||||
|
|
||||||
lora_param_weight *= lora_weight * layer_scale
|
|
||||||
module_param += lora_param_weight.to(dtype=dtype)
|
|
||||||
|
|
||||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
|
||||||
|
|
||||||
return modified_cached_weights, modified_weights
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
|
||||||
assert "." not in lora_key
|
|
||||||
|
|
||||||
if not lora_key.startswith(prefix):
|
|
||||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
|
||||||
|
|
||||||
module = model
|
|
||||||
module_key = ""
|
|
||||||
key_parts = lora_key[len(prefix) :].split("_")
|
|
||||||
|
|
||||||
submodule_name = key_parts.pop(0)
|
|
||||||
|
|
||||||
while len(key_parts) > 0:
|
|
||||||
try:
|
|
||||||
module = module.get_submodule(submodule_name)
|
|
||||||
module_key += "." + submodule_name
|
|
||||||
submodule_name = key_parts.pop(0)
|
|
||||||
except Exception:
|
|
||||||
submodule_name += "_" + key_parts.pop(0)
|
|
||||||
|
|
||||||
module = module.get_submodule(submodule_name)
|
|
||||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
|
||||||
|
|
||||||
return (module_key, module)
|
|
Loading…
Reference in New Issue
Block a user