mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
diffusers: support loading an alternate VAE
This commit is contained in:
parent
1f86e527aa
commit
3607042c9d
@ -6,14 +6,24 @@
|
|||||||
# and the width and height of the images it
|
# and the width and height of the images it
|
||||||
# was trained on.
|
# was trained on.
|
||||||
diffusers-1.4:
|
diffusers-1.4:
|
||||||
description: Diffusers version of Stable Diffusion version 1.4
|
description: 🤗🧨 Stable Diffusion v1.4
|
||||||
format: diffusers
|
format: diffusers
|
||||||
repo_name: CompVis/stable-diffusion-v1-4
|
repo_name: CompVis/stable-diffusion-v1-4
|
||||||
default: true
|
|
||||||
diffusers-1.5:
|
diffusers-1.5:
|
||||||
description: Diffusers version of Stable Diffusion version 1.5
|
description: 🤗🧨 Stable Diffusion v1.5
|
||||||
format: diffusers
|
format: diffusers
|
||||||
repo_name: runwayml/stable-diffusion-v1-5
|
repo_name: runwayml/stable-diffusion-v1-5
|
||||||
|
default: true
|
||||||
|
diffusers-1.5+mse:
|
||||||
|
description: 🤗🧨 Stable Diffusion v1.5 + MSE-finetuned VAE
|
||||||
|
format: diffusers
|
||||||
|
repo_name: runwayml/stable-diffusion-v1-5
|
||||||
|
vae:
|
||||||
|
repo_name: stabilityai/sd-vae-ft-mse
|
||||||
|
diffusers-inpainting-1.5:
|
||||||
|
description: 🤗🧨 inpainting for Stable Diffusion v1.5
|
||||||
|
format: diffusers
|
||||||
|
repo_name: runwayml/stable-diffusion-inpainting
|
||||||
stable-diffusion-1.5:
|
stable-diffusion-1.5:
|
||||||
description: The newest Stable Diffusion version 1.5 weight file (4.27 GB)
|
description: The newest Stable Diffusion version 1.5 weight file (4.27 GB)
|
||||||
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
||||||
|
@ -21,6 +21,7 @@ from typing import Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
from diffusers import AutoencoderKL
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from huggingface_hub.utils import RevisionNotFoundError
|
from huggingface_hub.utils import RevisionNotFoundError
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@ -337,14 +338,20 @@ class ModelCache(object):
|
|||||||
|
|
||||||
# TODO: scan weights maybe?
|
# TODO: scan weights maybe?
|
||||||
|
|
||||||
|
if 'vae' in mconfig:
|
||||||
|
vae = self._load_vae(mconfig['vae'])
|
||||||
|
pipeline_args.update(vae=vae)
|
||||||
|
|
||||||
if self.precision == 'float16':
|
if self.precision == 'float16':
|
||||||
print(' | Using faster float16 precision')
|
print(' | Using faster float16 precision')
|
||||||
|
|
||||||
if not isinstance(name_or_path, Path):
|
if not isinstance(name_or_path, Path):
|
||||||
|
# hub has no explicit API for different data types, but the main Stable Diffusion
|
||||||
|
# releases set a precedent for putting float16 weights in a fp16 branch.
|
||||||
try:
|
try:
|
||||||
hf_hub_download(name_or_path, "model_index.json", revision="fp16")
|
hf_hub_download(name_or_path, "model_index.json", revision="fp16")
|
||||||
except RevisionNotFoundError as e:
|
except RevisionNotFoundError:
|
||||||
pass
|
pass # no such branch, assume we should use the default.
|
||||||
else:
|
else:
|
||||||
pipeline_args.update(revision="fp16")
|
pipeline_args.update(revision="fp16")
|
||||||
|
|
||||||
@ -362,7 +369,6 @@ class ModelCache(object):
|
|||||||
# want to leave it as a separate processing node. It ends up using the same diffusers
|
# want to leave it as a separate processing node. It ends up using the same diffusers
|
||||||
# code either way, so we can table it for now.
|
# code either way, so we can table it for now.
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
# TODO: alternate VAE
|
|
||||||
# TODO: local_files_only=True
|
# TODO: local_files_only=True
|
||||||
**pipeline_args
|
**pipeline_args
|
||||||
)
|
)
|
||||||
@ -535,3 +541,40 @@ class ModelCache(object):
|
|||||||
with open(hashpath,'w') as f:
|
with open(hashpath,'w') as f:
|
||||||
f.write(hash)
|
f.write(hash)
|
||||||
return hash
|
return hash
|
||||||
|
|
||||||
|
def _load_vae(self, vae_config):
|
||||||
|
vae_args = {}
|
||||||
|
|
||||||
|
if 'repo_name' in vae_config:
|
||||||
|
name_or_path = vae_config['repo_name']
|
||||||
|
elif 'path' in vae_config:
|
||||||
|
name_or_path = Path(vae_config['path'])
|
||||||
|
if not name_or_path.is_absolute():
|
||||||
|
name_or_path = Path(Globals.root, name_or_path).resolve()
|
||||||
|
else:
|
||||||
|
raise ValueError("VAE config must specify either repo_name or path.")
|
||||||
|
|
||||||
|
print(f'>> Loading diffusers VAE from {name_or_path}')
|
||||||
|
if self.precision == 'float16':
|
||||||
|
print(' | Using faster float16 precision')
|
||||||
|
|
||||||
|
if not isinstance(name_or_path, Path):
|
||||||
|
try:
|
||||||
|
hf_hub_download(name_or_path, "model_index.json", revision="fp16")
|
||||||
|
except RevisionNotFoundError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
vae_args.update(revision="fp16")
|
||||||
|
|
||||||
|
vae_args.update(torch_dtype=torch.float16)
|
||||||
|
else:
|
||||||
|
# TODO: more accurately, "using the model's default precision."
|
||||||
|
# How do we find out what that is?
|
||||||
|
print(' | Using more accurate float32 precision')
|
||||||
|
|
||||||
|
if 'subfolder' in vae_config:
|
||||||
|
vae_args['subfolder'] = vae_config['subfolder']
|
||||||
|
|
||||||
|
# At some point we might need to be able to use different classes here? But for now I think
|
||||||
|
# all Stable Diffusion VAE are AutoencoderKL.
|
||||||
|
return AutoencoderKL.from_pretrained(name_or_path, **vae_args)
|
||||||
|
Loading…
Reference in New Issue
Block a user