mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
scan legacy checkpoint models in converter script prior to unpickling
Two related security fixes: 1. Port #2946 from main to 2.3.2 branch - this closes a hole that allows a pickle checkpoint file to masquerade as a safetensors file. 2. Add pickle scanning to the checkpoint to diffusers conversion script. This will be ported to main in a separate PR.
This commit is contained in:
parent
a044403ac3
commit
ba89444e36
@ -843,7 +843,11 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
|
||||
if Path(checkpoint_path).suffix == '.ckpt':
|
||||
ModelManager.scan_model(checkpoint_path,checkpoint_path)
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
else:
|
||||
checkpoint = load_file(checkpoint_path)
|
||||
cache_dir = global_cache_dir('hub')
|
||||
pipeline_class = StableDiffusionGeneratorPipeline if return_generator_pipeline else StableDiffusionPipeline
|
||||
|
||||
@ -960,7 +964,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
else:
|
||||
print(' | Using external VAE specified in config')
|
||||
print(' | Using VAE specified in config')
|
||||
|
||||
# Convert the text model.
|
||||
model_type = pipeline_type
|
||||
|
@ -438,10 +438,12 @@ class ModelManager(object):
|
||||
weight_bytes = f.read()
|
||||
model_hash = self._cached_sha256(weights, weight_bytes)
|
||||
sd = None
|
||||
if weights.endswith(".safetensors"):
|
||||
sd = safetensors.torch.load(weight_bytes)
|
||||
else:
|
||||
|
||||
if weights.endswith(".ckpt"):
|
||||
sd = torch.load(io.BytesIO(weight_bytes), map_location="cpu")
|
||||
else:
|
||||
sd = safetensors.torch.load(weight_bytes)
|
||||
|
||||
del weight_bytes
|
||||
# merged models from auto11 merge board are flat for some reason
|
||||
if "state_dict" in sd:
|
||||
@ -591,6 +593,7 @@ class ModelManager(object):
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@classmethod
|
||||
def scan_model(self, model_name, checkpoint):
|
||||
"""
|
||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||
|
Loading…
Reference in New Issue
Block a user