scan legacy checkpoint models in converter script prior to unpickling

Two related security fixes:

1. Port #2946 from main to 2.3.2 branch - this closes a hole that
   allows a pickle checkpoint file to masquerade as a safetensors
   file.

2. Add pickle scanning to the checkpoint to diffusers conversion
   script. This will be ported to main in a separate PR.
This commit is contained in:
Lincoln Stein 2023-03-23 13:44:08 -04:00
parent a044403ac3
commit ba89444e36
2 changed files with 12 additions and 5 deletions

View File

@ -843,7 +843,11 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
verbosity = dlogging.get_verbosity() verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error() dlogging.set_verbosity_error()
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path) if Path(checkpoint_path).suffix == '.ckpt':
ModelManager.scan_model(checkpoint_path,checkpoint_path)
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = load_file(checkpoint_path)
cache_dir = global_cache_dir('hub') cache_dir = global_cache_dir('hub')
pipeline_class = StableDiffusionGeneratorPipeline if return_generator_pipeline else StableDiffusionPipeline pipeline_class = StableDiffusionGeneratorPipeline if return_generator_pipeline else StableDiffusionPipeline
@ -960,7 +964,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)
else: else:
print(' | Using external VAE specified in config') print(' | Using VAE specified in config')
# Convert the text model. # Convert the text model.
model_type = pipeline_type model_type = pipeline_type

View File

@ -438,10 +438,12 @@ class ModelManager(object):
weight_bytes = f.read() weight_bytes = f.read()
model_hash = self._cached_sha256(weights, weight_bytes) model_hash = self._cached_sha256(weights, weight_bytes)
sd = None sd = None
if weights.endswith(".safetensors"):
sd = safetensors.torch.load(weight_bytes) if weights.endswith(".ckpt"):
else:
sd = torch.load(io.BytesIO(weight_bytes), map_location="cpu") sd = torch.load(io.BytesIO(weight_bytes), map_location="cpu")
else:
sd = safetensors.torch.load(weight_bytes)
del weight_bytes del weight_bytes
# merged models from auto11 merge board are flat for some reason # merged models from auto11 merge board are flat for some reason
if "state_dict" in sd: if "state_dict" in sd:
@ -591,6 +593,7 @@ class ModelManager(object):
if self._has_cuda(): if self._has_cuda():
torch.cuda.empty_cache() torch.cuda.empty_cache()
@classmethod
def scan_model(self, model_name, checkpoint): def scan_model(self, model_name, checkpoint):
""" """
Apply picklescanner to the indicated checkpoint and issue a warning Apply picklescanner to the indicated checkpoint and issue a warning