Fix attempt to deserialize on CUDA on Mac

Without specifying "cpu", attempts to use non-existent CUDA to deserialize embeddings on macOS, resulting in a warning / failure to load.
This commit is contained in:
Steven Frank 2023-11-25 10:38:27 -08:00 committed by psychedelicious
parent 1d8f44d356
commit e509d719ee

View File

@ -225,7 +225,7 @@ class ModelProbe(object):
with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model_path, model_path)
return torch.load(model_path)
return torch.load(model_path, map_location="cpu")
else:
return safetensors.torch.load_file(model_path)