mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix issue #2932
This commit is contained in:
parent
2eef6df66a
commit
c367b21c71
@ -378,18 +378,24 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
||||
for key in keys:
|
||||
if key.startswith("model.diffusion_model"):
|
||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
|
||||
flat_ema_key
|
||||
)
|
||||
flat_ema_key_alt = "model_ema." + "".join(key.split(".")[2:])
|
||||
if flat_ema_key in checkpoint:
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
|
||||
flat_ema_key
|
||||
)
|
||||
elif flat_ema_key_alt in checkpoint:
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
|
||||
flat_ema_key_alt
|
||||
)
|
||||
else:
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
|
||||
key
|
||||
)
|
||||
else:
|
||||
print(
|
||||
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
if key.startswith(unet_key):
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
||||
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict[
|
||||
|
Loading…
Reference in New Issue
Block a user