LoRA patching optimization (#6439)

* allow model patcher to optimize away the unpatching step when feasible

* remove lazy_offloading functionality

* allow model patcher to optimize away the unpatching step when feasible

* remove lazy_offloading functionality

* do not save original weights if there is a CPU copy of state dict

* Update invokeai/backend/model_manager/load/load_base.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* documentation fixes added during penultimate review

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com>
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
This commit is contained in:
Lincoln Stein
2024-06-06 09:53:35 -04:00
committed by GitHub
parent 1c5c3cdbd6
commit 2871676f79
7 changed files with 146 additions and 48 deletions

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
@ -66,8 +66,14 @@ class ModelPatcher:
cls,
unet: UNet2DConditionModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
with cls.apply_lora(unet, loras, "lora_unet_"):
with cls.apply_lora(
unet,
loras=loras,
prefix="lora_unet_",
model_state_dict=model_state_dict,
):
yield
@classmethod
@ -76,28 +82,9 @@ class ModelPatcher:
cls,
text_encoder: CLIPTextModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder2(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
yield
@classmethod
@ -107,7 +94,16 @@ class ModelPatcher:
model: AnyModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str,
) -> None:
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[Any, None, None]:
"""
Apply one or more LoRAs to a model.
:param model: The model to patch.
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
"""
original_weights = {}
try:
with torch.no_grad():
@ -133,7 +129,10 @@ class ModelPatcher:
dtype = module.weight.dtype
if module_key not in original_weights:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
original_weights[module_key] = model_state_dict[module_key + ".weight"]
else:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0