mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Improve LoRA patching speed (#5017)
## What type of PR is this? (check all applicable) - [ ] Refactor - [ ] Feature - [ ] Bug Fix - [x] Optimization - [ ] Documentation Update - [ ] Community Node Submission ## Have you discussed this change with the InvokeAI team? - [x] Yes - [ ] No, because: ## Have you updated all relevant documentation? - [x] Yes - [ ] No ## Description Improve LoRA patching speed with the following changes: - Calculate LoRA layer weights on the same device as the target model. Prior to this change, weights were always calculated on the CPU. If the target model is on the GPU, this significantly improves performance. - Move models to their target devices _before_ applying LoRA patches. - Improve the ordering of Tensor copy / cast operations. ## QA Instructions, Screenshots, Recordings Tests: - [x] Tested with a CUDA GPU, saw savings of ~10secs with 1 LoRA applied to an SDXL model. - [x] No regression in CPU-only environment - [ ] No regression (and possible improvement?) on Mac with MPS. - [x] Weights get restored correctly after using a LoRA - [x] Stacking multiple LoRAs Please hammer away with a variety of LoRAs in case there is some edge case that I've missed. ## Added/updated tests? - [x] Yes (Added some minimal unit tests. Definitely would benefit from more, but it's a step in the right direction.) - [ ] No
This commit is contained in:
commit
a4a7b601a1
@ -108,13 +108,14 @@ class CompelInvocation(BaseInvocation):
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with (
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
|
||||
text_encoder_info as text_encoder,
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
):
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
@ -229,13 +230,14 @@ class SDXLPromptInvocationBase:
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with (
|
||||
ModelPatcher.apply_lora(text_encoder_info.context.model, _lora_loader(), lora_prefix),
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
|
||||
text_encoder_info as text_encoder,
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
):
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
|
@ -710,9 +710,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
|
||||
set_seamless(unet_info.context.model, self.unet.seamless_axes),
|
||||
unet_info as unet,
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||
):
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
if noise is not None:
|
||||
|
@ -54,24 +54,6 @@ class ModelPatcher:
|
||||
|
||||
return (module_key, module)
|
||||
|
||||
@staticmethod
|
||||
def _lora_forward_hook(
|
||||
applied_loras: List[Tuple[LoRAModel, float]],
|
||||
layer_name: str,
|
||||
):
|
||||
def lora_forward(module, input_h, output):
|
||||
if len(applied_loras) == 0:
|
||||
return output
|
||||
|
||||
for lora, weight in applied_loras:
|
||||
layer = lora.layers.get(layer_name, None)
|
||||
if layer is None:
|
||||
continue
|
||||
output += layer.forward(module, input_h, weight)
|
||||
return output
|
||||
|
||||
return lora_forward
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_unet(
|
||||
@ -129,21 +111,40 @@ class ModelPatcher:
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
||||
# should be improved in the following ways:
|
||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
||||
# LoRA model is applied.
|
||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
||||
# weights to have valid keys.
|
||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
|
||||
if module_key not in original_weights:
|
||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||
|
||||
# enable autocast to calc fp16 loras on cpu
|
||||
# with torch.autocast(device_type="cpu"):
|
||||
layer.to(dtype=torch.float32)
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||
layer.to(device="cpu")
|
||||
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||
|
||||
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)
|
||||
module.weight += layer_weight.to(dtype=dtype)
|
||||
|
||||
yield # wait for context manager exit
|
||||
|
||||
@ -196,7 +197,9 @@ class ModelPatcher:
|
||||
|
||||
if model_embeddings.weight.data[token_id].shape != embedding.shape:
|
||||
raise ValueError(
|
||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
|
||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
|
||||
f" {embedding.shape[0]}, but the current model has token dimension"
|
||||
f" {model_embeddings.weight.data[token_id].shape[0]}."
|
||||
)
|
||||
|
||||
model_embeddings.weight.data[token_id] = embedding.to(
|
||||
@ -257,7 +260,8 @@ class TextualInversionModel:
|
||||
if "string_to_param" in state_dict:
|
||||
if len(state_dict["string_to_param"]) > 1:
|
||||
print(
|
||||
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first token will be used.'
|
||||
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first'
|
||||
" token will be used."
|
||||
)
|
||||
|
||||
result.embedding = next(iter(state_dict["string_to_param"].values()))
|
||||
@ -470,7 +474,9 @@ class ONNXModelPatcher:
|
||||
|
||||
if embeddings[token_id].shape != embedding.shape:
|
||||
raise ValueError(
|
||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embeddings[token_id].shape[0]}."
|
||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
|
||||
f" {embedding.shape[0]}, but the current model has token dimension"
|
||||
f" {embeddings[token_id].shape[0]}."
|
||||
)
|
||||
|
||||
embeddings[token_id] = embedding
|
||||
|
@ -440,33 +440,19 @@ class IA3Layer(LoRALayerBase):
|
||||
class LoRAModelRaw: # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, LoRALayer]
|
||||
_device: torch.device
|
||||
_dtype: torch.dtype
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: Dict[str, LoRALayer],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self._name = name
|
||||
self._device = device or torch.cpu
|
||||
self._dtype = dtype or torch.float32
|
||||
self.layers = layers
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
@ -475,8 +461,6 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
# TODO: try revert if exception?
|
||||
for key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
@ -557,8 +541,6 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
file_path = Path(file_path)
|
||||
|
||||
model = cls(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
name=file_path.stem, # TODO:
|
||||
layers=dict(),
|
||||
)
|
||||
|
102
tests/backend/model_management/test_lora.py
Normal file
102
tests/backend/model_management/test_lora.py
Normal file
@ -0,0 +1,102 @@
|
||||
# test that if the model's device changes while the lora is applied, the weights can still be restored
|
||||
|
||||
# test that LoRA patching works on both CPU and CUDA
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_management.lora import ModelPatcher
|
||||
from invokeai.backend.model_management.models.lora import LoRALayer, LoRAModelRaw
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device",
|
||||
[
|
||||
"cpu",
|
||||
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
],
|
||||
)
|
||||
@torch.no_grad()
|
||||
def test_apply_lora(device):
|
||||
"""Test the basic behavior of ModelPatcher.apply_lora(...). Check that patching and unpatching produce the correct
|
||||
result, and that model/LoRA tensors are moved between devices as expected.
|
||||
"""
|
||||
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_dim = 2
|
||||
model = torch.nn.ModuleDict(
|
||||
{"linear_layer_1": torch.nn.Linear(linear_in_features, linear_out_features, device=device, dtype=torch.float16)}
|
||||
)
|
||||
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer(
|
||||
layer_key="linear_layer_1",
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = LoRAModelRaw("lora_name", lora_layers)
|
||||
|
||||
lora_weight = 0.5
|
||||
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
|
||||
expected_patched_linear_weight = orig_linear_weight + (lora_dim * lora_weight)
|
||||
|
||||
with ModelPatcher.apply_lora(model, [(lora, lora_weight)], prefix=""):
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
|
||||
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
|
||||
|
||||
# After patching, the patched model should still be on its original device.
|
||||
assert model["linear_layer_1"].weight.data.device.type == device
|
||||
|
||||
torch.testing.assert_close(model["linear_layer_1"].weight.data, expected_patched_linear_weight)
|
||||
|
||||
# After unpatching, the original model weights should have been restored on the original device.
|
||||
assert model["linear_layer_1"].weight.data.device.type == device
|
||||
torch.testing.assert_close(model["linear_layer_1"].weight.data, orig_linear_weight)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
|
||||
@torch.no_grad()
|
||||
def test_apply_lora_change_device():
|
||||
"""Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching
|
||||
still behaves correctly.
|
||||
"""
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_dim = 2
|
||||
# Initialize the model on the CPU.
|
||||
model = torch.nn.ModuleDict(
|
||||
{"linear_layer_1": torch.nn.Linear(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)}
|
||||
)
|
||||
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer(
|
||||
layer_key="linear_layer_1",
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = LoRAModelRaw("lora_name", lora_layers)
|
||||
|
||||
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
|
||||
|
||||
with ModelPatcher.apply_lora(model, [(lora, 0.5)], prefix=""):
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
|
||||
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
|
||||
|
||||
# After patching, the patched model should still be on the CPU.
|
||||
assert model["linear_layer_1"].weight.data.device.type == "cpu"
|
||||
|
||||
# Move the model to the GPU.
|
||||
assert model.to("cuda")
|
||||
|
||||
# After unpatching, the original model weights should have been restored on the GPU.
|
||||
assert model["linear_layer_1"].weight.data.device.type == "cuda"
|
||||
torch.testing.assert_close(model["linear_layer_1"].weight.data, orig_linear_weight, check_device=False)
|
Loading…
x
Reference in New Issue
Block a user