diffusers: support loading an alternate VAE

This commit is contained in:
Kevin Turner 2022-12-14 09:05:45 -08:00
parent 1f86e527aa
commit 3607042c9d
2 changed files with 59 additions and 6 deletions

View File

@ -6,14 +6,24 @@
# and the width and height of the images it
# was trained on.
diffusers-1.4:
description: Diffusers version of Stable Diffusion version 1.4
description: 🤗🧨 Stable Diffusion v1.4
format: diffusers
repo_name: CompVis/stable-diffusion-v1-4
default: true
diffusers-1.5:
description: Diffusers version of Stable Diffusion version 1.5
description: 🤗🧨 Stable Diffusion v1.5
format: diffusers
repo_name: runwayml/stable-diffusion-v1-5
default: true
diffusers-1.5+mse:
description: 🤗🧨 Stable Diffusion v1.5 + MSE-finetuned VAE
format: diffusers
repo_name: runwayml/stable-diffusion-v1-5
vae:
repo_name: stabilityai/sd-vae-ft-mse
diffusers-inpainting-1.5:
description: 🤗🧨 inpainting for Stable Diffusion v1.5
format: diffusers
repo_name: runwayml/stable-diffusion-inpainting
stable-diffusion-1.5:
description: The newest Stable Diffusion version 1.5 weight file (4.27 GB)
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt

View File

@ -21,6 +21,7 @@ from typing import Union
import torch
import transformers
from diffusers import AutoencoderKL
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import RevisionNotFoundError
from omegaconf import OmegaConf
@ -337,14 +338,20 @@ class ModelCache(object):
# TODO: scan weights maybe?
if 'vae' in mconfig:
vae = self._load_vae(mconfig['vae'])
pipeline_args.update(vae=vae)
if self.precision == 'float16':
print(' | Using faster float16 precision')
if not isinstance(name_or_path, Path):
# hub has no explicit API for different data types, but the main Stable Diffusion
# releases set a precedent for putting float16 weights in a fp16 branch.
try:
hf_hub_download(name_or_path, "model_index.json", revision="fp16")
except RevisionNotFoundError as e:
pass
except RevisionNotFoundError:
pass # no such branch, assume we should use the default.
else:
pipeline_args.update(revision="fp16")
@ -362,7 +369,6 @@ class ModelCache(object):
# want to leave it as a separate processing node. It ends up using the same diffusers
# code either way, so we can table it for now.
safety_checker=None,
# TODO: alternate VAE
# TODO: local_files_only=True
**pipeline_args
)
@ -535,3 +541,40 @@ class ModelCache(object):
with open(hashpath,'w') as f:
f.write(hash)
return hash
def _load_vae(self, vae_config):
vae_args = {}
if 'repo_name' in vae_config:
name_or_path = vae_config['repo_name']
elif 'path' in vae_config:
name_or_path = Path(vae_config['path'])
if not name_or_path.is_absolute():
name_or_path = Path(Globals.root, name_or_path).resolve()
else:
raise ValueError("VAE config must specify either repo_name or path.")
print(f'>> Loading diffusers VAE from {name_or_path}')
if self.precision == 'float16':
print(' | Using faster float16 precision')
if not isinstance(name_or_path, Path):
try:
hf_hub_download(name_or_path, "model_index.json", revision="fp16")
except RevisionNotFoundError:
pass
else:
vae_args.update(revision="fp16")
vae_args.update(torch_dtype=torch.float16)
else:
# TODO: more accurately, "using the model's default precision."
# How do we find out what that is?
print(' | Using more accurate float32 precision')
if 'subfolder' in vae_config:
vae_args['subfolder'] = vae_config['subfolder']
# At some point we might need to be able to use different classes here? But for now I think
# all Stable Diffusion VAE are AutoencoderKL.
return AutoencoderKL.from_pretrained(name_or_path, **vae_args)