re-implement model scanning when loading legacy checkpoint files

- This PR turns on pickle scanning before a legacy checkpoint file
  is loaded from disk within the checkpoint_to_diffusers module.

- Also miscellaneous diagnostic message cleanup.
This commit is contained in:
Lincoln Stein 2023-03-23 15:03:30 -04:00
parent 485f6e5954
commit b2ce45a417
2 changed files with 53 additions and 46 deletions

View File

@ -372,9 +372,9 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100:
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
if extract_ema:
print(" | Extracting EMA weights (usually better for inference)")
print(" | Extracting EMA weights (usually better for inference)")
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
@ -383,7 +383,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
)
else:
print(
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
)
for key in keys:
@ -1040,6 +1040,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
vae: AutoencoderKL = None,
precision: torch.dtype = torch.float32,
return_generator_pipeline: bool = False,
scan_needed:bool=True,
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
"""
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
@ -1074,12 +1075,13 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
checkpoint = (
torch.load(checkpoint_path)
if Path(checkpoint_path).suffix == ".ckpt"
else load_file(checkpoint_path)
)
if Path(checkpoint_path).suffix == '.ckpt':
if scan_needed:
ModelManager.scan_model(checkpoint_path,checkpoint_path)
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = load_file(checkpoint_path)
cache_dir = global_cache_dir("hub")
pipeline_class = (
StableDiffusionGeneratorPipeline
@ -1091,7 +1093,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
print(" | global_step key not found in model")
print(" | global_step key not found in model")
global_step = None
# sometimes there is a state_dict key and sometimes not
@ -1204,7 +1206,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
# Convert the VAE model, or use the one passed
if not vae:
print(" | Using checkpoint model's original VAE")
print(" | Using checkpoint model's original VAE")
vae_config = create_vae_diffusers_config(
original_config, image_size=image_size
)
@ -1215,7 +1217,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
else:
print(" | Using external VAE specified in config")
print(" | Using external VAE specified in config")
# Convert the text model.
model_type = pipeline_type

View File

@ -34,7 +34,7 @@ from picklescan.scanner import scan_file_path
from invokeai.backend.globals import Globals, global_cache_dir
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util import CUDA_DEVICE, CPU_DEVICE, ask_user, download_with_resume
from ..util import CUDA_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum):
V1 = 1
@ -285,13 +285,13 @@ class ModelManager(object):
self.stack.remove(model_name)
if delete_files:
if weights:
print(f"** deleting file {weights}")
print(f"** Deleting file {weights}")
Path(weights).unlink(missing_ok=True)
elif path:
print(f"** deleting directory {path}")
print(f"** Deleting directory {path}")
rmtree(path, ignore_errors=True)
elif repo_id:
print(f"** deleting the cached model directory for {repo_id}")
print(f"** Deleting the cached model directory for {repo_id}")
self._delete_model_from_cache(repo_id)
def add_model(
@ -381,9 +381,9 @@ class ModelManager(object):
print(f">> Loading diffusers model from {name_or_path}")
if using_fp16:
print(" | Using faster float16 precision")
print(" | Using faster float16 precision")
else:
print(" | Using more accurate float32 precision")
print(" | Using more accurate float32 precision")
# TODO: scan weights maybe?
pipeline_args: dict[str, Any] = dict(
@ -435,7 +435,7 @@ class ModelManager(object):
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
height = width
print(f" | Default image dimensions = {width} x {height}")
print(f" | Default image dimensions = {width} x {height}")
self._add_embeddings_to_model(pipeline)
return pipeline, width, height, model_hash
@ -517,13 +517,14 @@ class ModelManager(object):
if self._has_cuda():
torch.cuda.empty_cache()
@classmethod
def scan_model(self, model_name, checkpoint):
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
print(f">> Scanning Model: {model_name}")
print(f" | Scanning Model: {model_name}")
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
if scan_result.infected_files == 1:
@ -546,7 +547,7 @@ class ModelManager(object):
print("### Exiting InvokeAI")
sys.exit()
else:
print(">> Model scanned ok")
print(" | Model scanned ok")
def import_diffuser_model(
self,
@ -665,7 +666,7 @@ class ModelManager(object):
print(f">> Probing {thing} for import")
if thing.startswith(("http:", "https:", "ftp:")):
print(f" | {thing} appears to be a URL")
print(f" | {thing} appears to be a URL")
model_path = self._resolve_path(
thing, "models/ldm/stable-diffusion-v1"
) # _resolve_path does a download if needed
@ -673,15 +674,15 @@ class ModelManager(object):
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
print(
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
)
return
else:
print(f" | {thing} appears to be a checkpoint file on disk")
print(f" | {thing} appears to be a checkpoint file on disk")
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
print(f" | {thing} appears to be a diffusers file on disk")
print(f" | {thing} appears to be a diffusers file on disk")
model_name = self.import_diffuser_model(
thing,
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
@ -692,13 +693,13 @@ class ModelManager(object):
elif Path(thing).is_dir():
if (Path(thing) / "model_index.json").exists():
print(f" | {thing} appears to be a diffusers model.")
print(f" | {thing} appears to be a diffusers model.")
model_name = self.import_diffuser_model(
thing, commit_to_conf=commit_to_conf
)
else:
print(
f" |{thing} appears to be a directory. Will scan for models to import"
f" |{thing} appears to be a directory. Will scan for models to import"
)
for m in list(Path(thing).rglob("*.ckpt")) + list(
Path(thing).rglob("*.safetensors")
@ -710,7 +711,7 @@ class ModelManager(object):
return model_name
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
model_name = self.import_diffuser_model(
thing, commit_to_conf=commit_to_conf
)
@ -727,32 +728,33 @@ class ModelManager(object):
return
if model_path.stem in self.config: # already imported
print(" | Already imported. Skipping")
print(" | Already imported. Skipping")
return model_path.stem
# another round of heuristics to guess the correct config file.
checkpoint = (
torch.load(model_path)
if model_path.suffix == ".ckpt"
else safetensors.torch.load_file(model_path)
)
checkpoint = None
if model_path.suffix.endswith((".ckpt",".pt")):
self.scan_model(model_path,model_path)
checkpoint = torch.load(model_path)
else:
checkpoint = safetensors.torch.load_file(model_path)
# additional probing needed if no config file provided
if model_config_file is None:
model_type = self.probe_model_type(checkpoint)
if model_type == SDLegacyType.V1:
print(" | SD-v1 model detected")
print(" | SD-v1 model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
)
elif model_type == SDLegacyType.V1_INPAINT:
print(" | SD-v1 inpainting model detected")
print(" | SD-v1 inpainting model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
)
elif model_type == SDLegacyType.V2_v:
print(
" | SD-v2-v model detected; model will be converted to diffusers format"
" | SD-v2-v model detected; model will be converted to diffusers format"
)
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
@ -760,7 +762,7 @@ class ModelManager(object):
convert = True
elif model_type == SDLegacyType.V2_e:
print(
" | SD-v2-e model detected; model will be converted to diffusers format"
" | SD-v2-e model detected; model will be converted to diffusers format"
)
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
@ -788,6 +790,7 @@ class ModelManager(object):
model_description=description,
original_config_file=model_config_file,
commit_to_conf=commit_to_conf,
scan_needed=False,
)
return model_name
@ -800,6 +803,7 @@ class ModelManager(object):
vae=None,
original_config_file: Path = None,
commit_to_conf: Path = None,
scan_needed: bool=True,
) -> str:
"""
Convert a legacy ckpt weights file to diffuser model and import
@ -834,11 +838,12 @@ class ModelManager(object):
extract_ema=True,
original_config_file=original_config_file,
vae=vae_model,
scan_needed=scan_needed,
)
print(
f" | Success. Optimized model is now located at {str(diffusers_path)}"
f" | Success. Optimized model is now located at {str(diffusers_path)}"
)
print(f" | Writing new config file entry for {model_name}")
print(f" | Writing new config file entry for {model_name}")
new_config = dict(
path=str(diffusers_path),
description=model_description,
@ -849,7 +854,7 @@ class ModelManager(object):
self.add_model(model_name, new_config, True)
if commit_to_conf:
self.commit(commit_to_conf)
print(">> Conversion succeeded")
print(" | Conversion succeeded")
except Exception as e:
print(f"** Conversion failed: {str(e)}")
print(
@ -1105,7 +1110,7 @@ class ModelManager(object):
with open(hashpath) as f:
hash = f.read()
return hash
print(" | Calculating sha256 hash of model files")
print(" | Calculating sha256 hash of model files")
tic = time.time()
sha = hashlib.sha256()
count = 0
@ -1117,7 +1122,7 @@ class ModelManager(object):
sha.update(chunk)
hash = sha.hexdigest()
toc = time.time()
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
with open(hashpath, "w") as f:
f.write(hash)
return hash
@ -1162,12 +1167,12 @@ class ModelManager(object):
local_files_only=not Globals.internet_available,
)
print(f" | Loading diffusers VAE from {name_or_path}")
print(f" | Loading diffusers VAE from {name_or_path}")
if using_fp16:
vae_args.update(torch_dtype=torch.float16)
fp_args_list = [{"revision": "fp16"}, {}]
else:
print(" | Using more accurate float32 precision")
print(" | Using more accurate float32 precision")
fp_args_list = [{}]
vae = None
@ -1208,7 +1213,7 @@ class ModelManager(object):
hashes_to_delete.add(revision.commit_hash)
strategy = cache_info.delete_revisions(*hashes_to_delete)
print(
f"** deletion of this model is expected to free {strategy.expected_freed_size_str}"
f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
)
strategy.execute()