Add util function for loading state_dicts from disk.

This commit is contained in:
Ryan Dick 2024-04-04 13:56:32 -04:00
parent 6d9fb207f0
commit 8db4ba252a

View File

@ -0,0 +1,37 @@
from pathlib import Path
from typing import Any, Optional, Union
import torch
from safetensors.torch import load_file
def state_dict_to(
state_dict: dict[str, torch.Tensor], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
) -> dict[str, torch.Tensor]:
new_state_dict: dict[str, torch.Tensor] = {}
for k, v in state_dict.items():
new_state_dict[k] = v.to(device=device, dtype=dtype, non_blocking=True)
return new_state_dict
def load_state_dict(file_path: Union[str, Path], device: str = "cpu") -> Any:
"""Load a state_dict from a file that may be in either PyTorch or safetensors format. The file format is inferred
from the file extension.
"""
file_path = Path(file_path)
if file_path.suffix == ".safetensors":
state_dict = load_file(
file_path,
device=device,
)
else:
# weights_only=True is used to address a security vulnerability that allows arbitrary code execution.
# This option was first introduced in https://github.com/pytorch/pytorch/pull/86812.
#
# mmap=True is used to both reduce memory usage and speed up loading. This setting causes torch.load() to more
# closely mirror the behaviour of safetensors.torch.load_file(). This option was first introduced in
# https://github.com/pytorch/pytorch/pull/102549. The discussion on that PR provides helpful context.
state_dict = torch.load(file_path, map_location=device, weights_only=True, mmap=True)
return state_dict