Tidy types in LoraModelPatcher.

This commit is contained in:
Ryan Dick 2024-04-05 16:05:31 -04:00
parent e1aa1ed6af
commit c15e9e23ca

View File

@ -1,5 +1,5 @@
from contextlib import contextmanager
from typing import Iterator, List, Tuple
from typing import Iterator, Tuple
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
@ -42,7 +42,7 @@ class LoraModelPatcher:
cls,
unet: UNet2DConditionModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None:
):
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@ -52,7 +52,7 @@ class LoraModelPatcher:
cls,
text_encoder: CLIPTextModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None:
):
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@ -61,8 +61,8 @@ class LoraModelPatcher:
def apply_sdxl_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> None:
loras: Iterator[Tuple[LoRAModelRaw, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
yield
@ -71,8 +71,8 @@ class LoraModelPatcher:
def apply_sdxl_lora_text_encoder2(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> None:
loras: Iterator[Tuple[LoRAModelRaw, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
yield
@ -83,7 +83,7 @@ class LoraModelPatcher:
model: AnyModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str,
) -> None:
):
original_weights = {}
try:
with torch.no_grad():
@ -123,13 +123,9 @@ class LoraModelPatcher:
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device=torch.device("cpu"))
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
assert hasattr(layer_weight, "reshape")
layer_weight = layer_weight.reshape(module.weight.shape)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
module.weight += layer_weight.to(dtype=dtype)
yield # wait for context manager exit