mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Modular backend - LoRA/LyCORIS (#6667)
## Summary Code for lora patching from #6577. Additionally made it the way, that lora can patch not only `weight`, but also `bias`, because saw some loras which doing it. ## Related Issues / Discussions #6606 https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d ## QA Instructions Run with and without set `USE_MODULAR_DENOISE` environment. ## Merge Plan Replace old lora patcher with new after review done. If you think that there should be some kind of tests - feel free to add. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
This commit is contained in:
commit
4ce64b69cb
@ -80,12 +80,12 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
# apply all patches while the model is on the target device
|
# apply all patches while the model is on the target device
|
||||||
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
|
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||||
tokenizer_info as tokenizer,
|
tokenizer_info as tokenizer,
|
||||||
ModelPatcher.apply_lora_text_encoder(
|
ModelPatcher.apply_lora_text_encoder(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
loras=_lora_loader(),
|
loras=_lora_loader(),
|
||||||
model_state_dict=model_state_dict,
|
cached_weights=cached_weights,
|
||||||
),
|
),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
||||||
@ -175,13 +175,13 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
# apply all patches while the model is on the target device
|
# apply all patches while the model is on the target device
|
||||||
text_encoder_info.model_on_device() as (state_dict, text_encoder),
|
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||||
tokenizer_info as tokenizer,
|
tokenizer_info as tokenizer,
|
||||||
ModelPatcher.apply_lora(
|
ModelPatcher.apply_lora(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
loras=_lora_loader(),
|
loras=_lora_loader(),
|
||||||
prefix=lora_prefix,
|
prefix=lora_prefix,
|
||||||
model_state_dict=state_dict,
|
cached_weights=cached_weights,
|
||||||
),
|
),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
||||||
|
@ -62,6 +62,7 @@ from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetEx
|
|||||||
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
|
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
|
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||||
@ -845,6 +846,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if self.unet.freeu_config:
|
if self.unet.freeu_config:
|
||||||
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
||||||
|
|
||||||
|
### lora
|
||||||
|
if self.unet.loras:
|
||||||
|
for lora_field in self.unet.loras:
|
||||||
|
ext_manager.add_extension(
|
||||||
|
LoRAExt(
|
||||||
|
node_context=context,
|
||||||
|
model_id=lora_field.lora,
|
||||||
|
weight=lora_field.weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
### seamless
|
### seamless
|
||||||
if self.unet.seamless_axes:
|
if self.unet.seamless_axes:
|
||||||
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
|
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
|
||||||
@ -964,14 +975,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
unet_info.model_on_device() as (model_state_dict, unet),
|
unet_info.model_on_device() as (cached_weights, unet),
|
||||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora_unet(
|
ModelPatcher.apply_lora_unet(
|
||||||
unet,
|
unet,
|
||||||
loras=_lora_loader(),
|
loras=_lora_loader(),
|
||||||
model_state_dict=model_state_dict,
|
cached_weights=cached_weights,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
assert isinstance(unet, UNet2DConditionModel)
|
assert isinstance(unet, UNet2DConditionModel)
|
||||||
|
@ -3,12 +3,13 @@
|
|||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.model_manager import BaseModelType
|
from invokeai.backend.model_manager import BaseModelType
|
||||||
from invokeai.backend.raw_model import RawModel
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
@ -46,9 +47,19 @@ class LoRALayerBase:
|
|||||||
self.rank = None # set in layer implementation
|
self.rank = None # set in layer implementation
|
||||||
self.layer_key = layer_key
|
self.layer_key = layer_key
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
|
return self.bias
|
||||||
|
|
||||||
|
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||||
|
params = {"weight": self.get_weight(orig_module.weight)}
|
||||||
|
bias = self.get_bias(orig_module.bias)
|
||||||
|
if bias is not None:
|
||||||
|
params["bias"] = bias
|
||||||
|
return params
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
model_size = 0
|
model_size = 0
|
||||||
for val in [self.bias]:
|
for val in [self.bias]:
|
||||||
@ -60,6 +71,17 @@ class LoRALayerBase:
|
|||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
||||||
|
"""Log a warning if values contains unhandled keys."""
|
||||||
|
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
||||||
|
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
||||||
|
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||||
|
unknown_keys = set(values.keys()) - all_known_keys
|
||||||
|
if unknown_keys:
|
||||||
|
logger.warning(
|
||||||
|
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: find and debug lora/locon with bias
|
# TODO: find and debug lora/locon with bias
|
||||||
class LoRALayer(LoRALayerBase):
|
class LoRALayer(LoRALayerBase):
|
||||||
@ -76,14 +98,19 @@ class LoRALayer(LoRALayerBase):
|
|||||||
|
|
||||||
self.up = values["lora_up.weight"]
|
self.up = values["lora_up.weight"]
|
||||||
self.down = values["lora_down.weight"]
|
self.down = values["lora_down.weight"]
|
||||||
if "lora_mid.weight" in values:
|
self.mid = values.get("lora_mid.weight", None)
|
||||||
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
|
|
||||||
else:
|
|
||||||
self.mid = None
|
|
||||||
|
|
||||||
self.rank = self.down.shape[0]
|
self.rank = self.down.shape[0]
|
||||||
|
self.check_keys(
|
||||||
|
values,
|
||||||
|
{
|
||||||
|
"lora_up.weight",
|
||||||
|
"lora_down.weight",
|
||||||
|
"lora_mid.weight",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
if self.mid is not None:
|
if self.mid is not None:
|
||||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||||
@ -125,20 +152,23 @@ class LoHALayer(LoRALayerBase):
|
|||||||
self.w1_b = values["hada_w1_b"]
|
self.w1_b = values["hada_w1_b"]
|
||||||
self.w2_a = values["hada_w2_a"]
|
self.w2_a = values["hada_w2_a"]
|
||||||
self.w2_b = values["hada_w2_b"]
|
self.w2_b = values["hada_w2_b"]
|
||||||
|
self.t1 = values.get("hada_t1", None)
|
||||||
if "hada_t1" in values:
|
self.t2 = values.get("hada_t2", None)
|
||||||
self.t1: Optional[torch.Tensor] = values["hada_t1"]
|
|
||||||
else:
|
|
||||||
self.t1 = None
|
|
||||||
|
|
||||||
if "hada_t2" in values:
|
|
||||||
self.t2: Optional[torch.Tensor] = values["hada_t2"]
|
|
||||||
else:
|
|
||||||
self.t2 = None
|
|
||||||
|
|
||||||
self.rank = self.w1_b.shape[0]
|
self.rank = self.w1_b.shape[0]
|
||||||
|
self.check_keys(
|
||||||
|
values,
|
||||||
|
{
|
||||||
|
"hada_w1_a",
|
||||||
|
"hada_w1_b",
|
||||||
|
"hada_w2_a",
|
||||||
|
"hada_w2_b",
|
||||||
|
"hada_t1",
|
||||||
|
"hada_t2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
if self.t1 is None:
|
if self.t1 is None:
|
||||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||||
|
|
||||||
@ -186,37 +216,39 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
):
|
):
|
||||||
super().__init__(layer_key, values)
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
if "lokr_w1" in values:
|
self.w1 = values.get("lokr_w1", None)
|
||||||
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
|
if self.w1 is None:
|
||||||
self.w1_a = None
|
|
||||||
self.w1_b = None
|
|
||||||
else:
|
|
||||||
self.w1 = None
|
|
||||||
self.w1_a = values["lokr_w1_a"]
|
self.w1_a = values["lokr_w1_a"]
|
||||||
self.w1_b = values["lokr_w1_b"]
|
self.w1_b = values["lokr_w1_b"]
|
||||||
|
|
||||||
if "lokr_w2" in values:
|
self.w2 = values.get("lokr_w2", None)
|
||||||
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
|
if self.w2 is None:
|
||||||
self.w2_a = None
|
|
||||||
self.w2_b = None
|
|
||||||
else:
|
|
||||||
self.w2 = None
|
|
||||||
self.w2_a = values["lokr_w2_a"]
|
self.w2_a = values["lokr_w2_a"]
|
||||||
self.w2_b = values["lokr_w2_b"]
|
self.w2_b = values["lokr_w2_b"]
|
||||||
|
|
||||||
if "lokr_t2" in values:
|
self.t2 = values.get("lokr_t2", None)
|
||||||
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
|
|
||||||
else:
|
|
||||||
self.t2 = None
|
|
||||||
|
|
||||||
if "lokr_w1_b" in values:
|
if self.w1_b is not None:
|
||||||
self.rank = values["lokr_w1_b"].shape[0]
|
self.rank = self.w1_b.shape[0]
|
||||||
elif "lokr_w2_b" in values:
|
elif self.w2_b is not None:
|
||||||
self.rank = values["lokr_w2_b"].shape[0]
|
self.rank = self.w2_b.shape[0]
|
||||||
else:
|
else:
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
self.check_keys(
|
||||||
|
values,
|
||||||
|
{
|
||||||
|
"lokr_w1",
|
||||||
|
"lokr_w1_a",
|
||||||
|
"lokr_w1_b",
|
||||||
|
"lokr_w2",
|
||||||
|
"lokr_w2_a",
|
||||||
|
"lokr_w2_b",
|
||||||
|
"lokr_t2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
w1: Optional[torch.Tensor] = self.w1
|
w1: Optional[torch.Tensor] = self.w1
|
||||||
if w1 is None:
|
if w1 is None:
|
||||||
assert self.w1_a is not None
|
assert self.w1_a is not None
|
||||||
@ -272,7 +304,9 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
|
|
||||||
|
|
||||||
class FullLayer(LoRALayerBase):
|
class FullLayer(LoRALayerBase):
|
||||||
|
# bias handled in LoRALayerBase(calc_size, to)
|
||||||
# weight: torch.Tensor
|
# weight: torch.Tensor
|
||||||
|
# bias: Optional[torch.Tensor]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -282,15 +316,12 @@ class FullLayer(LoRALayerBase):
|
|||||||
super().__init__(layer_key, values)
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
self.weight = values["diff"]
|
self.weight = values["diff"]
|
||||||
|
self.bias = values.get("diff_b", None)
|
||||||
if len(values.keys()) > 1:
|
|
||||||
_keys = list(values.keys())
|
|
||||||
_keys.remove("diff")
|
|
||||||
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
|
||||||
|
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
self.check_keys(values, {"diff", "diff_b"})
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
return self.weight
|
return self.weight
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
@ -319,8 +350,9 @@ class IA3Layer(LoRALayerBase):
|
|||||||
self.on_input = values["on_input"]
|
self.on_input = values["on_input"]
|
||||||
|
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
self.check_keys(values, {"weight", "on_input"})
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
weight = self.weight
|
weight = self.weight
|
||||||
if not self.on_input:
|
if not self.on_input:
|
||||||
weight = weight.reshape(-1, 1)
|
weight = weight.reshape(-1, 1)
|
||||||
@ -458,16 +490,19 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
|||||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||||
|
|
||||||
for layer_key, values in state_dict.items():
|
for layer_key, values in state_dict.items():
|
||||||
|
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||||
|
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||||
|
|
||||||
# lora and locon
|
# lora and locon
|
||||||
if "lora_down.weight" in values:
|
if "lora_up.weight" in values:
|
||||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||||
|
|
||||||
# loha
|
# loha
|
||||||
elif "hada_w1_b" in values:
|
elif "hada_w1_a" in values:
|
||||||
layer = LoHALayer(layer_key, values)
|
layer = LoHALayer(layer_key, values)
|
||||||
|
|
||||||
# lokr
|
# lokr
|
||||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
elif "lokr_w1" in values or "lokr_w1_a" in values:
|
||||||
layer = LoKRLayer(layer_key, values)
|
layer = LoKRLayer(layer_key, values)
|
||||||
|
|
||||||
# diff
|
# diff
|
||||||
@ -475,7 +510,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
|||||||
layer = FullLayer(layer_key, values)
|
layer = FullLayer(layer_key, values)
|
||||||
|
|
||||||
# ia3
|
# ia3
|
||||||
elif "weight" in values and "on_input" in values:
|
elif "on_input" in values:
|
||||||
layer = IA3Layer(layer_key, values)
|
layer = IA3Layer(layer_key, values)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -17,8 +17,9 @@ from invokeai.backend.lora import LoRAModelRaw
|
|||||||
from invokeai.backend.model_manager import AnyModel
|
from invokeai.backend.model_manager import AnyModel
|
||||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
"""
|
"""
|
||||||
loras = [
|
loras = [
|
||||||
@ -85,13 +86,13 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
with cls.apply_lora(
|
with cls.apply_lora(
|
||||||
unet,
|
unet,
|
||||||
loras=loras,
|
loras=loras,
|
||||||
prefix="lora_unet_",
|
prefix="lora_unet_",
|
||||||
model_state_dict=model_state_dict,
|
cached_weights=cached_weights,
|
||||||
):
|
):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@ -101,9 +102,9 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
|
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -113,7 +114,7 @@ class ModelPatcher:
|
|||||||
model: AnyModel,
|
model: AnyModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
"""
|
"""
|
||||||
Apply one or more LoRAs to a model.
|
Apply one or more LoRAs to a model.
|
||||||
@ -121,66 +122,26 @@ class ModelPatcher:
|
|||||||
:param model: The model to patch.
|
:param model: The model to patch.
|
||||||
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
||||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||||
:model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||||
"""
|
"""
|
||||||
original_weights = {}
|
original_weights = OriginalWeightsStorage(cached_weights)
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
for lora_model, lora_weight in loras:
|
||||||
for lora, lora_weight in loras:
|
LoRAExt.patch_model(
|
||||||
# assert lora.device.type == "cpu"
|
model=model,
|
||||||
for layer_key, layer in lora.layers.items():
|
prefix=prefix,
|
||||||
if not layer_key.startswith(prefix):
|
lora=lora_model,
|
||||||
continue
|
lora_weight=lora_weight,
|
||||||
|
original_weights=original_weights,
|
||||||
|
)
|
||||||
|
del lora_model
|
||||||
|
|
||||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
yield
|
||||||
# should be improved in the following ways:
|
|
||||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
|
||||||
# LoRA model is applied.
|
|
||||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
|
||||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
|
||||||
# weights to have valid keys.
|
|
||||||
assert isinstance(model, torch.nn.Module)
|
|
||||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
|
||||||
|
|
||||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
|
||||||
# (Performance will be best if this is a CUDA device.)
|
|
||||||
device = module.weight.device
|
|
||||||
dtype = module.weight.dtype
|
|
||||||
|
|
||||||
if module_key not in original_weights:
|
|
||||||
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
|
|
||||||
original_weights[module_key] = model_state_dict[module_key + ".weight"]
|
|
||||||
else:
|
|
||||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
|
||||||
|
|
||||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
|
||||||
|
|
||||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
|
||||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
|
||||||
# same thing in a single call to '.to(...)'.
|
|
||||||
layer.to(device=device)
|
|
||||||
layer.to(dtype=torch.float32)
|
|
||||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
|
||||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
|
||||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
|
||||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
|
||||||
|
|
||||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
|
||||||
if module.weight.shape != layer_weight.shape:
|
|
||||||
# TODO: debug on lycoris
|
|
||||||
assert hasattr(layer_weight, "reshape")
|
|
||||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
|
||||||
|
|
||||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
|
||||||
module.weight += layer_weight.to(dtype=dtype)
|
|
||||||
|
|
||||||
yield # wait for context manager exit
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for module_key, weight in original_weights.items():
|
for param_key, weight in original_weights.get_changed_weights():
|
||||||
model.get_submodule(module_key).weight.copy_(weight)
|
model.get_parameter(param_key).copy_(weight)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -2,14 +2,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
from typing import TYPE_CHECKING, Callable, Dict, List
|
||||||
|
|
||||||
import torch
|
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -56,5 +56,17 @@ class ExtensionBase:
|
|||||||
yield None
|
yield None
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||||
yield None
|
"""A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
|
||||||
|
diffusion process. Weight unpatching is handled upstream, and is achieved by saving unchanged weights by
|
||||||
|
`original_weights.save` function. Note that this enables some performance optimization by avoiding redundant
|
||||||
|
operations. All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched
|
||||||
|
by this context manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unet (UNet2DConditionModel): The UNet model on execution device to patch.
|
||||||
|
original_weights (OriginalWeightsStorage): A storage with copy of the model's original weights in CPU, for
|
||||||
|
unpatching purposes. Extension should save tensor which being modified in this storage, also extensions
|
||||||
|
can access original weights values.
|
||||||
|
"""
|
||||||
|
yield
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Dict, Optional
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
from invokeai.app.shared.models import FreeUConfig
|
||||||
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
|
|
||||||
class FreeUExt(ExtensionBase):
|
class FreeUExt(ExtensionBase):
|
||||||
@ -21,7 +21,7 @@ class FreeUExt(ExtensionBase):
|
|||||||
self._freeu_config = freeu_config
|
self._freeu_config = freeu_config
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||||
unet.enable_freeu(
|
unet.enable_freeu(
|
||||||
b1=self._freeu_config.b1,
|
b1=self._freeu_config.b1,
|
||||||
b2=self._freeu_config.b2,
|
b2=self._freeu_config.b2,
|
||||||
|
137
invokeai/backend/stable_diffusion/extensions/lora.py
Normal file
137
invokeai/backend/stable_diffusion/extensions/lora.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAExt(ExtensionBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
node_context: InvocationContext,
|
||||||
|
model_id: ModelIdentifierField,
|
||||||
|
weight: float,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._node_context = node_context
|
||||||
|
self._model_id = model_id
|
||||||
|
self._weight = weight
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||||
|
lora_model = self._node_context.models.load(self._model_id).model
|
||||||
|
self.patch_model(
|
||||||
|
model=unet,
|
||||||
|
prefix="lora_unet_",
|
||||||
|
lora=lora_model,
|
||||||
|
lora_weight=self._weight,
|
||||||
|
original_weights=original_weights,
|
||||||
|
)
|
||||||
|
del lora_model
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@torch.no_grad()
|
||||||
|
def patch_model(
|
||||||
|
cls,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
prefix: str,
|
||||||
|
lora: LoRAModelRaw,
|
||||||
|
lora_weight: float,
|
||||||
|
original_weights: OriginalWeightsStorage,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Apply one or more LoRAs to a model.
|
||||||
|
:param model: The model to patch.
|
||||||
|
:param lora: LoRA model to patch in.
|
||||||
|
:param lora_weight: LoRA patch weight.
|
||||||
|
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||||
|
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if lora_weight == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# assert lora.device.type == "cpu"
|
||||||
|
for layer_key, layer in lora.layers.items():
|
||||||
|
if not layer_key.startswith(prefix):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
||||||
|
# should be improved in the following ways:
|
||||||
|
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
||||||
|
# LoRA model is applied.
|
||||||
|
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
||||||
|
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
||||||
|
# weights to have valid keys.
|
||||||
|
assert isinstance(model, torch.nn.Module)
|
||||||
|
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||||
|
|
||||||
|
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||||
|
# (Performance will be best if this is a CUDA device.)
|
||||||
|
device = module.weight.device
|
||||||
|
dtype = module.weight.dtype
|
||||||
|
|
||||||
|
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||||
|
|
||||||
|
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||||
|
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||||
|
# same thing in a single call to '.to(...)'.
|
||||||
|
layer.to(device=device)
|
||||||
|
layer.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||||
|
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||||
|
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||||
|
param_key = module_key + "." + param_name
|
||||||
|
module_param = module.get_parameter(param_name)
|
||||||
|
|
||||||
|
# save original weight
|
||||||
|
original_weights.save(param_key, module_param)
|
||||||
|
|
||||||
|
if module_param.shape != lora_param_weight.shape:
|
||||||
|
# TODO: debug on lycoris
|
||||||
|
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||||
|
|
||||||
|
lora_param_weight *= lora_weight * layer_scale
|
||||||
|
module_param += lora_param_weight.to(dtype=dtype)
|
||||||
|
|
||||||
|
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||||
|
assert "." not in lora_key
|
||||||
|
|
||||||
|
if not lora_key.startswith(prefix):
|
||||||
|
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
||||||
|
|
||||||
|
module = model
|
||||||
|
module_key = ""
|
||||||
|
key_parts = lora_key[len(prefix) :].split("_")
|
||||||
|
|
||||||
|
submodule_name = key_parts.pop(0)
|
||||||
|
|
||||||
|
while len(key_parts) > 0:
|
||||||
|
try:
|
||||||
|
module = module.get_submodule(submodule_name)
|
||||||
|
module_key += "." + submodule_name
|
||||||
|
submodule_name = key_parts.pop(0)
|
||||||
|
except Exception:
|
||||||
|
submodule_name += "_" + key_parts.pop(0)
|
||||||
|
|
||||||
|
module = module.get_submodule(submodule_name)
|
||||||
|
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||||
|
|
||||||
|
return (module_key, module)
|
@ -7,6 +7,7 @@ import torch
|
|||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||||
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
@ -67,9 +68,15 @@ class ExtensionsManager:
|
|||||||
if self._is_canceled and self._is_canceled():
|
if self._is_canceled and self._is_canceled():
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
|
|
||||||
# TODO: create weight patch logic in PR with extension which uses it
|
original_weights = OriginalWeightsStorage(cached_weights)
|
||||||
|
try:
|
||||||
with ExitStack() as exit_stack:
|
with ExitStack() as exit_stack:
|
||||||
for ext in self._extensions:
|
for ext in self._extensions:
|
||||||
exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
|
exit_stack.enter_context(ext.patch_unet(unet, original_weights))
|
||||||
|
|
||||||
yield None
|
yield None
|
||||||
|
|
||||||
|
finally:
|
||||||
|
with torch.no_grad():
|
||||||
|
for param_key, weight in original_weights.get_changed_weights():
|
||||||
|
unet.get_parameter(param_key).copy_(weight)
|
||||||
|
39
invokeai/backend/util/original_weights_storage.py
Normal file
39
invokeai/backend/util/original_weights_storage.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, Iterator, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
|
class OriginalWeightsStorage:
|
||||||
|
"""A class for tracking the original weights of a model for patch/unpatch operations."""
|
||||||
|
|
||||||
|
def __init__(self, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||||
|
# The original weights of the model.
|
||||||
|
self._weights: dict[str, torch.Tensor] = {}
|
||||||
|
# The keys of the weights that have been changed (via `save()`) during the lifetime of this instance.
|
||||||
|
self._changed_weights: set[str] = set()
|
||||||
|
if cached_weights:
|
||||||
|
self._weights.update(cached_weights)
|
||||||
|
|
||||||
|
def save(self, key: str, weight: torch.Tensor, copy: bool = True):
|
||||||
|
self._changed_weights.add(key)
|
||||||
|
if key in self._weights:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._weights[key] = weight.detach().to(device=TorchDevice.CPU_DEVICE, copy=copy)
|
||||||
|
|
||||||
|
def get(self, key: str, copy: bool = False) -> Optional[torch.Tensor]:
|
||||||
|
weight = self._weights.get(key, None)
|
||||||
|
if weight is not None and copy:
|
||||||
|
weight = weight.clone()
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def contains(self, key: str) -> bool:
|
||||||
|
return key in self._weights
|
||||||
|
|
||||||
|
def get_changed_weights(self) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||||
|
for key in self._changed_weights:
|
||||||
|
yield key, self._weights[key]
|
Loading…
Reference in New Issue
Block a user