Use load_state_dict() util in LoRAModelRaw.

This commit is contained in:
Ryan Dick 2024-04-05 15:20:07 -04:00
parent 61a672cd81
commit 9e68a5c851

View File

@ -2,8 +2,6 @@ from pathlib import Path
from typing import Dict, Optional, Union
import torch
from safetensors.torch import load_file
from typing_extensions import Self
from invokeai.backend.lora.full_layer import FullLayer
from invokeai.backend.lora.ia3_layer import IA3Layer
@ -12,6 +10,7 @@ from invokeai.backend.lora.lokr_layer import LoKRLayer
from invokeai.backend.lora.lora_layer import LoRALayer
from invokeai.backend.lora.sdxl_state_dict_utils import convert_sdxl_keys_to_diffusers_format
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.util.serialization import load_state_dict
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
@ -52,7 +51,7 @@ class LoRAModelRaw(torch.nn.Module):
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
base_model: Optional[BaseModelType] = None,
) -> Self:
):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
@ -64,11 +63,7 @@ class LoRAModelRaw(torch.nn.Module):
layers={},
)
if file_path.suffix == ".safetensors":
sd = load_file(file_path.absolute().as_posix(), device="cpu")
else:
sd = torch.load(file_path, map_location="cpu")
sd = load_state_dict(file_path, device=str(device))
state_dict = cls._group_state(sd)
if base_model == BaseModelType.StableDiffusionXL: