From e0e01f6c505bafb75e10895193a35d147eaa7282 Mon Sep 17 00:00:00 2001 From: jeremy Date: Mon, 13 Mar 2023 16:16:30 -0400 Subject: [PATCH] Reduced Pickle ACE attack surface Prior to this commit, all models would be loaded with the extremely unsafe `torch.load` method, except those with the exact extension `.safetensors`. Even a change in casing (eg. `saFetensors`, `Safetensors`, etc) would cause the file to be loaded with torch.load instead of the much safer `safetensors.toch.load_file`. If a malicious actor renamed an infected `.ckpt` to something like `.SafeTensors` or `.SAFETENSORS` an unsuspecting user would think they are loading a safe .safetensor, but would in fact be parsing an unsafe pickle file, and executing an attacker's payload. This commit fixes this vulnerability by reversing the loading-method decision logic to only use the unsafe `torch.load` when the file extension is exactly `.ckpt`. --- .../backend/model_management/convert_ckpt_to_diffusers.py | 7 ++++--- invokeai/backend/model_management/model_manager.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index ae5550880a..793ba024cf 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -1075,9 +1075,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt( dlogging.set_verbosity_error() checkpoint = ( - load_file(checkpoint_path) - if Path(checkpoint_path).suffix == ".safetensors" - else torch.load(checkpoint_path) + torch.load(checkpoint_path) + if Path(checkpoint_path).suffix == ".ckpt" + else load_file(checkpoint_path) + ) cache_dir = global_cache_dir("hub") pipeline_class = ( diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 84e2ab378b..9464057f71 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -732,9 +732,9 @@ class ModelManager(object): # another round of heuristics to guess the correct config file. checkpoint = ( - safetensors.torch.load_file(model_path) - if model_path.suffix == ".safetensors" - else torch.load(model_path) + torch.load(model_path) + if model_path.suffix == ".ckpt" + else safetensors.torch.load_file(model_path) ) # additional probing needed if no config file provided