This commit is contained in:
Fabio 'MrWHO' Torchetti 2023-03-12 15:40:33 -05:00
parent 2eef6df66a
commit c367b21c71

View File

@ -378,18 +378,24 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key
)
flat_ema_key_alt = "model_ema." + "".join(key.split(".")[2:])
if flat_ema_key in checkpoint:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key
)
elif flat_ema_key_alt in checkpoint:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key_alt
)
else:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
key
)
else:
print(
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
)
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict[