diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index 793ba024cf..7c24eb8da8 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -378,16 +378,26 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False for key in keys: if key.startswith("model.diffusion_model"): flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( - flat_ema_key - ) + flat_ema_key_alt = "model_ema." + "".join(key.split(".")[2:]) + if flat_ema_key in checkpoint: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( + flat_ema_key + ) + elif flat_ema_key_alt in checkpoint: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( + flat_ema_key_alt + ) + else: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( + key + ) else: print( " | Extracting only the non-EMA weights (usually better for fine-tuning)" ) for key in keys: - if key.startswith(unet_key): + if key.startswith("model.diffusion_model") and key in checkpoint: unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) new_checkpoint = {}