prevent double-scanning during convert

- Avoid running scanner twice on same file during the probing and
  conversion process.

- Clean up diagnostic messages.
This commit is contained in:
Lincoln Stein 2023-03-23 14:24:10 -04:00
parent ba89444e36
commit 4a3951681c
2 changed files with 52 additions and 55 deletions

View File

@ -327,10 +327,10 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
unet_key = "model.diffusion_model." unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100: if sum(k.startswith("model_ema") for k in keys) > 100:
print(f" | Checkpoint {path} has both EMA and non-EMA weights.") print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
if extract_ema: if extract_ema:
print( print(
' | Extracting EMA weights (usually better for inference)' ' | Extracting EMA weights (usually better for inference)'
) )
for key in keys: for key in keys:
if key.startswith("model.diffusion_model"): if key.startswith("model.diffusion_model"):
@ -338,7 +338,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else: else:
print( print(
' | Extracting only the non-EMA weights (usually better for fine-tuning)' ' | Extracting only the non-EMA weights (usually better for fine-tuning)'
) )
for key in keys: for key in keys:
@ -809,6 +809,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
vae:AutoencoderKL=None, vae:AutoencoderKL=None,
precision:torch.dtype=torch.float32, precision:torch.dtype=torch.float32,
return_generator_pipeline:bool=False, return_generator_pipeline:bool=False,
scan_needed:bool=True,
)->Union[StableDiffusionPipeline,StableDiffusionGeneratorPipeline]: )->Union[StableDiffusionPipeline,StableDiffusionGeneratorPipeline]:
''' '''
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
@ -844,7 +845,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
dlogging.set_verbosity_error() dlogging.set_verbosity_error()
if Path(checkpoint_path).suffix == '.ckpt': if Path(checkpoint_path).suffix == '.ckpt':
ModelManager.scan_model(checkpoint_path,checkpoint_path) if scan_needed:
ModelManager.scan_model(checkpoint_path,checkpoint_path)
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path)
else: else:
checkpoint = load_file(checkpoint_path) checkpoint = load_file(checkpoint_path)
@ -855,7 +857,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
if "global_step" in checkpoint: if "global_step" in checkpoint:
global_step = checkpoint["global_step"] global_step = checkpoint["global_step"]
else: else:
print(" | global_step key not found in model") print(" | global_step key not found in model")
global_step = None global_step = None
# sometimes there is a state_dict key and sometimes not # sometimes there is a state_dict key and sometimes not
@ -957,14 +959,14 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
# Convert the VAE model, or use the one passed # Convert the VAE model, or use the one passed
if not vae: if not vae:
print(' | Using checkpoint model\'s original VAE') print(' | Using checkpoint model\'s original VAE')
vae_config = create_vae_diffusers_config(original_config, image_size=image_size) vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)
else: else:
print(' | Using VAE specified in config') print(' | Using VAE specified in config')
# Convert the text model. # Convert the text model.
model_type = pipeline_type model_type = pipeline_type

View File

