Tidy types in LoraModelPatcher.

This commit is contained in:
Ryan Dick 2024-04-05 16:05:31 -04:00
parent e1aa1ed6af
commit c15e9e23ca

View File

@ -1,5 +1,5 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import Iterator, List, Tuple from typing import Iterator, Tuple
import torch import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
@ -42,7 +42,7 @@ class LoraModelPatcher:
cls, cls,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
loras: Iterator[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None: ):
with cls.apply_lora(unet, loras, "lora_unet_"): with cls.apply_lora(unet, loras, "lora_unet_"):
yield yield
@ -52,7 +52,7 @@ class LoraModelPatcher:
cls, cls,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
loras: Iterator[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None: ):
with cls.apply_lora(text_encoder, loras, "lora_te_"): with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield yield
@ -61,8 +61,8 @@ class LoraModelPatcher:
def apply_sdxl_lora_text_encoder( def apply_sdxl_lora_text_encoder(
cls, cls,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None: ):
with cls.apply_lora(text_encoder, loras, "lora_te1_"): with cls.apply_lora(text_encoder, loras, "lora_te1_"):
yield yield
@ -71,8 +71,8 @@ class LoraModelPatcher:
def apply_sdxl_lora_text_encoder2( def apply_sdxl_lora_text_encoder2(
cls, cls,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None: ):
with cls.apply_lora(text_encoder, loras, "lora_te2_"): with cls.apply_lora(text_encoder, loras, "lora_te2_"):
yield yield
@ -83,7 +83,7 @@ class LoraModelPatcher:
model: AnyModel, model: AnyModel,
loras: Iterator[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str, prefix: str,
) -> None: ):
original_weights = {} original_weights = {}
try: try:
with torch.no_grad(): with torch.no_grad():
@ -123,13 +123,9 @@ class LoraModelPatcher:
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device=torch.device("cpu")) layer.to(device=torch.device("cpu"))
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape: if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
assert hasattr(layer_weight, "reshape")
layer_weight = layer_weight.reshape(module.weight.shape) layer_weight = layer_weight.reshape(module.weight.shape)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
module.weight += layer_weight.to(dtype=dtype) module.weight += layer_weight.to(dtype=dtype)
yield # wait for context manager exit yield # wait for context manager exit