add support for loading VAE autoencoders

To add a VAE autoencoder to an existing model:

1. Download the appropriate autoencoder and put it into
   models/ldm/stable-diffusion

   Note that you MUST use a VAE that was written for the
   original CompViz Stable Diffusion codebase. For v1.4,
   that would be the file named vae-ft-mse-840000-ema-pruned.ckpt
   that you can download from https://huggingface.co/stabilityai/sd-vae-ft-mse-original

2. Edit config/models.yaml to contain the following stanza, modifying `weights`
   and `vae` as required to match the weights and vae model file names. There is
   no requirement to rename the VAE file.

~~~
stable-diffusion-1.4:
  weights: models/ldm/stable-diffusion-v1/sd-v1-4.ckpt
  description: Stable Diffusion v1.4
  config: configs/stable-diffusion/v1-inference.yaml
  vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
  width: 512
  height: 512
~~~

3. Alternatively from within the `invoke.py` CLI, you may use the command
   `!editmodel stable-diffusion-1.4` to bring up a simple editor that will
   allow you to add the path to the VAE.

4. If you are just installing InvokeAI for the first time, you can also
   use `!import_model models/ldm/stable-diffusion/sd-v1.4.ckpt` instead
   to create the configuration from scratch.

5. That's it!
This commit is contained in:
Lincoln Stein 2022-10-22 13:29:45 -04:00
parent 7308022bc7
commit 51fdbe22d2
3 changed files with 23 additions and 2 deletions

View File

@ -9,6 +9,7 @@
stable-diffusion-1.4: stable-diffusion-1.4:
config: configs/stable-diffusion/v1-inference.yaml config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/model.ckpt weights: models/ldm/stable-diffusion-v1/model.ckpt
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
description: Stable Diffusion inference model version 1.4 description: Stable Diffusion inference model version 1.4
width: 512 width: 512
height: 512 height: 512

View File

@ -13,6 +13,7 @@ import gc
import hashlib import hashlib
import psutil import psutil
import transformers import transformers
import os
from sys import getrefcount from sys import getrefcount
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.errors import ConfigAttributeError from omegaconf.errors import ConfigAttributeError
@ -193,6 +194,7 @@ class ModelCache(object):
mconfig = self.config[model_name] mconfig = self.config[model_name]
config = mconfig.config config = mconfig.config
weights = mconfig.weights weights = mconfig.weights
vae = mconfig.get('vae',None)
width = mconfig.width width = mconfig.width
height = mconfig.height height = mconfig.height
@ -222,9 +224,17 @@ class ModelCache(object):
else: else:
print(' | Using more accurate float32 precision') print(' | Using more accurate float32 precision')
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
if vae and os.path.exists(vae):
print(f' | Loading VAE weights from: {vae}')
vae_ckpt = torch.load(vae, map_location="cpu")
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
model.first_stage_model.load_state_dict(vae_dict, strict=False)
model.to(self.device) model.to(self.device)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
model.cond_stage_model.device = self.device model.cond_stage_model.device = self.device
model.eval() model.eval()
for m in model.modules(): for m in model.modules():

View File

@ -493,6 +493,16 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
new_config['config'] = input('Configuration file for this model: ') new_config['config'] = input('Configuration file for this model: ')
done = os.path.exists(new_config['config']) done = os.path.exists(new_config['config'])
done = False
completer.complete_extensions(('.vae.pt','.vae','.ckpt'))
while not done:
vae = input('VAE autoencoder file for this model [None]: ')
if os.path.exists(vae):
new_config['vae'] = vae
done = True
else:
done = len(vae)==0
completer.complete_extensions(None) completer.complete_extensions(None)
for field in ('width','height'): for field in ('width','height'):
@ -537,8 +547,8 @@ def edit_config(model_name:str, gen, opt, completer):
conf = config[model_name] conf = config[model_name]
new_config = {} new_config = {}
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae')) completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt'))
for field in ('description', 'weights', 'config', 'width','height'): for field in ('description', 'weights', 'vae', 'config', 'width','height'):
completer.linebuffer = str(conf[field]) if field in conf else '' completer.linebuffer = str(conf[field]) if field in conf else ''
new_value = input(f'{field}: ') new_value = input(f'{field}: ')
new_config[field] = int(new_value) if field in ('width','height') else new_value new_config[field] = int(new_value) if field in ('width','height') else new_value