@ -282,13 +282,13 @@ class ModelManager(object):
self.stack.remove(model_name) self.stack.remove(model_name)
if delete_files: if delete_files:
if weights: if weights:
print(f"** deleting file {weights}") print(f"** Deleting file {weights}")
Path(weights).unlink(missing_ok=True) Path(weights).unlink(missing_ok=True)
elif path: elif path:
print(f"** deleting directory {path}") print(f"** Deleting directory {path}")
rmtree(path, ignore_errors=True) rmtree(path, ignore_errors=True)
elif repo_id: elif repo_id:
print(f"** deleting the cached model directory for {repo_id}") print(f"** Deleting the cached model directory for {repo_id}")
self._delete_model_from_cache(repo_id) self._delete_model_from_cache(repo_id)
def add_model( def add_model(
@ -420,11 +420,6 @@ class ModelManager(object):
"NOHASH", "NOHASH",
) )
# scan model
self.scan_model(model_name, weights)
print(f">> Loading {model_name} from {weights}")
# for usage statistics # for usage statistics
if self._has_cuda(): if self._has_cuda():
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
@ -440,6 +435,7 @@ class ModelManager(object):
sd = None sd = None
if weights.endswith(".ckpt"): if weights.endswith(".ckpt"):
self.scan_model(model_name, weights)
sd = torch.load(io.BytesIO(weight_bytes), map_location="cpu") sd = torch.load(io.BytesIO(weight_bytes), map_location="cpu")
else: else:
sd = safetensors.torch.load(weight_bytes) sd = safetensors.torch.load(weight_bytes)
@ -466,18 +462,12 @@ class ModelManager(object):
vae = os.path.normpath(os.path.join(Globals.root, vae)) vae = os.path.normpath(os.path.join(Globals.root, vae))
if os.path.exists(vae): if os.path.exists(vae):
print(f" | Loading VAE weights from: {vae}") print(f" | Loading VAE weights from: {vae}")
vae_ckpt = None if vae.endswith((".ckpt",".pt")):
vae_dict = None self.scan_model(vae,vae)
if vae.endswith(".safetensors"):
vae_ckpt = safetensors.torch.load_file(vae)
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
else:
vae_ckpt = torch.load(vae, map_location="cpu") vae_ckpt = torch.load(vae, map_location="cpu")
vae_dict = { else:
k: v vae_ckpt = safetensors.torch.load_file(vae)
for k, v in vae_ckpt["state_dict"].items() vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
if k[0:4] != "loss"
}
model.first_stage_model.load_state_dict(vae_dict, strict=False) model.first_stage_model.load_state_dict(vae_dict, strict=False)
else: else:
print(f" | VAE file {vae} not found. Skipping.") print(f" | VAE file {vae} not found. Skipping.")
@ -499,9 +489,9 @@ class ModelManager(object):
print(f">> Loading diffusers model from {name_or_path}") print(f">> Loading diffusers model from {name_or_path}")
if using_fp16: if using_fp16:
print(" | Using faster float16 precision") print(" | Using faster float16 precision")
else: else:
print(" | Using more accurate float32 precision") print(" | Using more accurate float32 precision")
# TODO: scan weights maybe? # TODO: scan weights maybe?
pipeline_args: dict[str, Any] = dict( pipeline_args: dict[str, Any] = dict(
@ -553,7 +543,7 @@ class ModelManager(object):
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
height = width height = width
print(f" | Default image dimensions = {width} x {height}") print(f" | Default image dimensions = {width} x {height}")
return pipeline, width, height, model_hash return pipeline, width, height, model_hash
@ -600,7 +590,7 @@ class ModelManager(object):
and option to exit if an infected file is identified. and option to exit if an infected file is identified.
""" """
# scan model # scan model
print(f">> Scanning Model: {model_name}") print(f" | Scanning Model: {model_name}")
scan_result = scan_file_path(checkpoint) scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0: if scan_result.infected_files != 0:
if scan_result.infected_files == 1: if scan_result.infected_files == 1:
@ -623,7 +613,7 @@ class ModelManager(object):
print("### Exiting InvokeAI") print("### Exiting InvokeAI")
sys.exit() sys.exit()
else: else:
print(">> Model scanned ok") print(" | Model scanned ok")
def import_diffuser_model( def import_diffuser_model(
self, self,
@ -803,19 +793,20 @@ class ModelManager(object):
print(f">> Probing {thing} for import") print(f">> Probing {thing} for import")
if thing.startswith(("http:", "https:", "ftp:")): if thing.startswith(("http:", "https:", "ftp:")):
print(f" | {thing} appears to be a URL") print(f" | {thing} appears to be a URL")
model_path = self._resolve_path( model_path = self._resolve_path(
thing, "models/ldm/stable-diffusion-v1" thing, "models/ldm/stable-diffusion-v1"
) # _resolve_path does a download if needed ) # _resolve_path does a download if needed
is_temporary = True is_temporary = True
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")): elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
if Path(thing).stem in ["model", "diffusion_pytorch_model"]: if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
print( print(
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import" f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
) )
return return
else: else:
print(f" | {thing} appears to be a checkpoint file on disk") print(f" | {thing} appears to be a checkpoint file on disk")
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1") model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists(): elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
@ -872,11 +863,12 @@ class ModelManager(object):
return model_path.stem return model_path.stem
# another round of heuristics to guess the correct config file. # another round of heuristics to guess the correct config file.
checkpoint = ( checkpoint = None
safetensors.torch.load_file(model_path) if model_path.suffix.endswith((".ckpt",".pt")):
if model_path.suffix == ".safetensors" self.scan_model(model_path,model_path)
else torch.load(model_path) checkpoint = torch.load(model_path)
) else:
checkpoint = safetensors.torch.load_file(model_path)
# additional probing needed if no config file provided # additional probing needed if no config file provided
if model_config_file is None: if model_config_file is None:
model_type = self.probe_model_type(checkpoint) model_type = self.probe_model_type(checkpoint)
@ -921,7 +913,7 @@ class ModelManager(object):
if model_config_file.name.startswith('v2'): if model_config_file.name.startswith('v2'):
convert = True convert = True
print( print(
" | This SD-v2 model will be converted to diffusers format for use" " | This SD-v2 model will be converted to diffusers format for use"
) )
if convert: if convert:
@ -936,6 +928,7 @@ class ModelManager(object):
model_description=description, model_description=description,
original_config_file=model_config_file, original_config_file=model_config_file,
commit_to_conf=commit_to_conf, commit_to_conf=commit_to_conf,
scan_needed=False,
) )
# in the event that this file was downloaded automatically prior to conversion # in the event that this file was downloaded automatically prior to conversion
# we do not keep the original .ckpt/.safetensors around # we do not keep the original .ckpt/.safetensors around
@ -960,14 +953,15 @@ class ModelManager(object):
return model_name return model_name
def convert_and_import( def convert_and_import(
self, self,
ckpt_path: Path, ckpt_path: Path,
diffusers_path: Path, diffusers_path: Path,
model_name=None, model_name=None,
model_description=None, model_description=None,
vae=None, vae=None,
original_config_file: Path = None, original_config_file: Path = None,
commit_to_conf: Path = None, commit_to_conf: Path = None,
scan_needed: bool=True,
) -> str: ) -> str:
""" """
Convert a legacy ckpt weights file to diffuser model and import Convert a legacy ckpt weights file to diffuser model and import
@ -1002,11 +996,12 @@ class ModelManager(object):
extract_ema=True, extract_ema=True,
original_config_file=original_config_file, original_config_file=original_config_file,
vae=vae_model, vae=vae_model,
scan_needed=scan_needed,
) )
print( print(
f" | Success. Optimized model is now located at {str(diffusers_path)}" f" | Success. Optimized model is now located at {str(diffusers_path)}"
) )
print(f" | Writing new config file entry for {model_name}") print(f" | Writing new config file entry for {model_name}")
new_config = dict( new_config = dict(
path=str(diffusers_path), path=str(diffusers_path),
description=model_description, description=model_description,
@ -1296,7 +1291,7 @@ class ModelManager(object):
with open(hashpath) as f: with open(hashpath) as f:
hash = f.read() hash = f.read()
return hash return hash
print(" | Calculating sha256 hash of model files") print(" | Calculating sha256 hash of model files")
tic = time.time() tic = time.time()
sha = hashlib.sha256() sha = hashlib.sha256()
count = 0 count = 0
@ -1308,7 +1303,7 @@ class ModelManager(object):
sha.update(chunk) sha.update(chunk)
hash = sha.hexdigest() hash = sha.hexdigest()
toc = time.time() toc = time.time()
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic)) print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
with open(hashpath, "w") as f: with open(hashpath, "w") as f:
f.write(hash) f.write(hash)
return hash return hash
@ -1353,12 +1348,12 @@ class ModelManager(object):
local_files_only=not Globals.internet_available, local_files_only=not Globals.internet_available,
) )
print(f" | Loading diffusers VAE from {name_or_path}") print(f" | Loading diffusers VAE from {name_or_path}")
if using_fp16: if using_fp16:
vae_args.update(torch_dtype=torch.float16) vae_args.update(torch_dtype=torch.float16)
fp_args_list = [{"revision": "fp16"}, {}] fp_args_list = [{"revision": "fp16"}, {}]
else: else:
print(" | Using more accurate float32 precision") print(" | Using more accurate float32 precision")
fp_args_list = [{}] fp_args_list = [{}]
vae = None vae = None
@ -1399,7 +1394,7 @@ class ModelManager(object):
hashes_to_delete.add(revision.commit_hash) hashes_to_delete.add(revision.commit_hash)
strategy = cache_info.delete_revisions(*hashes_to_delete) strategy = cache_info.delete_revisions(*hashes_to_delete)
print( print(
f"** deletion of this model is expected to free {strategy.expected_freed_size_str}" f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
) )
strategy.execute() strategy.execute()