mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
load safetensors vaes
This commit is contained in:
parent
fc2098834d
commit
0b5c0c374e
@ -359,7 +359,9 @@ class ModelManager(object):
|
|||||||
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
||||||
if os.path.exists(vae):
|
if os.path.exists(vae):
|
||||||
print(f' | Loading VAE weights from: {vae}')
|
print(f' | Loading VAE weights from: {vae}')
|
||||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
vae_ckpt = safetensors.torch.load_file(vae) \
|
||||||
|
if vae.endswith('.safetensors') \
|
||||||
|
else torch.load(vae, map_location="cpu")
|
||||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||||
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user