documentation and usability fixes

This commit is contained in:
Lincoln Stein
2022-10-29 10:37:38 -04:00
parent 3caa95ced9
commit 13f26a99b8
7 changed files with 290 additions and 44 deletions

View File

@ -227,11 +227,14 @@ class ModelCache(object):
print(' | Using more accurate float32 precision')
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
if vae and os.path.exists(vae):
print(f' | Loading VAE weights from: {vae}')
vae_ckpt = torch.load(vae, map_location="cpu")
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
model.first_stage_model.load_state_dict(vae_dict, strict=False)
if vae:
if os.path.exists(vae):
print(f' | Loading VAE weights from: {vae}')
vae_ckpt = torch.load(vae, map_location="cpu")
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
model.first_stage_model.load_state_dict(vae_dict, strict=False)
else:
print(f' | VAE file {vae} not found. Skipping.')
model.to(self.device)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here