mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add support for loading VAE autoencoders
To add a VAE autoencoder to an existing model: 1. Download the appropriate autoencoder and put it into models/ldm/stable-diffusion Note that you MUST use a VAE that was written for the original CompViz Stable Diffusion codebase. For v1.4, that would be the file named vae-ft-mse-840000-ema-pruned.ckpt that you can download from https://huggingface.co/stabilityai/sd-vae-ft-mse-original 2. Edit config/models.yaml to contain the following stanza, modifying `weights` and `vae` as required to match the weights and vae model file names. There is no requirement to rename the VAE file. ~~~ stable-diffusion-1.4: weights: models/ldm/stable-diffusion-v1/sd-v1-4.ckpt description: Stable Diffusion v1.4 config: configs/stable-diffusion/v1-inference.yaml vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt width: 512 height: 512 ~~~ 3. Alternatively from within the `invoke.py` CLI, you may use the command `!editmodel stable-diffusion-1.4` to bring up a simple editor that will allow you to add the path to the VAE. 4. If you are just installing InvokeAI for the first time, you can also use `!import_model models/ldm/stable-diffusion/sd-v1.4.ckpt` instead to create the configuration from scratch. 5. That's it!
This commit is contained in:
parent
7308022bc7
commit
51fdbe22d2
@ -9,6 +9,7 @@
|
|||||||
stable-diffusion-1.4:
|
stable-diffusion-1.4:
|
||||||
config: configs/stable-diffusion/v1-inference.yaml
|
config: configs/stable-diffusion/v1-inference.yaml
|
||||||
weights: models/ldm/stable-diffusion-v1/model.ckpt
|
weights: models/ldm/stable-diffusion-v1/model.ckpt
|
||||||
|
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||||
description: Stable Diffusion inference model version 1.4
|
description: Stable Diffusion inference model version 1.4
|
||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
|
@ -13,6 +13,7 @@ import gc
|
|||||||
import hashlib
|
import hashlib
|
||||||
import psutil
|
import psutil
|
||||||
import transformers
|
import transformers
|
||||||
|
import os
|
||||||
from sys import getrefcount
|
from sys import getrefcount
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.errors import ConfigAttributeError
|
from omegaconf.errors import ConfigAttributeError
|
||||||
@ -193,6 +194,7 @@ class ModelCache(object):
|
|||||||
mconfig = self.config[model_name]
|
mconfig = self.config[model_name]
|
||||||
config = mconfig.config
|
config = mconfig.config
|
||||||
weights = mconfig.weights
|
weights = mconfig.weights
|
||||||
|
vae = mconfig.get('vae',None)
|
||||||
width = mconfig.width
|
width = mconfig.width
|
||||||
height = mconfig.height
|
height = mconfig.height
|
||||||
|
|
||||||
@ -222,9 +224,17 @@ class ModelCache(object):
|
|||||||
else:
|
else:
|
||||||
print(' | Using more accurate float32 precision')
|
print(' | Using more accurate float32 precision')
|
||||||
|
|
||||||
|
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
|
||||||
|
if vae and os.path.exists(vae):
|
||||||
|
print(f' | Loading VAE weights from: {vae}')
|
||||||
|
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||||
|
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||||
|
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||||
|
|
||||||
model.to(self.device)
|
model.to(self.device)
|
||||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||||
model.cond_stage_model.device = self.device
|
model.cond_stage_model.device = self.device
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
|
@ -493,6 +493,16 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
|
|||||||
new_config['config'] = input('Configuration file for this model: ')
|
new_config['config'] = input('Configuration file for this model: ')
|
||||||
done = os.path.exists(new_config['config'])
|
done = os.path.exists(new_config['config'])
|
||||||
|
|
||||||
|
done = False
|
||||||
|
completer.complete_extensions(('.vae.pt','.vae','.ckpt'))
|
||||||
|
while not done:
|
||||||
|
vae = input('VAE autoencoder file for this model [None]: ')
|
||||||
|
if os.path.exists(vae):
|
||||||
|
new_config['vae'] = vae
|
||||||
|
done = True
|
||||||
|
else:
|
||||||
|
done = len(vae)==0
|
||||||
|
|
||||||
completer.complete_extensions(None)
|
completer.complete_extensions(None)
|
||||||
|
|
||||||
for field in ('width','height'):
|
for field in ('width','height'):
|
||||||
@ -537,8 +547,8 @@ def edit_config(model_name:str, gen, opt, completer):
|
|||||||
|
|
||||||
conf = config[model_name]
|
conf = config[model_name]
|
||||||
new_config = {}
|
new_config = {}
|
||||||
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae'))
|
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt'))
|
||||||
for field in ('description', 'weights', 'config', 'width','height'):
|
for field in ('description', 'weights', 'vae', 'config', 'width','height'):
|
||||||
completer.linebuffer = str(conf[field]) if field in conf else ''
|
completer.linebuffer = str(conf[field]) if field in conf else ''
|
||||||
new_value = input(f'{field}: ')
|
new_value = input(f'{field}: ')
|
||||||
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
||||||
|
Loading…
x
Reference in New Issue
Block a user