mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Security patch: Scan all pickle files, including VAEs; default to safetensor loading (#3011)
Several related security fixes: 1. Port #2946 from main to 2.3.2 branch - this closes a hole that allows a pickle checkpoint file to masquerade as a safetensors file. 2. Add pickle scanning to the checkpoint to diffusers conversion script. 3. Pickle scan VAE non-safetensors files 4. Avoid running scanner twice on same file during the probing and conversion process. 5. Clean up diagnostic messages.
This commit is contained in:
commit
b792b7d68c
@ -809,6 +809,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
vae:AutoencoderKL=None,
|
vae:AutoencoderKL=None,
|
||||||
precision:torch.dtype=torch.float32,
|
precision:torch.dtype=torch.float32,
|
||||||
return_generator_pipeline:bool=False,
|
return_generator_pipeline:bool=False,
|
||||||
|
scan_needed:bool=True,
|
||||||
)->Union[StableDiffusionPipeline,StableDiffusionGeneratorPipeline]:
|
)->Union[StableDiffusionPipeline,StableDiffusionGeneratorPipeline]:
|
||||||
'''
|
'''
|
||||||
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
||||||
@ -843,7 +844,12 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
verbosity = dlogging.get_verbosity()
|
verbosity = dlogging.get_verbosity()
|
||||||
dlogging.set_verbosity_error()
|
dlogging.set_verbosity_error()
|
||||||
|
|
||||||
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
|
if Path(checkpoint_path).suffix == '.ckpt':
|
||||||
|
if scan_needed:
|
||||||
|
ModelManager.scan_model(checkpoint_path,checkpoint_path)
|
||||||
|
checkpoint = torch.load(checkpoint_path)
|
||||||
|
else:
|
||||||
|
checkpoint = load_file(checkpoint_path)
|
||||||
cache_dir = global_cache_dir('hub')
|
cache_dir = global_cache_dir('hub')
|
||||||
pipeline_class = StableDiffusionGeneratorPipeline if return_generator_pipeline else StableDiffusionPipeline
|
pipeline_class = StableDiffusionGeneratorPipeline if return_generator_pipeline else StableDiffusionPipeline
|
||||||
|
|
||||||
@ -960,7 +966,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
else:
|
else:
|
||||||
print(' | Using external VAE specified in config')
|
print(' | Using VAE specified in config')
|
||||||
|
|
||||||
# Convert the text model.
|
# Convert the text model.
|
||||||
model_type = pipeline_type
|
model_type = pipeline_type
|
||||||
|
@ -282,13 +282,13 @@ class ModelManager(object):
|
|||||||
self.stack.remove(model_name)
|
self.stack.remove(model_name)
|
||||||
if delete_files:
|
if delete_files:
|
||||||
if weights:
|
if weights:
|
||||||
print(f"** deleting file {weights}")
|
print(f"** Deleting file {weights}")
|
||||||
Path(weights).unlink(missing_ok=True)
|
Path(weights).unlink(missing_ok=True)
|
||||||
elif path:
|
elif path:
|
||||||
print(f"** deleting directory {path}")
|
print(f"** Deleting directory {path}")
|
||||||
rmtree(path, ignore_errors=True)
|
rmtree(path, ignore_errors=True)
|
||||||
elif repo_id:
|
elif repo_id:
|
||||||
print(f"** deleting the cached model directory for {repo_id}")
|
print(f"** Deleting the cached model directory for {repo_id}")
|
||||||
self._delete_model_from_cache(repo_id)
|
self._delete_model_from_cache(repo_id)
|
||||||
|
|
||||||
def add_model(
|
def add_model(
|
||||||
@ -420,11 +420,6 @@ class ModelManager(object):
|
|||||||
"NOHASH",
|
"NOHASH",
|
||||||
)
|
)
|
||||||
|
|
||||||
# scan model
|
|
||||||
self.scan_model(model_name, weights)
|
|
||||||
|
|
||||||
print(f">> Loading {model_name} from {weights}")
|
|
||||||
|
|
||||||
# for usage statistics
|
# for usage statistics
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
@ -438,10 +433,13 @@ class ModelManager(object):
|
|||||||
weight_bytes = f.read()
|
weight_bytes = f.read()
|
||||||
model_hash = self._cached_sha256(weights, weight_bytes)
|
model_hash = self._cached_sha256(weights, weight_bytes)
|
||||||
sd = None
|
sd = None
|
||||||
if weights.endswith(".safetensors"):
|
|
||||||
sd = safetensors.torch.load(weight_bytes)
|
if weights.endswith(".ckpt"):
|
||||||
else:
|
self.scan_model(model_name, weights)
|
||||||
sd = torch.load(io.BytesIO(weight_bytes), map_location="cpu")
|
sd = torch.load(io.BytesIO(weight_bytes), map_location="cpu")
|
||||||
|
else:
|
||||||
|
sd = safetensors.torch.load(weight_bytes)
|
||||||
|
|
||||||
del weight_bytes
|
del weight_bytes
|
||||||
# merged models from auto11 merge board are flat for some reason
|
# merged models from auto11 merge board are flat for some reason
|
||||||
if "state_dict" in sd:
|
if "state_dict" in sd:
|
||||||
@ -464,18 +462,12 @@ class ModelManager(object):
|
|||||||
vae = os.path.normpath(os.path.join(Globals.root, vae))
|
vae = os.path.normpath(os.path.join(Globals.root, vae))
|
||||||
if os.path.exists(vae):
|
if os.path.exists(vae):
|
||||||
print(f" | Loading VAE weights from: {vae}")
|
print(f" | Loading VAE weights from: {vae}")
|
||||||
vae_ckpt = None
|
if vae.endswith((".ckpt",".pt")):
|
||||||
vae_dict = None
|
self.scan_model(vae,vae)
|
||||||
if vae.endswith(".safetensors"):
|
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||||
|
else:
|
||||||
vae_ckpt = safetensors.torch.load_file(vae)
|
vae_ckpt = safetensors.torch.load_file(vae)
|
||||||
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
|
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
|
||||||
else:
|
|
||||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
|
||||||
vae_dict = {
|
|
||||||
k: v
|
|
||||||
for k, v in vae_ckpt["state_dict"].items()
|
|
||||||
if k[0:4] != "loss"
|
|
||||||
}
|
|
||||||
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||||
else:
|
else:
|
||||||
print(f" | VAE file {vae} not found. Skipping.")
|
print(f" | VAE file {vae} not found. Skipping.")
|
||||||
@ -591,13 +583,14 @@ class ModelManager(object):
|
|||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def scan_model(self, model_name, checkpoint):
|
def scan_model(self, model_name, checkpoint):
|
||||||
"""
|
"""
|
||||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||||
and option to exit if an infected file is identified.
|
and option to exit if an infected file is identified.
|
||||||
"""
|
"""
|
||||||
# scan model
|
# scan model
|
||||||
print(f">> Scanning Model: {model_name}")
|
print(f" | Scanning Model: {model_name}")
|
||||||
scan_result = scan_file_path(checkpoint)
|
scan_result = scan_file_path(checkpoint)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
if scan_result.infected_files == 1:
|
if scan_result.infected_files == 1:
|
||||||
@ -620,7 +613,7 @@ class ModelManager(object):
|
|||||||
print("### Exiting InvokeAI")
|
print("### Exiting InvokeAI")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
else:
|
else:
|
||||||
print(">> Model scanned ok")
|
print(" | Model scanned ok")
|
||||||
|
|
||||||
def import_diffuser_model(
|
def import_diffuser_model(
|
||||||
self,
|
self,
|
||||||
@ -805,6 +798,7 @@ class ModelManager(object):
|
|||||||
thing, "models/ldm/stable-diffusion-v1"
|
thing, "models/ldm/stable-diffusion-v1"
|
||||||
) # _resolve_path does a download if needed
|
) # _resolve_path does a download if needed
|
||||||
is_temporary = True
|
is_temporary = True
|
||||||
|
|
||||||
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
||||||
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
||||||
print(
|
print(
|
||||||
@ -869,11 +863,12 @@ class ModelManager(object):
|
|||||||
return model_path.stem
|
return model_path.stem
|
||||||
|
|
||||||
# another round of heuristics to guess the correct config file.
|
# another round of heuristics to guess the correct config file.
|
||||||
checkpoint = (
|
checkpoint = None
|
||||||
safetensors.torch.load_file(model_path)
|
if model_path.suffix.endswith((".ckpt",".pt")):
|
||||||
if model_path.suffix == ".safetensors"
|
self.scan_model(model_path,model_path)
|
||||||
else torch.load(model_path)
|
checkpoint = torch.load(model_path)
|
||||||
)
|
else:
|
||||||
|
checkpoint = safetensors.torch.load_file(model_path)
|
||||||
# additional probing needed if no config file provided
|
# additional probing needed if no config file provided
|
||||||
if model_config_file is None:
|
if model_config_file is None:
|
||||||
model_type = self.probe_model_type(checkpoint)
|
model_type = self.probe_model_type(checkpoint)
|
||||||
@ -933,6 +928,7 @@ class ModelManager(object):
|
|||||||
model_description=description,
|
model_description=description,
|
||||||
original_config_file=model_config_file,
|
original_config_file=model_config_file,
|
||||||
commit_to_conf=commit_to_conf,
|
commit_to_conf=commit_to_conf,
|
||||||
|
scan_needed=False,
|
||||||
)
|
)
|
||||||
# in the event that this file was downloaded automatically prior to conversion
|
# in the event that this file was downloaded automatically prior to conversion
|
||||||
# we do not keep the original .ckpt/.safetensors around
|
# we do not keep the original .ckpt/.safetensors around
|
||||||
@ -965,6 +961,7 @@ class ModelManager(object):
|
|||||||
vae=None,
|
vae=None,
|
||||||
original_config_file: Path = None,
|
original_config_file: Path = None,
|
||||||
commit_to_conf: Path = None,
|
commit_to_conf: Path = None,
|
||||||
|
scan_needed: bool=True,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert a legacy ckpt weights file to diffuser model and import
|
Convert a legacy ckpt weights file to diffuser model and import
|
||||||
@ -999,6 +996,7 @@ class ModelManager(object):
|
|||||||
extract_ema=True,
|
extract_ema=True,
|
||||||
original_config_file=original_config_file,
|
original_config_file=original_config_file,
|
||||||
vae=vae_model,
|
vae=vae_model,
|
||||||
|
scan_needed=scan_needed,
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f" | Success. Optimized model is now located at {str(diffusers_path)}"
|
f" | Success. Optimized model is now located at {str(diffusers_path)}"
|
||||||
@ -1396,7 +1394,7 @@ class ModelManager(object):
|
|||||||
hashes_to_delete.add(revision.commit_hash)
|
hashes_to_delete.add(revision.commit_hash)
|
||||||
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
||||||
print(
|
print(
|
||||||
f"** deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||||
)
|
)
|
||||||
strategy.execute()
|
strategy.execute()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user