mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Tidy types in LoraModelPatcher.
This commit is contained in:
parent
e1aa1ed6af
commit
c15e9e23ca
@ -1,5 +1,5 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator, List, Tuple
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
@ -42,7 +42,7 @@ class LoraModelPatcher:
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
):
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
|
||||
@ -52,7 +52,7 @@ class LoraModelPatcher:
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
|
||||
@ -61,8 +61,8 @@ class LoraModelPatcher:
|
||||
def apply_sdxl_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
|
||||
yield
|
||||
|
||||
@ -71,8 +71,8 @@ class LoraModelPatcher:
|
||||
def apply_sdxl_lora_text_encoder2(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
|
||||
yield
|
||||
|
||||
@ -83,7 +83,7 @@ class LoraModelPatcher:
|
||||
model: AnyModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
) -> None:
|
||||
):
|
||||
original_weights = {}
|
||||
try:
|
||||
with torch.no_grad():
|
||||
@ -123,13 +123,9 @@ class LoraModelPatcher:
|
||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||
layer.to(device=torch.device("cpu"))
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
assert hasattr(layer_weight, "reshape")
|
||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
module.weight += layer_weight.to(dtype=dtype)
|
||||
|
||||
yield # wait for context manager exit
|
||||
|
Loading…
Reference in New Issue
Block a user