fix mistake in indexing flat_ema_key

This commit is contained in:
Lincoln Stein 2023-07-29 17:20:26 -04:00
parent 3f9105be50
commit 1de783b1ce

View File

@ -422,7 +422,7 @@ def convert_ldm_unet_checkpoint(
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[2:])
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
if sum(k.startswith("model_ema") for k in keys) > 100: