mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
Consolidate all LoRA patching logic in the LoRAPatcher.
This commit is contained in:
@ -20,6 +20,7 @@ from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
@ -81,9 +82,10 @@ class CompelInvocation(BaseInvocation):
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder,
|
||||
loras=_lora_loader(),
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
model=text_encoder,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_te_",
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
@ -176,9 +178,9 @@ class SDXLPromptInvocationBase:
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
ModelPatcher.apply_lora(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
text_encoder,
|
||||
loras=_lora_loader(),
|
||||
patches=_lora_loader(),
|
||||
prefix=lora_prefix,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
|
@ -37,6 +37,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||
@ -979,9 +980,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_unet(
|
||||
unet,
|
||||
loras=_lora_loader(),
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
model=unet,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_unet_",
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
):
|
||||
|
@ -23,7 +23,7 @@ from invokeai.app.invocations.model import UNetField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
|
||||
MultiDiffusionPipeline,
|
||||
@ -204,7 +204,11 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
# Load the UNet model.
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
|
||||
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
unet_info as unet,
|
||||
LoRAPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
if noise is not None:
|
||||
|
@ -1,5 +1,5 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Iterator, Optional, Tuple
|
||||
from typing import Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -9,17 +9,16 @@ from invokeai.backend.util.original_weights_storage import OriginalWeightsStorag
|
||||
|
||||
|
||||
class LoRAPatcher:
|
||||
@classmethod
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_lora_patches(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
patches: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Apply one or more LoRA patches to a model.
|
||||
"""Apply one or more LoRA patches to a model within a context manager.
|
||||
|
||||
:param model: The model to patch.
|
||||
:param loras: An iterator that returns tuples of LoRA patches and associated weights. An iterator is used so
|
||||
@ -30,23 +29,23 @@ class LoRAPatcher:
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
cls._apply_lora_patch(
|
||||
LoRAPatcher.apply_lora_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
del patch
|
||||
|
||||
yield
|
||||
finally:
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
@classmethod
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _apply_lora_patch(
|
||||
cls,
|
||||
def apply_lora_patch(
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: LoRAModelRaw,
|
||||
@ -54,7 +53,7 @@ class LoRAPatcher:
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
"""
|
||||
Apply one a LoRA to a model.
|
||||
Apply a single LoRA patch to a model.
|
||||
:param model: The model to patch.
|
||||
:param patch: LoRA model to patch in.
|
||||
:param patch_weight: LoRA patch weight.
|
||||
@ -65,11 +64,21 @@ class LoRAPatcher:
|
||||
if patch_weight == 0:
|
||||
return
|
||||
|
||||
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
|
||||
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
||||
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
||||
# without searching, but some legacy code still uses flattened keys.
|
||||
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
||||
|
||||
prefix_len = len(prefix)
|
||||
|
||||
for layer_key, layer in patch.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module = model.get_submodule(layer_key)
|
||||
module_key, module = LoRAPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
@ -87,7 +96,7 @@ class LoRAPatcher:
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||
param_key = layer_key + "." + param_name
|
||||
param_key = module_key + "." + param_name
|
||||
module_param = module.get_parameter(param_name)
|
||||
|
||||
# Save original weight
|
||||
@ -100,3 +109,40 @@ class LoRAPatcher:
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
@staticmethod
|
||||
def _get_submodule(
|
||||
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool
|
||||
) -> tuple[str, torch.nn.Module]:
|
||||
"""Get the submodule corresponding to the given layer key.
|
||||
:param model: The model to search.
|
||||
:param layer_key: The layer key to search for.
|
||||
:param layer_key_is_flattened: Whether the layer key is flattened. If flattened, then all '.' have been replaced
|
||||
with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly without
|
||||
searching, but some legacy code still uses flattened keys.
|
||||
:return: A tuple containing the module key and the submodule.
|
||||
"""
|
||||
if not layer_key_is_flattened:
|
||||
return layer_key, model.get_submodule(layer_key)
|
||||
|
||||
# Handle flattened keys.
|
||||
assert "." not in layer_key
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = layer_key.split("_")
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return module_key, module
|
||||
|
@ -5,32 +5,18 @@ from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
|
||||
from diffusers import UNet2DConditionModel
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
"""
|
||||
loras = [
|
||||
(lora_model1, 0.7),
|
||||
(lora_model2, 0.4),
|
||||
]
|
||||
with LoRAHelper.apply_lora_unet(unet, loras):
|
||||
# unet with applied loras
|
||||
# unmodified unet
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
@ -54,95 +40,6 @@ class ModelPatcher:
|
||||
finally:
|
||||
unet.set_attn_processor(unet_orig_processors)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
|
||||
if not lora_key.startswith(prefix):
|
||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = lora_key[len(prefix) :].split("_")
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return (module_key, module)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
with cls.apply_lora(
|
||||
unet,
|
||||
loras=loras,
|
||||
prefix="lora_unet_",
|
||||
cached_weights=cached_weights,
|
||||
):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: AnyModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Apply one or more LoRAs to a model.
|
||||
|
||||
:param model: The model to patch.
|
||||
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||
"""
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
for lora_model, lora_weight in loras:
|
||||
LoRAExt.patch_model(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
lora=lora_model,
|
||||
lora_weight=lora_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
del lora_model
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
with torch.no_grad():
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_ti(
|
||||
@ -282,26 +179,6 @@ class ModelPatcher:
|
||||
|
||||
|
||||
class ONNXModelPatcher:
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: OnnxRuntimeModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
|
||||
# based on
|
||||
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
|
||||
@classmethod
|
||||
|
@ -1,14 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
@ -31,107 +30,14 @@ class LoRAExt(ExtensionBase):
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||
lora_model = self._node_context.models.load(self._model_id).model
|
||||
self.patch_model(
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
LoRAPatcher.apply_lora_patch(
|
||||
model=unet,
|
||||
prefix="lora_unet_",
|
||||
lora=lora_model,
|
||||
lora_weight=self._weight,
|
||||
patch=lora_model,
|
||||
patch_weight=self._weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
del lora_model
|
||||
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@torch.no_grad()
|
||||
def patch_model(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
lora: LoRAModelRaw,
|
||||
lora_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
"""
|
||||
Apply one or more LoRAs to a model.
|
||||
:param model: The model to patch.
|
||||
:param lora: LoRA model to patch in.
|
||||
:param lora_weight: LoRA patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
|
||||
"""
|
||||
|
||||
if lora_weight == 0:
|
||||
return
|
||||
|
||||
# assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
||||
# should be improved in the following ways:
|
||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
||||
# LoRA model is applied.
|
||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
||||
# weights to have valid keys.
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||
param_key = module_key + "." + param_name
|
||||
module_param = module.get_parameter(param_name)
|
||||
|
||||
# save original weight
|
||||
original_weights.save(param_key, module_param)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||
|
||||
lora_param_weight *= lora_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
|
||||
if not lora_key.startswith(prefix):
|
||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = lora_key[len(prefix) :].split("_")
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return (module_key, module)
|
||||
|
@ -3,7 +3,7 @@ import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -41,7 +41,7 @@ def test_apply_lora(device: str):
|
||||
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
|
||||
expected_patched_linear_weight = orig_linear_weight + (lora_dim * lora_weight)
|
||||
|
||||
with ModelPatcher.apply_lora(model, [(lora, lora_weight)], prefix=""):
|
||||
with LoRAPatcher.apply_lora_patches(model=model, patches=[(lora, lora_weight)], prefix=""):
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
|
||||
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
|
||||
@ -83,7 +83,7 @@ def test_apply_lora_change_device():
|
||||
|
||||
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
|
||||
|
||||
with ModelPatcher.apply_lora(model, [(lora, 0.5)], prefix=""):
|
||||
with LoRAPatcher.apply_lora_patches(model=model, patches=[(lora, 0.5)], prefix=""):
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
|
||||
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
|
Reference in New Issue
Block a user