From 51fdbe22d2b74c0080a4f9eec8ce00d50023cc97 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 22 Oct 2022 13:29:45 -0400 Subject: [PATCH] add support for loading VAE autoencoders To add a VAE autoencoder to an existing model: 1. Download the appropriate autoencoder and put it into models/ldm/stable-diffusion Note that you MUST use a VAE that was written for the original CompViz Stable Diffusion codebase. For v1.4, that would be the file named vae-ft-mse-840000-ema-pruned.ckpt that you can download from https://huggingface.co/stabilityai/sd-vae-ft-mse-original 2. Edit config/models.yaml to contain the following stanza, modifying `weights` and `vae` as required to match the weights and vae model file names. There is no requirement to rename the VAE file. ~~~ stable-diffusion-1.4: weights: models/ldm/stable-diffusion-v1/sd-v1-4.ckpt description: Stable Diffusion v1.4 config: configs/stable-diffusion/v1-inference.yaml vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt width: 512 height: 512 ~~~ 3. Alternatively from within the `invoke.py` CLI, you may use the command `!editmodel stable-diffusion-1.4` to bring up a simple editor that will allow you to add the path to the VAE. 4. If you are just installing InvokeAI for the first time, you can also use `!import_model models/ldm/stable-diffusion/sd-v1.4.ckpt` instead to create the configuration from scratch. 5. That's it! --- configs/models.yaml | 1 + ldm/invoke/model_cache.py | 10 ++++++++++ scripts/invoke.py | 14 ++++++++++++-- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/configs/models.yaml b/configs/models.yaml index f3fde45d8f..67183bdd1f 100644 --- a/configs/models.yaml +++ b/configs/models.yaml @@ -9,6 +9,7 @@ stable-diffusion-1.4: config: configs/stable-diffusion/v1-inference.yaml weights: models/ldm/stable-diffusion-v1/model.ckpt + vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt description: Stable Diffusion inference model version 1.4 width: 512 height: 512 diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 5e9e53cfb7..f580dfba25 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -13,6 +13,7 @@ import gc import hashlib import psutil import transformers +import os from sys import getrefcount from omegaconf import OmegaConf from omegaconf.errors import ConfigAttributeError @@ -193,6 +194,7 @@ class ModelCache(object): mconfig = self.config[model_name] config = mconfig.config weights = mconfig.weights + vae = mconfig.get('vae',None) width = mconfig.width height = mconfig.height @@ -222,9 +224,17 @@ class ModelCache(object): else: print(' | Using more accurate float32 precision') + # look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py + if vae and os.path.exists(vae): + print(f' | Loading VAE weights from: {vae}') + vae_ckpt = torch.load(vae, map_location="cpu") + vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} + model.first_stage_model.load_state_dict(vae_dict, strict=False) + model.to(self.device) # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here model.cond_stage_model.device = self.device + model.eval() for m in model.modules(): diff --git a/scripts/invoke.py b/scripts/invoke.py index b7af4d6469..f4d4f3c4c0 100644 --- a/scripts/invoke.py +++ b/scripts/invoke.py @@ -493,6 +493,16 @@ def add_weights_to_config(model_path:str, gen, opt, completer): new_config['config'] = input('Configuration file for this model: ') done = os.path.exists(new_config['config']) + done = False + completer.complete_extensions(('.vae.pt','.vae','.ckpt')) + while not done: + vae = input('VAE autoencoder file for this model [None]: ') + if os.path.exists(vae): + new_config['vae'] = vae + done = True + else: + done = len(vae)==0 + completer.complete_extensions(None) for field in ('width','height'): @@ -537,8 +547,8 @@ def edit_config(model_name:str, gen, opt, completer): conf = config[model_name] new_config = {} - completer.complete_extensions(('.yaml','.yml','.ckpt','.vae')) - for field in ('description', 'weights', 'config', 'width','height'): + completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt')) + for field in ('description', 'weights', 'vae', 'config', 'width','height'): completer.linebuffer = str(conf[field]) if field in conf else '' new_value = input(f'{field}: ') new_config[field] = int(new_value) if field in ('width','height') else new_value