mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add util function for loading state_dicts from disk.
This commit is contained in:
parent
6d9fb207f0
commit
8db4ba252a
37
invokeai/backend/util/serialization.py
Normal file
37
invokeai/backend/util/serialization.py
Normal file
@ -0,0 +1,37 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
def state_dict_to(
|
||||
state_dict: dict[str, torch.Tensor], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
new_state_dict: dict[str, torch.Tensor] = {}
|
||||
for k, v in state_dict.items():
|
||||
new_state_dict[k] = v.to(device=device, dtype=dtype, non_blocking=True)
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def load_state_dict(file_path: Union[str, Path], device: str = "cpu") -> Any:
|
||||
"""Load a state_dict from a file that may be in either PyTorch or safetensors format. The file format is inferred
|
||||
from the file extension.
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(
|
||||
file_path,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
# weights_only=True is used to address a security vulnerability that allows arbitrary code execution.
|
||||
# This option was first introduced in https://github.com/pytorch/pytorch/pull/86812.
|
||||
#
|
||||
# mmap=True is used to both reduce memory usage and speed up loading. This setting causes torch.load() to more
|
||||
# closely mirror the behaviour of safetensors.torch.load_file(). This option was first introduced in
|
||||
# https://github.com/pytorch/pytorch/pull/102549. The discussion on that PR provides helpful context.
|
||||
state_dict = torch.load(file_path, map_location=device, weights_only=True, mmap=True)
|
||||
|
||||
return state_dict
|
Loading…
Reference in New Issue
Block a user