mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Tidy types in LoraModelPatcher.
This commit is contained in:
parent
e1aa1ed6af
commit
c15e9e23ca
@ -1,5 +1,5 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Iterator, List, Tuple
|
from typing import Iterator, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
@ -42,7 +42,7 @@ class LoraModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
) -> None:
|
):
|
||||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ class LoraModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
) -> None:
|
):
|
||||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@ -61,8 +61,8 @@ class LoraModelPatcher:
|
|||||||
def apply_sdxl_lora_text_encoder(
|
def apply_sdxl_lora_text_encoder(
|
||||||
cls,
|
cls,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
loras: List[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
) -> None:
|
):
|
||||||
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
|
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@ -71,8 +71,8 @@ class LoraModelPatcher:
|
|||||||
def apply_sdxl_lora_text_encoder2(
|
def apply_sdxl_lora_text_encoder2(
|
||||||
cls,
|
cls,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
loras: List[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
) -> None:
|
):
|
||||||
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
|
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@ -83,7 +83,7 @@ class LoraModelPatcher:
|
|||||||
model: AnyModel,
|
model: AnyModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
) -> None:
|
):
|
||||||
original_weights = {}
|
original_weights = {}
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -123,13 +123,9 @@ class LoraModelPatcher:
|
|||||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||||
layer.to(device=torch.device("cpu"))
|
layer.to(device=torch.device("cpu"))
|
||||||
|
|
||||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
|
||||||
if module.weight.shape != layer_weight.shape:
|
if module.weight.shape != layer_weight.shape:
|
||||||
# TODO: debug on lycoris
|
|
||||||
assert hasattr(layer_weight, "reshape")
|
|
||||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||||
|
|
||||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
|
||||||
module.weight += layer_weight.to(dtype=dtype)
|
module.weight += layer_weight.to(dtype=dtype)
|
||||||
|
|
||||||
yield # wait for context manager exit
|
yield # wait for context manager exit
|
||||||
|
Loading…
Reference in New Issue
Block a user