mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Make lora as separate extensions
This commit is contained in:
parent
46c632e7cc
commit
faa88f72bf
@ -80,12 +80,12 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
with (
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder,
|
||||
loras=_lora_loader(),
|
||||
model_state_dict=model_state_dict,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
||||
@ -175,13 +175,13 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
with (
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (state_dict, text_encoder),
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
ModelPatcher.apply_lora(
|
||||
text_encoder,
|
||||
loras=_lora_loader(),
|
||||
prefix=lora_prefix,
|
||||
model_state_dict=state_dict,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
||||
|
@ -60,7 +60,7 @@ from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionB
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
|
||||
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
||||
from invokeai.backend.stable_diffusion.extensions.lora_patcher import LoRAPatcherExt
|
||||
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||
@ -836,13 +836,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
### lora
|
||||
if self.unet.loras:
|
||||
ext_manager.add_extension(
|
||||
LoRAPatcherExt(
|
||||
node_context=context,
|
||||
loras=self.unet.loras,
|
||||
prefix="lora_unet_",
|
||||
for lora_field in self.unet.loras:
|
||||
ext_manager.add_extension(
|
||||
LoRAExt(
|
||||
node_context=context,
|
||||
model_id=lora_field.lora,
|
||||
weight=lora_field.weight,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# context for loading additional models
|
||||
with ExitStack() as exit_stack:
|
||||
@ -924,14 +925,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
unet_info.model_on_device() as (model_state_dict, unet),
|
||||
unet_info.model_on_device() as (cached_weights, unet),
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_unet(
|
||||
unet,
|
||||
loras=_lora_loader(),
|
||||
model_state_dict=model_state_dict,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -17,8 +17,8 @@ from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
"""
|
||||
loras = [
|
||||
@ -85,13 +85,13 @@ class ModelPatcher:
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
with cls.apply_lora(
|
||||
unet,
|
||||
loras=loras,
|
||||
prefix="lora_unet_",
|
||||
model_state_dict=model_state_dict,
|
||||
cached_weights=cached_weights,
|
||||
):
|
||||
yield
|
||||
|
||||
@ -101,9 +101,9 @@ class ModelPatcher:
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
|
||||
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@ -113,7 +113,7 @@ class ModelPatcher:
|
||||
model: AnyModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Apply one or more LoRAs to a model.
|
||||
@ -121,66 +121,37 @@ class ModelPatcher:
|
||||
:param model: The model to patch.
|
||||
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||
"""
|
||||
original_weights = {}
|
||||
modified_cached_weights: Set[str] = set()
|
||||
modified_weights: Dict[str, torch.Tensor] = {}
|
||||
try:
|
||||
with torch.no_grad():
|
||||
for lora, lora_weight in loras:
|
||||
# assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
for lora_model, lora_weight in loras:
|
||||
lora_modified_cached_weights, lora_modified_weights = LoRAExt.patch_model(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
lora=lora_model,
|
||||
lora_weight=lora_weight,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
del lora_model
|
||||
|
||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
||||
# should be improved in the following ways:
|
||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
||||
# LoRA model is applied.
|
||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
||||
# weights to have valid keys.
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||
modified_cached_weights.update(lora_modified_cached_weights)
|
||||
# Store only first returned weight for each key, because
|
||||
# next extension which changes it, will work with already modified weight
|
||||
for param_key, weight in lora_modified_weights.items():
|
||||
if param_key in modified_weights:
|
||||
continue
|
||||
modified_weights[param_key] = weight
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
|
||||
if module_key not in original_weights:
|
||||
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
|
||||
original_weights[module_key] = model_state_dict[module_key + ".weight"]
|
||||
else:
|
||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
assert hasattr(layer_weight, "reshape")
|
||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
module.weight += layer_weight.to(dtype=dtype)
|
||||
|
||||
yield # wait for context manager exit
|
||||
yield
|
||||
|
||||
finally:
|
||||
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
||||
with torch.no_grad():
|
||||
for module_key, weight in original_weights.items():
|
||||
model.get_submodule(module_key).weight.copy_(weight)
|
||||
for param_key in modified_cached_weights:
|
||||
model.get_parameter(param_key).copy_(cached_weights[param_key])
|
||||
for param_key, weight in modified_weights.items():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
|
145
invokeai/backend/stable_diffusion/extensions/lora.py
Normal file
145
invokeai/backend/stable_diffusion/extensions/lora.py
Normal file
@ -0,0 +1,145 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
|
||||
|
||||
class LoRAExt(ExtensionBase):
|
||||
def __init__(
|
||||
self,
|
||||
node_context: InvocationContext,
|
||||
model_id: ModelIdentifierField,
|
||||
weight: float,
|
||||
):
|
||||
super().__init__()
|
||||
self._node_context = node_context
|
||||
self._model_id = model_id
|
||||
self._weight = weight
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||
lora_model = self._node_context.models.load(self._model_id).model
|
||||
modified_cached_weights, modified_weights = self.patch_model(
|
||||
model=unet,
|
||||
prefix="lora_unet_",
|
||||
lora=lora_model,
|
||||
lora_weight=self._weight,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
del lora_model
|
||||
|
||||
yield modified_cached_weights, modified_weights
|
||||
|
||||
@classmethod
|
||||
def patch_model(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
lora: LoRAModelRaw,
|
||||
lora_weight: float,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""
|
||||
Apply one or more LoRAs to a model.
|
||||
:param model: The model to patch.
|
||||
:param lora: LoRA model to patch in.
|
||||
:param lora_weight: LoRA patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||
"""
|
||||
if cached_weights is None:
|
||||
cached_weights = {}
|
||||
|
||||
modified_weights: Dict[str, torch.Tensor] = {}
|
||||
modified_cached_weights: Set[str] = set()
|
||||
with torch.no_grad():
|
||||
# assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
||||
# should be improved in the following ways:
|
||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
||||
# LoRA model is applied.
|
||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
||||
# weights to have valid keys.
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||
param_key = module_key + "." + param_name
|
||||
module_param = module.get_parameter(param_name)
|
||||
|
||||
# save original weight
|
||||
if param_key not in modified_cached_weights and param_key not in modified_weights:
|
||||
if param_key in cached_weights:
|
||||
modified_cached_weights.add(param_key)
|
||||
else:
|
||||
modified_weights[param_key] = module_param.detach().to(
|
||||
device=TorchDevice.CPU_DEVICE, copy=True
|
||||
)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||
|
||||
lora_param_weight *= lora_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
return modified_cached_weights, modified_weights
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
|
||||
if not lora_key.startswith(prefix):
|
||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = lora_key[len(prefix) :].split("_")
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return (module_key, module)
|
@ -1,172 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.model import LoRAField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
|
||||
|
||||
class LoRAPatcherExt(ExtensionBase):
|
||||
def __init__(
|
||||
self,
|
||||
node_context: InvocationContext,
|
||||
loras: List[LoRAField],
|
||||
prefix: str,
|
||||
):
|
||||
super().__init__()
|
||||
self._loras = loras
|
||||
self._prefix = prefix
|
||||
self._node_context = node_context
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self._loras:
|
||||
lora_info = self._node_context.models.load(lora.lora)
|
||||
lora_model = lora_info.model
|
||||
yield (lora_model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
yield self._patch_model(
|
||||
model=unet,
|
||||
prefix=self._prefix,
|
||||
loras=_lora_loader(),
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def static_patch_model(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
modified_cached_weights, modified_weights = cls._patch_model(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
loras=loras,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
try:
|
||||
yield
|
||||
|
||||
finally:
|
||||
with torch.no_grad():
|
||||
for param_key in modified_cached_weights:
|
||||
model.get_parameter(param_key).copy_(cached_weights[param_key])
|
||||
for param_key, weight in modified_weights.items():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
@classmethod
|
||||
def _patch_model(
|
||||
cls,
|
||||
model: UNet2DConditionModel,
|
||||
prefix: str,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""
|
||||
Apply one or more LoRAs to a model.
|
||||
:param model: The model to patch.
|
||||
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||
"""
|
||||
if cached_weights is None:
|
||||
cached_weights = {}
|
||||
|
||||
modified_weights = {}
|
||||
modified_cached_weights = set()
|
||||
with torch.no_grad():
|
||||
for lora, lora_weight in loras:
|
||||
# assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
||||
# should be improved in the following ways:
|
||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
||||
# LoRA model is applied.
|
||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
||||
# weights to have valid keys.
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||
param_key = module_key + "." + param_name
|
||||
module_param = module.get_parameter(param_name)
|
||||
|
||||
# save original weight
|
||||
if param_key not in modified_cached_weights and param_key not in modified_weights:
|
||||
if param_key in cached_weights:
|
||||
modified_cached_weights.add(param_key)
|
||||
else:
|
||||
modified_weights[param_key] = module_param.detach().to(
|
||||
device=TorchDevice.CPU_DEVICE, copy=True
|
||||
)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||
|
||||
lora_param_weight *= lora_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
return modified_cached_weights, modified_weights
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
|
||||
if not lora_key.startswith(prefix):
|
||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = lora_key[len(prefix) :].split("_")
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return (module_key, module)
|
Loading…
Reference in New Issue
Block a user