mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Minor cleanup of LoRAModelRaw.
This commit is contained in:
parent
9e68a5c851
commit
4b68050c9b
@ -1,5 +1,5 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -19,7 +19,7 @@ class LoRAModelRaw(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: Dict[str, AnyLoRALayer],
|
||||
layers: dict[str, AnyLoRALayer],
|
||||
):
|
||||
super().__init__()
|
||||
self._name = name
|
||||
@ -55,13 +55,9 @@ class LoRAModelRaw(torch.nn.Module):
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
file_path = Path(file_path)
|
||||
|
||||
model = cls(
|
||||
name=file_path.stem,
|
||||
layers={},
|
||||
)
|
||||
model_name = file_path.stem
|
||||
|
||||
sd = load_state_dict(file_path, device=str(device))
|
||||
state_dict = cls._group_state(sd)
|
||||
@ -69,6 +65,7 @@ class LoRAModelRaw(torch.nn.Module):
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, values in state_dict.items():
|
||||
# lora and locon
|
||||
if "lora_down.weight" in values:
|
||||
@ -91,20 +88,19 @@ class LoRAModelRaw(torch.nn.Module):
|
||||
layer = IA3Layer(layer_key, values)
|
||||
|
||||
else:
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||
raise Exception("Unknown lora format!")
|
||||
raise ValueError(f"Unknown lora layer module in {model_name}: {layer_key}: {list(values.keys())}")
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype)
|
||||
model.layers[layer_key] = layer
|
||||
layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
return cls(name=model_name, layers=layers)
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
def _group_state(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: dict[str, dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
|
@ -1,6 +1,5 @@
|
||||
import bisect
|
||||
|
||||
import torch
|
||||
from typing import TypeVar
|
||||
|
||||
|
||||
def make_sdxl_unet_conversion_map() -> list[tuple[str, str]]:
|
||||
@ -97,7 +96,10 @@ SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def convert_sdxl_keys_to_diffusers_format(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def convert_sdxl_keys_to_diffusers_format(state_dict: dict[str, T]) -> dict[str, T]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
@ -124,7 +126,7 @@ def convert_sdxl_keys_to_diffusers_format(state_dict: dict[str, torch.Tensor]) -
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict: dict[str, torch.Tensor] = {}
|
||||
new_state_dict: dict[str, T] = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
|
Loading…
Reference in New Issue
Block a user