mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Use load_state_dict() util in LoRAModelRaw.
This commit is contained in:
parent
61a672cd81
commit
9e68a5c851
@ -2,8 +2,6 @@ from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.backend.lora.full_layer import FullLayer
|
||||
from invokeai.backend.lora.ia3_layer import IA3Layer
|
||||
@ -12,6 +10,7 @@ from invokeai.backend.lora.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.sdxl_state_dict_utils import convert_sdxl_keys_to_diffusers_format
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.util.serialization import load_state_dict
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||
|
||||
@ -52,7 +51,7 @@ class LoRAModelRaw(torch.nn.Module):
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
) -> Self:
|
||||
):
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
@ -64,11 +63,7 @@ class LoRAModelRaw(torch.nn.Module):
|
||||
layers={},
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
sd = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
sd = torch.load(file_path, map_location="cpu")
|
||||
|
||||
sd = load_state_dict(file_path, device=str(device))
|
||||
state_dict = cls._group_state(sd)
|
||||
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
|
Loading…
Reference in New Issue
Block a user