Add keys when non EMA

This commit is contained in:
Fabio 'MrWHO' Torchetti 2023-03-12 16:22:22 -05:00
parent c367b21c71
commit 5c5106c14a

View File

@ -396,6 +396,10 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
)
for key in keys:
if key.startswith("model.diffusion_model") and key in checkpoint:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict[