mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
diffusers: support loading an alternate VAE
This commit is contained in:
parent
1f86e527aa
commit
3607042c9d
@ -6,14 +6,24 @@
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
diffusers-1.4:
|
||||
description: Diffusers version of Stable Diffusion version 1.4
|
||||
description: 🤗🧨 Stable Diffusion v1.4
|
||||
format: diffusers
|
||||
repo_name: CompVis/stable-diffusion-v1-4
|
||||
default: true
|
||||
diffusers-1.5:
|
||||
description: Diffusers version of Stable Diffusion version 1.5
|
||||
description: 🤗🧨 Stable Diffusion v1.5
|
||||
format: diffusers
|
||||
repo_name: runwayml/stable-diffusion-v1-5
|
||||
default: true
|
||||
diffusers-1.5+mse:
|
||||
description: 🤗🧨 Stable Diffusion v1.5 + MSE-finetuned VAE
|
||||
format: diffusers
|
||||
repo_name: runwayml/stable-diffusion-v1-5
|
||||
vae:
|
||||
repo_name: stabilityai/sd-vae-ft-mse
|
||||
diffusers-inpainting-1.5:
|
||||
description: 🤗🧨 inpainting for Stable Diffusion v1.5
|
||||
format: diffusers
|
||||
repo_name: runwayml/stable-diffusion-inpainting
|
||||
stable-diffusion-1.5:
|
||||
description: The newest Stable Diffusion version 1.5 weight file (4.27 GB)
|
||||
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
||||
|
@ -21,6 +21,7 @@ from typing import Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from diffusers import AutoencoderKL
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import RevisionNotFoundError
|
||||
from omegaconf import OmegaConf
|
||||
@ -337,14 +338,20 @@ class ModelCache(object):
|
||||
|
||||
# TODO: scan weights maybe?
|
||||
|
||||
if 'vae' in mconfig:
|
||||
vae = self._load_vae(mconfig['vae'])
|
||||
pipeline_args.update(vae=vae)
|
||||
|
||||
if self.precision == 'float16':
|
||||
print(' | Using faster float16 precision')
|
||||
|
||||
if not isinstance(name_or_path, Path):
|
||||
# hub has no explicit API for different data types, but the main Stable Diffusion
|
||||
# releases set a precedent for putting float16 weights in a fp16 branch.
|
||||
try:
|
||||
hf_hub_download(name_or_path, "model_index.json", revision="fp16")
|
||||
except RevisionNotFoundError as e:
|
||||
pass
|
||||
except RevisionNotFoundError:
|
||||
pass # no such branch, assume we should use the default.
|
||||
else:
|
||||
pipeline_args.update(revision="fp16")
|
||||
|
||||
@ -362,7 +369,6 @@ class ModelCache(object):
|
||||
# want to leave it as a separate processing node. It ends up using the same diffusers
|
||||
# code either way, so we can table it for now.
|
||||
safety_checker=None,
|
||||
# TODO: alternate VAE
|
||||
# TODO: local_files_only=True
|
||||
**pipeline_args
|
||||
)
|
||||
@ -535,3 +541,40 @@ class ModelCache(object):
|
||||
with open(hashpath,'w') as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
|
||||
def _load_vae(self, vae_config):
|
||||
vae_args = {}
|
||||
|
||||
if 'repo_name' in vae_config:
|
||||
name_or_path = vae_config['repo_name']
|
||||
elif 'path' in vae_config:
|
||||
name_or_path = Path(vae_config['path'])
|
||||
if not name_or_path.is_absolute():
|
||||
name_or_path = Path(Globals.root, name_or_path).resolve()
|
||||
else:
|
||||
raise ValueError("VAE config must specify either repo_name or path.")
|
||||
|
||||
print(f'>> Loading diffusers VAE from {name_or_path}')
|
||||
if self.precision == 'float16':
|
||||
print(' | Using faster float16 precision')
|
||||
|
||||
if not isinstance(name_or_path, Path):
|
||||
try:
|
||||
hf_hub_download(name_or_path, "model_index.json", revision="fp16")
|
||||
except RevisionNotFoundError:
|
||||
pass
|
||||
else:
|
||||
vae_args.update(revision="fp16")
|
||||
|
||||
vae_args.update(torch_dtype=torch.float16)
|
||||
else:
|
||||
# TODO: more accurately, "using the model's default precision."
|
||||
# How do we find out what that is?
|
||||
print(' | Using more accurate float32 precision')
|
||||
|
||||
if 'subfolder' in vae_config:
|
||||
vae_args['subfolder'] = vae_config['subfolder']
|
||||
|
||||
# At some point we might need to be able to use different classes here? But for now I think
|
||||
# all Stable Diffusion VAE are AutoencoderKL.
|
||||
return AutoencoderKL.from_pretrained(name_or_path, **vae_args)
|
||||
|
Loading…
Reference in New Issue
Block a user