diff --git a/configs/models.yaml b/configs/models.yaml index f3fde45d8f..67183bdd1f 100644 --- a/configs/models.yaml +++ b/configs/models.yaml @@ -9,6 +9,7 @@ stable-diffusion-1.4: config: configs/stable-diffusion/v1-inference.yaml weights: models/ldm/stable-diffusion-v1/model.ckpt + vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt description: Stable Diffusion inference model version 1.4 width: 512 height: 512 diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 5e9e53cfb7..f580dfba25 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -13,6 +13,7 @@ import gc import hashlib import psutil import transformers +import os from sys import getrefcount from omegaconf import OmegaConf from omegaconf.errors import ConfigAttributeError @@ -193,6 +194,7 @@ class ModelCache(object): mconfig = self.config[model_name] config = mconfig.config weights = mconfig.weights + vae = mconfig.get('vae',None) width = mconfig.width height = mconfig.height @@ -222,9 +224,17 @@ class ModelCache(object): else: print(' | Using more accurate float32 precision') + # look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py + if vae and os.path.exists(vae): + print(f' | Loading VAE weights from: {vae}') + vae_ckpt = torch.load(vae, map_location="cpu") + vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} + model.first_stage_model.load_state_dict(vae_dict, strict=False) + model.to(self.device) # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here model.cond_stage_model.device = self.device + model.eval() for m in model.modules(): diff --git a/scripts/invoke.py b/scripts/invoke.py index b7af4d6469..f4d4f3c4c0 100644 --- a/scripts/invoke.py +++ b/scripts/invoke.py @@ -493,6 +493,16 @@ def add_weights_to_config(model_path:str, gen, opt, completer): new_config['config'] = input('Configuration file for this model: ') done = os.path.exists(new_config['config']) + done = False + completer.complete_extensions(('.vae.pt','.vae','.ckpt')) + while not done: + vae = input('VAE autoencoder file for this model [None]: ') + if os.path.exists(vae): + new_config['vae'] = vae + done = True + else: + done = len(vae)==0 + completer.complete_extensions(None) for field in ('width','height'): @@ -537,8 +547,8 @@ def edit_config(model_name:str, gen, opt, completer): conf = config[model_name] new_config = {} - completer.complete_extensions(('.yaml','.yml','.ckpt','.vae')) - for field in ('description', 'weights', 'config', 'width','height'): + completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt')) + for field in ('description', 'weights', 'vae', 'config', 'width','height'): completer.linebuffer = str(conf[field]) if field in conf else '' new_value = input(f'{field}: ') new_config[field] = int(new_value) if field in ('width','height') else new_value