handle VAEs that do not have a "state_dict" key

This commit is contained in:
Lincoln Stein 2023-03-23 15:11:29 -04:00
parent 4e0b5d85ba
commit a97107bd90

View File

@ -1031,9 +1031,10 @@ def replace_checkpoint_vae(checkpoint, vae_path:str):
vae_ckpt = load_file(vae_path)
else:
vae_ckpt = torch.load(vae_path, map_location="cpu")
for vae_key in vae_ckpt['state_dict']:
state_dict = vae_ckpt['state_dict'] if "state_dict" in vae_ckpt else vae_ckpt
for vae_key in state_dict:
new_key = f'first_stage_model.{vae_key}'
checkpoint[new_key] = vae_ckpt['state_dict'][vae_key]
checkpoint[new_key] = state_dict[vae_key]
def load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path: str,