diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index d6d61ee71d..b371fc96e8 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -422,11 +422,8 @@ def convert_ldm_unet_checkpoint( ) for key in keys: if key.startswith("model.diffusion_model"): - for delimiter in ['','.']: - flat_ema_key = "model_ema." + delimiter.join(key.split(".")[1:]) - if checkpoint.get(flat_ema_key) is not None: - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) - break + flat_ema_key = "model_ema." + "".join(key.split(".")[2:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) else: if sum(k.startswith("model_ema") for k in keys) > 100: logger.warning( @@ -1114,7 +1111,6 @@ def convert_controlnet_checkpoint( return controlnet.to(precision) -# TO DO - PASS PRECISION def download_from_original_stable_diffusion_ckpt( checkpoint_path: str, model_version: BaseModelType, @@ -1288,6 +1284,9 @@ def download_from_original_stable_diffusion_ckpt( original_config_file = BytesIO(requests.get(config_url).content) original_config = OmegaConf.load(original_config_file) + if original_config['model']['params'].get('use_ema') is not None: + extract_ema = original_config['model']['params']['use_ema'] + if ( model_version == BaseModelType.StableDiffusion2 and original_config["model"]["params"].get("parameterization") == "v"