Reduced Pickle ACE attack surface

Prior to this commit, all models would be loaded with the extremely unsafe `torch.load` method, except those with the exact extension `.safetensors`. Even a change in casing (eg. `saFetensors`, `Safetensors`, etc) would cause the file to be loaded with torch.load instead of the much safer `safetensors.toch.load_file`.
If a malicious actor renamed an infected `.ckpt` to something like `.SafeTensors` or `.SAFETENSORS` an unsuspecting user would think they are loading a safe .safetensor, but would in fact be parsing an unsafe pickle file, and executing an attacker's payload. This commit fixes this vulnerability by reversing the loading-method decision logic to only use the unsafe `torch.load` when the file extension is exactly `.ckpt`.
This commit is contained in:
jeremy 2023-03-13 16:16:30 -04:00
parent d9dab1b6c7
commit e0e01f6c50
2 changed files with 7 additions and 6 deletions

View File

@ -1075,9 +1075,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
dlogging.set_verbosity_error()
checkpoint = (
load_file(checkpoint_path)
if Path(checkpoint_path).suffix == ".safetensors"
else torch.load(checkpoint_path)
torch.load(checkpoint_path)
if Path(checkpoint_path).suffix == ".ckpt"
else load_file(checkpoint_path)
)
cache_dir = global_cache_dir("hub")
pipeline_class = (

View File

@ -732,9 +732,9 @@ class ModelManager(object):
# another round of heuristics to guess the correct config file.
checkpoint = (
safetensors.torch.load_file(model_path)
if model_path.suffix == ".safetensors"
else torch.load(model_path)
torch.load(model_path)
if model_path.suffix == ".ckpt"
else safetensors.torch.load_file(model_path)
)
# additional probing needed if no config file provided