mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Security patch: Scan all pickle files, including VAEs; default to safetensor loading (#3011)
Several related security fixes: 1. Port #2946 from main to 2.3.2 branch - this closes a hole that allows a pickle checkpoint file to masquerade as a safetensors file. 2. Add pickle scanning to the checkpoint to diffusers conversion script. 3. Pickle scan VAE non-safetensors files 4. Avoid running scanner twice on same file during the probing and conversion process. 5. Clean up diagnostic messages.
This commit is contained in:
commit
b792b7d68c
@ -327,10 +327,10 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||||||
unet_key = "model.diffusion_model."
|
unet_key = "model.diffusion_model."
|
||||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||||
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
|
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
|
||||||
if extract_ema:
|
if extract_ema:
|
||||||
print(
|
print(
|
||||||
' | Extracting EMA weights (usually better for inference)'
|
' | Extracting EMA weights (usually better for inference)'
|
||||||
)
|
)
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith("model.diffusion_model"):
|
if key.startswith("model.diffusion_model"):
|
||||||
@ -338,7 +338,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
' | Extracting only the non-EMA weights (usually better for fine-tuning)'
|
' | Extracting only the non-EMA weights (usually better for fine-tuning)'
|
||||||
)
|
)
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
@ -809,6 +809,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
vae:AutoencoderKL=None,
|
vae:AutoencoderKL=None,
|
||||||
precision:torch.dtype=torch.float32,
|
precision:torch.dtype=torch.float32,
|
||||||
return_generator_pipeline:bool=False,
|
return_generator_pipeline:bool=False,
|
||||||
|
scan_needed:bool=True,
|
||||||
)->Union[StableDiffusionPipeline,StableDiffusionGeneratorPipeline]:
|
)->Union[StableDiffusionPipeline,StableDiffusionGeneratorPipeline]:
|
||||||
'''
|
'''
|
||||||
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
||||||
@ -843,7 +844,12 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
verbosity = dlogging.get_verbosity()
|
verbosity = dlogging.get_verbosity()
|
||||||
dlogging.set_verbosity_error()
|
dlogging.set_verbosity_error()
|
||||||
|
|
||||||
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
|
if Path(checkpoint_path).suffix == '.ckpt':
|
||||||
|
if scan_needed:
|
||||||
|
ModelManager.scan_model(checkpoint_path,checkpoint_path)
|
||||||
|
checkpoint = torch.load(checkpoint_path)
|
||||||
|
else:
|
||||||
|
checkpoint = load_file(checkpoint_path)
|
||||||
cache_dir = global_cache_dir('hub')
|
cache_dir = global_cache_dir('hub')
|
||||||
pipeline_class = StableDiffusionGeneratorPipeline if return_generator_pipeline else StableDiffusionPipeline
|
pipeline_class = StableDiffusionGeneratorPipeline if return_generator_pipeline else StableDiffusionPipeline
|
||||||
|
|
||||||
@ -851,7 +857,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
if "global_step" in checkpoint:
|
if "global_step" in checkpoint:
|
||||||
global_step = checkpoint["global_step"]
|
global_step = checkpoint["global_step"]
|
||||||
else:
|
else:
|
||||||
print(" | global_step key not found in model")
|
print(" | global_step key not found in model")
|
||||||
global_step = None
|
global_step = None
|
||||||
|
|
||||||
# sometimes there is a state_dict key and sometimes not
|
# sometimes there is a state_dict key and sometimes not
|
||||||
@ -953,14 +959,14 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
|
|
||||||
# Convert the VAE model, or use the one passed
|
# Convert the VAE model, or use the one passed
|
||||||
if not vae:
|
if not vae:
|
||||||
print(' | Using checkpoint model\'s original VAE')
|
print(' | Using checkpoint model\'s original VAE')
|
||||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
else:
|
else:
|
||||||
print(' | Using external VAE specified in config')
|
print(' | Using VAE specified in config')
|
||||||
|
|
||||||
# Convert the text model.
|
# Convert the text model.
|
||||||
model_type = pipeline_type
|
model_type = pipeline_type
|
||||||
|
@ -282,13 +282,13 @@ class ModelManager(object):
|
|||||||
self.stack.remove(model_name)
|
self.stack.remove(model_name)
|
||||||
if delete_files:
|
if delete_files:
|
||||||
if weights:
|
if weights:
|
||||||
print(f"** deleting file {weights}")
|
print(f"** Deleting file {weights}")
|
||||||
Path(weights).unlink(missing_ok=True)
|
Path(weights).unlink(missing_ok=True)
|
||||||
elif path:
|
elif path:
|
||||||
print(f"** deleting directory {path}")
|
print(f"** Deleting directory {path}")
|
||||||
rmtree(path, ignore_errors=True)
|
rmtree(path, ignore_errors=True)
|
||||||
elif repo_id:
|
elif repo_id:
|
||||||
print(f"** deleting the cached model directory for {repo_id}")
|
print(f"** Deleting the cached model directory for {repo_id}")
|
||||||
self._delete_model_from_cache(repo_id)
|
self._delete_model_from_cache(repo_id)
|
||||||
|
|
||||||
def add_model(
|
def add_model(
|
||||||
@ -420,11 +420,6 @@ class ModelManager(object):
|
|||||||
"NOHASH",
|
"NOHASH",
|
||||||
)
|
)
|
||||||
|
|
||||||
# scan model
|
|
||||||
self.scan_model(model_name, weights)
|
|
||||||
|
|
||||||
print(f">> Loading {model_name} from {weights}")
|
|
||||||
|
|
||||||
# for usage statistics
|
# for usage statistics
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
@ -438,10 +433,13 @@ class ModelManager(object):
|
|||||||
weight_bytes = f.read()
|
weight_bytes = f.read()
|
||||||
model_hash = self._cached_sha256(weights, weight_bytes)
|
model_hash = self._cached_sha256(weights, weight_bytes)
|
||||||
sd = None
|
sd = None
|
||||||
if weights.endswith(".safetensors"):
|
|
||||||
sd = safetensors.torch.load(weight_bytes)
|
if weights.endswith(".ckpt"):
|
||||||
else:
|
self.scan_model(model_name, weights)
|
||||||
sd = torch.load(io.BytesIO(weight_bytes), map_location="cpu")
|
sd = torch.load(io.BytesIO(weight_bytes), map_location="cpu")
|
||||||
|
else:
|
||||||
|
sd = safetensors.torch.load(weight_bytes)
|
||||||
|
|
||||||
del weight_bytes
|
del weight_bytes
|
||||||
# merged models from auto11 merge board are flat for some reason
|
# merged models from auto11 merge board are flat for some reason
|
||||||
if "state_dict" in sd:
|
if "state_dict" in sd:
|
||||||
@ -464,18 +462,12 @@ class ModelManager(object):
|
|||||||
vae = os.path.normpath(os.path.join(Globals.root, vae))
|
vae = os.path.normpath(os.path.join(Globals.root, vae))
|
||||||
if os.path.exists(vae):
|
if os.path.exists(vae):
|
||||||
print(f" | Loading VAE weights from: {vae}")
|
print(f" | Loading VAE weights from: {vae}")
|
||||||
vae_ckpt = None
|
if vae.endswith((".ckpt",".pt")):
|
||||||
vae_dict = None
|
self.scan_model(vae,vae)
|
||||||
if vae.endswith(".safetensors"):
|
|
||||||
vae_ckpt = safetensors.torch.load_file(vae)
|
|
||||||
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
|
|
||||||
else:
|
|
||||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||||
vae_dict = {
|
else:
|
||||||
k: v
|
vae_ckpt = safetensors.torch.load_file(vae)
|
||||||
for k, v in vae_ckpt["state_dict"].items()
|
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
|
||||||
if k[0:4] != "loss"
|
|
||||||
}
|
|
||||||
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||||
else:
|
else:
|
||||||
print(f" | VAE file {vae} not found. Skipping.")
|
print(f" | VAE file {vae} not found. Skipping.")
|
||||||
@ -497,9 +489,9 @@ class ModelManager(object):
|
|||||||
|
|
||||||
print(f">> Loading diffusers model from {name_or_path}")
|
print(f">> Loading diffusers model from {name_or_path}")
|
||||||
if using_fp16:
|
if using_fp16:
|
||||||
print(" | Using faster float16 precision")
|
print(" | Using faster float16 precision")
|
||||||
else:
|
else:
|
||||||
print(" | Using more accurate float32 precision")
|
print(" | Using more accurate float32 precision")
|
||||||
|
|
||||||
# TODO: scan weights maybe?
|
# TODO: scan weights maybe?
|
||||||
pipeline_args: dict[str, Any] = dict(
|
pipeline_args: dict[str, Any] = dict(
|
||||||
@ -551,7 +543,7 @@ class ModelManager(object):
|
|||||||
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||||
height = width
|
height = width
|
||||||
|
|
||||||
print(f" | Default image dimensions = {width} x {height}")
|
print(f" | Default image dimensions = {width} x {height}")
|
||||||
|
|
||||||
return pipeline, width, height, model_hash
|
return pipeline, width, height, model_hash
|
||||||
|
|
||||||
@ -591,13 +583,14 @@ class ModelManager(object):
|
|||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def scan_model(self, model_name, checkpoint):
|
def scan_model(self, model_name, checkpoint):
|
||||||
"""
|
"""
|
||||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||||
and option to exit if an infected file is identified.
|
and option to exit if an infected file is identified.
|
||||||
"""
|
"""
|
||||||
# scan model
|
# scan model
|
||||||
print(f">> Scanning Model: {model_name}")
|
print(f" | Scanning Model: {model_name}")
|
||||||
scan_result = scan_file_path(checkpoint)
|
scan_result = scan_file_path(checkpoint)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
if scan_result.infected_files == 1:
|
if scan_result.infected_files == 1:
|
||||||
@ -620,7 +613,7 @@ class ModelManager(object):
|
|||||||
print("### Exiting InvokeAI")
|
print("### Exiting InvokeAI")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
else:
|
else:
|
||||||
print(">> Model scanned ok")
|
print(" | Model scanned ok")
|
||||||
|
|
||||||
def import_diffuser_model(
|
def import_diffuser_model(
|
||||||
self,
|
self,
|
||||||
@ -800,19 +793,20 @@ class ModelManager(object):
|
|||||||
print(f">> Probing {thing} for import")
|
print(f">> Probing {thing} for import")
|
||||||
|
|
||||||
if thing.startswith(("http:", "https:", "ftp:")):
|
if thing.startswith(("http:", "https:", "ftp:")):
|
||||||
print(f" | {thing} appears to be a URL")
|
print(f" | {thing} appears to be a URL")
|
||||||
model_path = self._resolve_path(
|
model_path = self._resolve_path(
|
||||||
thing, "models/ldm/stable-diffusion-v1"
|
thing, "models/ldm/stable-diffusion-v1"
|
||||||
) # _resolve_path does a download if needed
|
) # _resolve_path does a download if needed
|
||||||
is_temporary = True
|
is_temporary = True
|
||||||
|
|
||||||
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
||||||
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
||||||
print(
|
print(
|
||||||
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
|
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
print(f" | {thing} appears to be a checkpoint file on disk")
|
print(f" | {thing} appears to be a checkpoint file on disk")
|
||||||
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
||||||
|
|
||||||
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
||||||
@ -869,11 +863,12 @@ class ModelManager(object):
|
|||||||
return model_path.stem
|
return model_path.stem
|
||||||
|
|
||||||
# another round of heuristics to guess the correct config file.
|
# another round of heuristics to guess the correct config file.
|
||||||
checkpoint = (
|
checkpoint = None
|
||||||
safetensors.torch.load_file(model_path)
|
if model_path.suffix.endswith((".ckpt",".pt")):
|
||||||
if model_path.suffix == ".safetensors"
|
self.scan_model(model_path,model_path)
|
||||||
else torch.load(model_path)
|
checkpoint = torch.load(model_path)
|
||||||
)
|
else:
|
||||||
|
checkpoint = safetensors.torch.load_file(model_path)
|
||||||
# additional probing needed if no config file provided
|
# additional probing needed if no config file provided
|
||||||
if model_config_file is None:
|
if model_config_file is None:
|
||||||
model_type = self.probe_model_type(checkpoint)
|
model_type = self.probe_model_type(checkpoint)
|
||||||
@ -918,7 +913,7 @@ class ModelManager(object):
|
|||||||
if model_config_file.name.startswith('v2'):
|
if model_config_file.name.startswith('v2'):
|
||||||
convert = True
|
convert = True
|
||||||
print(
|
print(
|
||||||
" | This SD-v2 model will be converted to diffusers format for use"
|
" | This SD-v2 model will be converted to diffusers format for use"
|
||||||
)
|
)
|
||||||
|
|
||||||
if convert:
|
if convert:
|
||||||
@ -933,6 +928,7 @@ class ModelManager(object):
|
|||||||
model_description=description,
|
model_description=description,
|
||||||
original_config_file=model_config_file,
|
original_config_file=model_config_file,
|
||||||
commit_to_conf=commit_to_conf,
|
commit_to_conf=commit_to_conf,
|
||||||
|
scan_needed=False,
|
||||||
)
|
)
|
||||||
# in the event that this file was downloaded automatically prior to conversion
|
# in the event that this file was downloaded automatically prior to conversion
|
||||||
# we do not keep the original .ckpt/.safetensors around
|
# we do not keep the original .ckpt/.safetensors around
|
||||||
@ -957,14 +953,15 @@ class ModelManager(object):
|
|||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
def convert_and_import(
|
def convert_and_import(
|
||||||
self,
|
self,
|
||||||
ckpt_path: Path,
|
ckpt_path: Path,
|
||||||
diffusers_path: Path,
|
diffusers_path: Path,
|
||||||
model_name=None,
|
model_name=None,
|
||||||
model_description=None,
|
model_description=None,
|
||||||
vae=None,
|
vae=None,
|
||||||
original_config_file: Path = None,
|
original_config_file: Path = None,
|
||||||
commit_to_conf: Path = None,
|
commit_to_conf: Path = None,
|
||||||
|
scan_needed: bool=True,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert a legacy ckpt weights file to diffuser model and import
|
Convert a legacy ckpt weights file to diffuser model and import
|
||||||
@ -999,11 +996,12 @@ class ModelManager(object):
|
|||||||
extract_ema=True,
|
extract_ema=True,
|
||||||
original_config_file=original_config_file,
|
original_config_file=original_config_file,
|
||||||
vae=vae_model,
|
vae=vae_model,
|
||||||
|
scan_needed=scan_needed,
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f" | Success. Optimized model is now located at {str(diffusers_path)}"
|
f" | Success. Optimized model is now located at {str(diffusers_path)}"
|
||||||
)
|
)
|
||||||
print(f" | Writing new config file entry for {model_name}")
|
print(f" | Writing new config file entry for {model_name}")
|
||||||
new_config = dict(
|
new_config = dict(
|
||||||
path=str(diffusers_path),
|
path=str(diffusers_path),
|
||||||
description=model_description,
|
description=model_description,
|
||||||
@ -1293,7 +1291,7 @@ class ModelManager(object):
|
|||||||
with open(hashpath) as f:
|
with open(hashpath) as f:
|
||||||
hash = f.read()
|
hash = f.read()
|
||||||
return hash
|
return hash
|
||||||
print(" | Calculating sha256 hash of model files")
|
print(" | Calculating sha256 hash of model files")
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
count = 0
|
count = 0
|
||||||
@ -1305,7 +1303,7 @@ class ModelManager(object):
|
|||||||
sha.update(chunk)
|
sha.update(chunk)
|
||||||
hash = sha.hexdigest()
|
hash = sha.hexdigest()
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
||||||
with open(hashpath, "w") as f:
|
with open(hashpath, "w") as f:
|
||||||
f.write(hash)
|
f.write(hash)
|
||||||
return hash
|
return hash
|
||||||
@ -1350,12 +1348,12 @@ class ModelManager(object):
|
|||||||
local_files_only=not Globals.internet_available,
|
local_files_only=not Globals.internet_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f" | Loading diffusers VAE from {name_or_path}")
|
print(f" | Loading diffusers VAE from {name_or_path}")
|
||||||
if using_fp16:
|
if using_fp16:
|
||||||
vae_args.update(torch_dtype=torch.float16)
|
vae_args.update(torch_dtype=torch.float16)
|
||||||
fp_args_list = [{"revision": "fp16"}, {}]
|
fp_args_list = [{"revision": "fp16"}, {}]
|
||||||
else:
|
else:
|
||||||
print(" | Using more accurate float32 precision")
|
print(" | Using more accurate float32 precision")
|
||||||
fp_args_list = [{}]
|
fp_args_list = [{}]
|
||||||
|
|
||||||
vae = None
|
vae = None
|
||||||
@ -1396,7 +1394,7 @@ class ModelManager(object):
|
|||||||
hashes_to_delete.add(revision.commit_hash)
|
hashes_to_delete.add(revision.commit_hash)
|
||||||
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
||||||
print(
|
print(
|
||||||
f"** deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||||
)
|
)
|
||||||
strategy.execute()
|
strategy.execute()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user