mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix vae safetensor loading
This commit is contained in:
parent
4a9e93463d
commit
1c62ae461e
@ -359,10 +359,14 @@ class ModelManager(object):
|
|||||||
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
||||||
if os.path.exists(vae):
|
if os.path.exists(vae):
|
||||||
print(f' | Loading VAE weights from: {vae}')
|
print(f' | Loading VAE weights from: {vae}')
|
||||||
vae_ckpt = safetensors.torch.load_file(vae) \
|
vae_ckpt = None
|
||||||
if vae.endswith('.safetensors') \
|
vae_dict = None
|
||||||
else torch.load(vae, map_location="cpu")
|
if vae.endswith('.safetensors'):
|
||||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
vae_ckpt = safetensors.torch.load_file(vae)
|
||||||
|
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
|
||||||
|
else:
|
||||||
|
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||||
|
vae_dict = {k: v for k, v in vae_ckpt['state_dict'].items() if k[0:4] != "loss"}
|
||||||
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||||
else:
|
else:
|
||||||
print(f' | VAE file {vae} not found. Skipping.')
|
print(f' | VAE file {vae} not found. Skipping.')
|
||||||
|
Loading…
Reference in New Issue
Block a user