Allow loading all types of dreambooth models - Fix issue #2932 (#2933)

Allows to load models with EMA using `model_ema.diffusion_model.xxxx` or
`model_ema.xxxx` weights.

Fixes #2932
This commit is contained in:
Lincoln Stein 2023-03-23 23:40:04 -04:00 committed by GitHub
commit 873597cb84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -378,16 +378,26 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key
)
flat_ema_key_alt = "model_ema." + "".join(key.split(".")[2:])
if flat_ema_key in checkpoint:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key
)
elif flat_ema_key_alt in checkpoint:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key_alt
)
else:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
key
)
else:
print(
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
)
for key in keys:
if key.startswith(unet_key):
if key.startswith("model.diffusion_model") and key in checkpoint:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}