diff --git a/ldm/invoke/ckpt_to_diffuser.py b/ldm/invoke/ckpt_to_diffuser.py index 8db2ec8366..1d41fa5bd1 100644 --- a/ldm/invoke/ckpt_to_diffuser.py +++ b/ldm/invoke/ckpt_to_diffuser.py @@ -327,10 +327,10 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False unet_key = "model.diffusion_model." # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA if sum(k.startswith("model_ema") for k in keys) > 100: - print(f" | Checkpoint {path} has both EMA and non-EMA weights.") + print(f" | Checkpoint {path} has both EMA and non-EMA weights.") if extract_ema: print( - ' | Extracting EMA weights (usually better for inference)' + ' | Extracting EMA weights (usually better for inference)' ) for key in keys: if key.startswith("model.diffusion_model"): @@ -338,7 +338,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) else: print( - ' | Extracting only the non-EMA weights (usually better for fine-tuning)' + ' | Extracting only the non-EMA weights (usually better for fine-tuning)' ) for key in keys: @@ -809,6 +809,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt( vae:AutoencoderKL=None, precision:torch.dtype=torch.float32, return_generator_pipeline:bool=False, + scan_needed:bool=True, )->Union[StableDiffusionPipeline,StableDiffusionGeneratorPipeline]: ''' Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` @@ -844,7 +845,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt( dlogging.set_verbosity_error() if Path(checkpoint_path).suffix == '.ckpt': - ModelManager.scan_model(checkpoint_path,checkpoint_path) + if scan_needed: + ModelManager.scan_model(checkpoint_path,checkpoint_path) checkpoint = torch.load(checkpoint_path) else: checkpoint = load_file(checkpoint_path) @@ -855,7 +857,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt( if "global_step" in checkpoint: global_step = checkpoint["global_step"] else: - print(" | global_step key not found in model") + print(" | global_step key not found in model") global_step = None # sometimes there is a state_dict key and sometimes not @@ -957,14 +959,14 @@ def load_pipeline_from_original_stable_diffusion_ckpt( # Convert the VAE model, or use the one passed if not vae: - print(' | Using checkpoint model\'s original VAE') + print(' | Using checkpoint model\'s original VAE') vae_config = create_vae_diffusers_config(original_config, image_size=image_size) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) vae = AutoencoderKL(**vae_config) vae.load_state_dict(converted_vae_checkpoint) else: - print(' | Using VAE specified in config') + print(' | Using VAE specified in config') # Convert the text model. model_type = pipeline_type diff --git a/ldm/invoke/model_manager.py b/ldm/invoke/model_manager.py index 8bed8a1be2..214ef022bb 100644 --- a/ldm/invoke/model_manager.py +++ b/ldm/invoke/model_manager.py @@ -282,13 +282,13 @@ class ModelManager(object): self.stack.remove(model_name) if delete_files: if weights: - print(f"** deleting file {weights}") + print(f"** Deleting file {weights}") Path(weights).unlink(missing_ok=True) elif path: - print(f"** deleting directory {path}") + print(f"** Deleting directory {path}") rmtree(path, ignore_errors=True) elif repo_id: - print(f"** deleting the cached model directory for {repo_id}") + print(f"** Deleting the cached model directory for {repo_id}") self._delete_model_from_cache(repo_id) def add_model( @@ -420,11 +420,6 @@ class ModelManager(object): "NOHASH", ) - # scan model - self.scan_model(model_name, weights) - - print(f">> Loading {model_name} from {weights}") - # for usage statistics if self._has_cuda(): torch.cuda.reset_peak_memory_stats() @@ -440,6 +435,7 @@ class ModelManager(object): sd = None if weights.endswith(".ckpt"): + self.scan_model(model_name, weights) sd = torch.load(io.BytesIO(weight_bytes), map_location="cpu") else: sd = safetensors.torch.load(weight_bytes) @@ -466,18 +462,12 @@ class ModelManager(object): vae = os.path.normpath(os.path.join(Globals.root, vae)) if os.path.exists(vae): print(f" | Loading VAE weights from: {vae}") - vae_ckpt = None - vae_dict = None - if vae.endswith(".safetensors"): - vae_ckpt = safetensors.torch.load_file(vae) - vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"} - else: + if vae.endswith((".ckpt",".pt")): + self.scan_model(vae,vae) vae_ckpt = torch.load(vae, map_location="cpu") - vae_dict = { - k: v - for k, v in vae_ckpt["state_dict"].items() - if k[0:4] != "loss" - } + else: + vae_ckpt = safetensors.torch.load_file(vae) + vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"} model.first_stage_model.load_state_dict(vae_dict, strict=False) else: print(f" | VAE file {vae} not found. Skipping.") @@ -499,9 +489,9 @@ class ModelManager(object): print(f">> Loading diffusers model from {name_or_path}") if using_fp16: - print(" | Using faster float16 precision") + print(" | Using faster float16 precision") else: - print(" | Using more accurate float32 precision") + print(" | Using more accurate float32 precision") # TODO: scan weights maybe? pipeline_args: dict[str, Any] = dict( @@ -553,7 +543,7 @@ class ModelManager(object): width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor height = width - print(f" | Default image dimensions = {width} x {height}") + print(f" | Default image dimensions = {width} x {height}") return pipeline, width, height, model_hash @@ -600,7 +590,7 @@ class ModelManager(object): and option to exit if an infected file is identified. """ # scan model - print(f">> Scanning Model: {model_name}") + print(f" | Scanning Model: {model_name}") scan_result = scan_file_path(checkpoint) if scan_result.infected_files != 0: if scan_result.infected_files == 1: @@ -623,7 +613,7 @@ class ModelManager(object): print("### Exiting InvokeAI") sys.exit() else: - print(">> Model scanned ok") + print(" | Model scanned ok") def import_diffuser_model( self, @@ -803,19 +793,20 @@ class ModelManager(object): print(f">> Probing {thing} for import") if thing.startswith(("http:", "https:", "ftp:")): - print(f" | {thing} appears to be a URL") + print(f" | {thing} appears to be a URL") model_path = self._resolve_path( thing, "models/ldm/stable-diffusion-v1" ) # _resolve_path does a download if needed is_temporary = True + elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")): if Path(thing).stem in ["model", "diffusion_pytorch_model"]: print( - f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import" + f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import" ) return else: - print(f" | {thing} appears to be a checkpoint file on disk") + print(f" | {thing} appears to be a checkpoint file on disk") model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1") elif Path(thing).is_dir() and Path(thing, "model_index.json").exists(): @@ -872,11 +863,12 @@ class ModelManager(object): return model_path.stem # another round of heuristics to guess the correct config file. - checkpoint = ( - safetensors.torch.load_file(model_path) - if model_path.suffix == ".safetensors" - else torch.load(model_path) - ) + checkpoint = None + if model_path.suffix.endswith((".ckpt",".pt")): + self.scan_model(model_path,model_path) + checkpoint = torch.load(model_path) + else: + checkpoint = safetensors.torch.load_file(model_path) # additional probing needed if no config file provided if model_config_file is None: model_type = self.probe_model_type(checkpoint) @@ -921,7 +913,7 @@ class ModelManager(object): if model_config_file.name.startswith('v2'): convert = True print( - " | This SD-v2 model will be converted to diffusers format for use" + " | This SD-v2 model will be converted to diffusers format for use" ) if convert: @@ -936,6 +928,7 @@ class ModelManager(object): model_description=description, original_config_file=model_config_file, commit_to_conf=commit_to_conf, + scan_needed=False, ) # in the event that this file was downloaded automatically prior to conversion # we do not keep the original .ckpt/.safetensors around @@ -960,14 +953,15 @@ class ModelManager(object): return model_name def convert_and_import( - self, - ckpt_path: Path, - diffusers_path: Path, - model_name=None, - model_description=None, - vae=None, - original_config_file: Path = None, - commit_to_conf: Path = None, + self, + ckpt_path: Path, + diffusers_path: Path, + model_name=None, + model_description=None, + vae=None, + original_config_file: Path = None, + commit_to_conf: Path = None, + scan_needed: bool=True, ) -> str: """ Convert a legacy ckpt weights file to diffuser model and import @@ -1002,11 +996,12 @@ class ModelManager(object): extract_ema=True, original_config_file=original_config_file, vae=vae_model, + scan_needed=scan_needed, ) print( - f" | Success. Optimized model is now located at {str(diffusers_path)}" + f" | Success. Optimized model is now located at {str(diffusers_path)}" ) - print(f" | Writing new config file entry for {model_name}") + print(f" | Writing new config file entry for {model_name}") new_config = dict( path=str(diffusers_path), description=model_description, @@ -1296,7 +1291,7 @@ class ModelManager(object): with open(hashpath) as f: hash = f.read() return hash - print(" | Calculating sha256 hash of model files") + print(" | Calculating sha256 hash of model files") tic = time.time() sha = hashlib.sha256() count = 0 @@ -1308,7 +1303,7 @@ class ModelManager(object): sha.update(chunk) hash = sha.hexdigest() toc = time.time() - print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic)) + print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic)) with open(hashpath, "w") as f: f.write(hash) return hash @@ -1353,12 +1348,12 @@ class ModelManager(object): local_files_only=not Globals.internet_available, ) - print(f" | Loading diffusers VAE from {name_or_path}") + print(f" | Loading diffusers VAE from {name_or_path}") if using_fp16: vae_args.update(torch_dtype=torch.float16) fp_args_list = [{"revision": "fp16"}, {}] else: - print(" | Using more accurate float32 precision") + print(" | Using more accurate float32 precision") fp_args_list = [{}] vae = None @@ -1399,7 +1394,7 @@ class ModelManager(object): hashes_to_delete.add(revision.commit_hash) strategy = cache_info.delete_revisions(*hashes_to_delete) print( - f"** deletion of this model is expected to free {strategy.expected_freed_size_str}" + f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}" ) strategy.execute()