Minor cleanup of LoRAModelRaw.

This commit is contained in:
Ryan Dick 2024-04-05 15:30:01 -04:00
parent 9e68a5c851
commit 4b68050c9b
2 changed files with 16 additions and 18 deletions

View File

@ -1,5 +1,5 @@
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Optional, Union
import torch
@ -19,7 +19,7 @@ class LoRAModelRaw(torch.nn.Module):
def __init__(
self,
name: str,
layers: Dict[str, AnyLoRALayer],
layers: dict[str, AnyLoRALayer],
):
super().__init__()
self._name = name
@ -55,13 +55,9 @@ class LoRAModelRaw(torch.nn.Module):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
file_path = Path(file_path)
model = cls(
name=file_path.stem,
layers={},
)
model_name = file_path.stem
sd = load_state_dict(file_path, device=str(device))
state_dict = cls._group_state(sd)
@ -69,6 +65,7 @@ class LoRAModelRaw(torch.nn.Module):
if base_model == BaseModelType.StableDiffusionXL:
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
layers: dict[str, AnyLoRALayer] = {}
for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
@ -91,20 +88,19 @@ class LoRAModelRaw(torch.nn.Module):
layer = IA3Layer(layer_key, values)
else:
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")
raise ValueError(f"Unknown lora layer module in {model_name}: {layer_key}: {list(values.keys())}")
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
layers[layer_key] = layer
return model
return cls(name=model_name, layers=layers)
@staticmethod
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
def _group_state(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
state_dict_groupped: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)

View File

@ -1,6 +1,5 @@
import bisect
import torch
from typing import TypeVar
def make_sdxl_unet_conversion_map() -> list[tuple[str, str]]:
@ -97,7 +96,10 @@ SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
}
def convert_sdxl_keys_to_diffusers_format(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
T = TypeVar("T")
def convert_sdxl_keys_to_diffusers_format(state_dict: dict[str, T]) -> dict[str, T]:
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
@ -124,7 +126,7 @@ def convert_sdxl_keys_to_diffusers_format(state_dict: dict[str, torch.Tensor]) -
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict: dict[str, torch.Tensor] = {}
new_state_dict: dict[str, T] = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")