mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Use load_state_dict() util in LoRAModelRaw.
This commit is contained in:
parent
61a672cd81
commit
9e68a5c851
@ -2,8 +2,6 @@ from pathlib import Path
|
|||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file
|
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
from invokeai.backend.lora.full_layer import FullLayer
|
from invokeai.backend.lora.full_layer import FullLayer
|
||||||
from invokeai.backend.lora.ia3_layer import IA3Layer
|
from invokeai.backend.lora.ia3_layer import IA3Layer
|
||||||
@ -12,6 +10,7 @@ from invokeai.backend.lora.lokr_layer import LoKRLayer
|
|||||||
from invokeai.backend.lora.lora_layer import LoRALayer
|
from invokeai.backend.lora.lora_layer import LoRALayer
|
||||||
from invokeai.backend.lora.sdxl_state_dict_utils import convert_sdxl_keys_to_diffusers_format
|
from invokeai.backend.lora.sdxl_state_dict_utils import convert_sdxl_keys_to_diffusers_format
|
||||||
from invokeai.backend.model_manager import BaseModelType
|
from invokeai.backend.model_manager import BaseModelType
|
||||||
|
from invokeai.backend.util.serialization import load_state_dict
|
||||||
|
|
||||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||||
|
|
||||||
@ -52,7 +51,7 @@ class LoRAModelRaw(torch.nn.Module):
|
|||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
base_model: Optional[BaseModelType] = None,
|
base_model: Optional[BaseModelType] = None,
|
||||||
) -> Self:
|
):
|
||||||
device = device or torch.device("cpu")
|
device = device or torch.device("cpu")
|
||||||
dtype = dtype or torch.float32
|
dtype = dtype or torch.float32
|
||||||
|
|
||||||
@ -64,11 +63,7 @@ class LoRAModelRaw(torch.nn.Module):
|
|||||||
layers={},
|
layers={},
|
||||||
)
|
)
|
||||||
|
|
||||||
if file_path.suffix == ".safetensors":
|
sd = load_state_dict(file_path, device=str(device))
|
||||||
sd = load_file(file_path.absolute().as_posix(), device="cpu")
|
|
||||||
else:
|
|
||||||
sd = torch.load(file_path, map_location="cpu")
|
|
||||||
|
|
||||||
state_dict = cls._group_state(sd)
|
state_dict = cls._group_state(sd)
|
||||||
|
|
||||||
if base_model == BaseModelType.StableDiffusionXL:
|
if base_model == BaseModelType.StableDiffusionXL:
|
||||||
|
Loading…
Reference in New Issue
Block a user