load safetensors vaes

This commit is contained in:
Lincoln Stein 2023-01-17 22:51:57 -05:00
parent fc2098834d
commit 0b5c0c374e

View File

@ -359,7 +359,9 @@ class ModelManager(object):
vae = os.path.normpath(os.path.join(Globals.root,vae)) vae = os.path.normpath(os.path.join(Globals.root,vae))
if os.path.exists(vae): if os.path.exists(vae):
print(f' | Loading VAE weights from: {vae}') print(f' | Loading VAE weights from: {vae}')
vae_ckpt = torch.load(vae, map_location="cpu") vae_ckpt = safetensors.torch.load_file(vae) \
if vae.endswith('.safetensors') \
else torch.load(vae, map_location="cpu")
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
model.first_stage_model.load_state_dict(vae_dict, strict=False) model.first_stage_model.load_state_dict(vae_dict, strict=False)
else: else: