mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
scan legacy checkpoint models in converter script prior to unpickling
Two related security fixes: 1. Port #2946 from main to 2.3.2 branch - this closes a hole that allows a pickle checkpoint file to masquerade as a safetensors file. 2. Add pickle scanning to the checkpoint to diffusers conversion script. This will be ported to main in a separate PR.
This commit is contained in:
parent
a044403ac3
commit
ba89444e36
@ -843,7 +843,11 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
verbosity = dlogging.get_verbosity()
|
verbosity = dlogging.get_verbosity()
|
||||||
dlogging.set_verbosity_error()
|
dlogging.set_verbosity_error()
|
||||||
|
|
||||||
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
|
if Path(checkpoint_path).suffix == '.ckpt':
|
||||||
|
ModelManager.scan_model(checkpoint_path,checkpoint_path)
|
||||||
|
checkpoint = torch.load(checkpoint_path)
|
||||||
|
else:
|
||||||
|
checkpoint = load_file(checkpoint_path)
|
||||||
cache_dir = global_cache_dir('hub')
|
cache_dir = global_cache_dir('hub')
|
||||||
pipeline_class = StableDiffusionGeneratorPipeline if return_generator_pipeline else StableDiffusionPipeline
|
pipeline_class = StableDiffusionGeneratorPipeline if return_generator_pipeline else StableDiffusionPipeline
|
||||||
|
|
||||||
@ -960,7 +964,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
else:
|
else:
|
||||||
print(' | Using external VAE specified in config')
|
print(' | Using VAE specified in config')
|
||||||
|
|
||||||
# Convert the text model.
|
# Convert the text model.
|
||||||
model_type = pipeline_type
|
model_type = pipeline_type
|
||||||
|
@ -438,10 +438,12 @@ class ModelManager(object):
|
|||||||
weight_bytes = f.read()
|
weight_bytes = f.read()
|
||||||
model_hash = self._cached_sha256(weights, weight_bytes)
|
model_hash = self._cached_sha256(weights, weight_bytes)
|
||||||
sd = None
|
sd = None
|
||||||
if weights.endswith(".safetensors"):
|
|
||||||
sd = safetensors.torch.load(weight_bytes)
|
if weights.endswith(".ckpt"):
|
||||||
else:
|
|
||||||
sd = torch.load(io.BytesIO(weight_bytes), map_location="cpu")
|
sd = torch.load(io.BytesIO(weight_bytes), map_location="cpu")
|
||||||
|
else:
|
||||||
|
sd = safetensors.torch.load(weight_bytes)
|
||||||
|
|
||||||
del weight_bytes
|
del weight_bytes
|
||||||
# merged models from auto11 merge board are flat for some reason
|
# merged models from auto11 merge board are flat for some reason
|
||||||
if "state_dict" in sd:
|
if "state_dict" in sd:
|
||||||
@ -591,6 +593,7 @@ class ModelManager(object):
|
|||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def scan_model(self, model_name, checkpoint):
|
def scan_model(self, model_name, checkpoint):
|
||||||
"""
|
"""
|
||||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||||
|
Loading…
x
Reference in New Issue
Block a user