diff --git a/ldm/invoke/CLI.py b/ldm/invoke/CLI.py index f7b78dba1c..fa82ce31d6 100644 --- a/ldm/invoke/CLI.py +++ b/ldm/invoke/CLI.py @@ -772,11 +772,11 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer): original_config_file = Path(model_info["config"]) model_name = model_name_or_path model_description = model_info["description"] - vae = model_info["vae"] + vae = model_info.get("vae") else: print(f"** {model_name_or_path} is not a legacy .ckpt weights file") return - if vae_repo := ldm.invoke.model_manager.VAE_TO_REPO_ID.get(Path(vae).stem): + if vae and (vae_repo := ldm.invoke.model_manager.VAE_TO_REPO_ID.get(Path(vae).stem)): vae_repo = dict(repo_id=vae_repo) else: vae_repo = None diff --git a/ldm/invoke/ckpt_to_diffuser.py b/ldm/invoke/ckpt_to_diffuser.py index 44f48a77cd..0d82e1c413 100644 --- a/ldm/invoke/ckpt_to_diffuser.py +++ b/ldm/invoke/ckpt_to_diffuser.py @@ -1264,10 +1264,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt( cache_dir=cache_dir, ) pipe = pipeline_class( - vae=vae, - text_encoder=text_model, + vae=vae.to(precision), + text_encoder=text_model.to(precision), tokenizer=tokenizer, - unet=unet, + unet=unet.to(precision), scheduler=scheduler, safety_checker=None, feature_extractor=None, diff --git a/ldm/invoke/model_manager.py b/ldm/invoke/model_manager.py index ff88ae8c1c..1af1872289 100644 --- a/ldm/invoke/model_manager.py +++ b/ldm/invoke/model_manager.py @@ -172,9 +172,9 @@ class ModelManager(object): """ # if we are converting legacy files automatically, then # there are no legacy ckpts! - if Globals.ckpt_convert: - return False info = self.model_info(model_name) + if Globals.ckpt_convert or info.format=='diffusers' or self.is_v2_config(info.config): + return False if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")): return True return False @@ -544,6 +544,8 @@ class ModelManager(object): return pipeline, width, height, model_hash def is_v2_config(self, config: Path) -> bool: + if not os.path.isabs(config): + config = os.path.join(Globals.root, config) try: mconfig = OmegaConf.load(config) return (