fix vae safetensor loading

This commit is contained in:
Lincoln Stein 2023-01-18 12:15:57 -05:00
parent 4a9e93463d
commit 1c62ae461e

View File

@ -359,10 +359,14 @@ class ModelManager(object):
vae = os.path.normpath(os.path.join(Globals.root,vae))
if os.path.exists(vae):
print(f' | Loading VAE weights from: {vae}')
vae_ckpt = safetensors.torch.load_file(vae) \
if vae.endswith('.safetensors') \
else torch.load(vae, map_location="cpu")
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
vae_ckpt = None
vae_dict = None
if vae.endswith('.safetensors'):
vae_ckpt = safetensors.torch.load_file(vae)
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
else:
vae_ckpt = torch.load(vae, map_location="cpu")
vae_dict = {k: v for k, v in vae_ckpt['state_dict'].items() if k[0:4] != "loss"}
model.first_stage_model.load_state_dict(vae_dict, strict=False)
else:
print(f' | VAE file {vae} not found. Skipping.')