diff --git a/invokeai/backend/lora/lora_model.py b/invokeai/backend/lora/lora_model.py index a7ab274b40..30a22ce6e3 100644 --- a/invokeai/backend/lora/lora_model.py +++ b/invokeai/backend/lora/lora_model.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Dict, Optional, Union +from typing import Optional, Union import torch @@ -19,7 +19,7 @@ class LoRAModelRaw(torch.nn.Module): def __init__( self, name: str, - layers: Dict[str, AnyLoRALayer], + layers: dict[str, AnyLoRALayer], ): super().__init__() self._name = name @@ -55,13 +55,9 @@ class LoRAModelRaw(torch.nn.Module): device = device or torch.device("cpu") dtype = dtype or torch.float32 - if isinstance(file_path, str): - file_path = Path(file_path) + file_path = Path(file_path) - model = cls( - name=file_path.stem, - layers={}, - ) + model_name = file_path.stem sd = load_state_dict(file_path, device=str(device)) state_dict = cls._group_state(sd) @@ -69,6 +65,7 @@ class LoRAModelRaw(torch.nn.Module): if base_model == BaseModelType.StableDiffusionXL: state_dict = convert_sdxl_keys_to_diffusers_format(state_dict) + layers: dict[str, AnyLoRALayer] = {} for layer_key, values in state_dict.items(): # lora and locon if "lora_down.weight" in values: @@ -91,20 +88,19 @@ class LoRAModelRaw(torch.nn.Module): layer = IA3Layer(layer_key, values) else: - print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}") - raise Exception("Unknown lora format!") + raise ValueError(f"Unknown lora layer module in {model_name}: {layer_key}: {list(values.keys())}") # lower memory consumption by removing already parsed layer values state_dict[layer_key].clear() layer.to(device=device, dtype=dtype) - model.layers[layer_key] = layer + layers[layer_key] = layer - return model + return cls(name=model_name, layers=layers) @staticmethod - def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]: - state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {} + def _group_state(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]: + state_dict_groupped: dict[str, dict[str, torch.Tensor]] = {} for key, value in state_dict.items(): stem, leaf = key.split(".", 1) diff --git a/invokeai/backend/lora/sdxl_state_dict_utils.py b/invokeai/backend/lora/sdxl_state_dict_utils.py index 643c7d353b..15eede468d 100644 --- a/invokeai/backend/lora/sdxl_state_dict_utils.py +++ b/invokeai/backend/lora/sdxl_state_dict_utils.py @@ -1,6 +1,5 @@ import bisect - -import torch +from typing import TypeVar def make_sdxl_unet_conversion_map() -> list[tuple[str, str]]: @@ -97,7 +96,10 @@ SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = { } -def convert_sdxl_keys_to_diffusers_format(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: +T = TypeVar("T") + + +def convert_sdxl_keys_to_diffusers_format(state_dict: dict[str, T]) -> dict[str, T]: """Convert the keys of an SDXL LoRA state_dict to diffusers format. The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in @@ -124,7 +126,7 @@ def convert_sdxl_keys_to_diffusers_format(state_dict: dict[str, torch.Tensor]) - stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP) stability_unet_keys.sort() - new_state_dict: dict[str, torch.Tensor] = {} + new_state_dict: dict[str, T] = {} for full_key, value in state_dict.items(): if full_key.startswith("lora_unet_"): search_key = full_key.replace("lora_unet_", "")