convert no longer creates StableDiffusionGenerator pipelines unless asked to

This commit is contained in:
Lincoln Stein 2023-02-03 10:04:32 -05:00
parent ca0f3ec0e4
commit 9e46badc40
2 changed files with 160 additions and 148 deletions

View File

@ -20,6 +20,7 @@
import os import os
import re import re
import torch import torch
import warnings
from pathlib import Path from pathlib import Path
from ldm.invoke.globals import Globals, global_cache_dir from ldm.invoke.globals import Globals, global_cache_dir
from safetensors.torch import load_file from safetensors.torch import load_file
@ -48,6 +49,7 @@ from diffusers import (
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
from diffusers.utils import is_safetensors_available from diffusers.utils import is_safetensors_available
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
@ -795,8 +797,9 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
prediction_type:str=None, prediction_type:str=None,
extract_ema:bool=True, extract_ema:bool=True,
upcast_attn:bool=False, upcast_attn:bool=False,
vae:AutoencoderKL=None vae:AutoencoderKL=None,
)->StableDiffusionGeneratorPipeline: return_generator_pipeline:bool=False,
)->Union[StableDiffusionPipeline,StableDiffusionGeneratorPipeline]:
''' '''
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
config file. config file.
@ -824,8 +827,14 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
running stable diffusion 2.1. running stable diffusion 2.1.
''' '''
with warnings.catch_warnings():
warnings.simplefilter('ignore')
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path) checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
cache_dir = global_cache_dir('hub') cache_dir = global_cache_dir('hub')
pipeline_class = StableDiffusionGeneratorPipeline if return_generator_pipeline else StableDiffusionPipeline
# Sometimes models don't have the global_step item # Sometimes models don't have the global_step item
if "global_step" in checkpoint: if "global_step" in checkpoint:
@ -923,14 +932,14 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
# Convert the VAE model, or use the one passed # Convert the VAE model, or use the one passed
if not vae: if not vae:
print(f' | Using checkpoint model\'s original VAE') print(' | Using checkpoint model\'s original VAE')
vae_config = create_vae_diffusers_config(original_config, image_size=image_size) vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)
else: else:
print(f' | Using external VAE specified in config') print(' | Using external VAE specified in config')
# Convert the text model. # Convert the text model.
model_type = pipeline_type model_type = pipeline_type
@ -943,7 +952,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
subfolder="tokenizer", subfolder="tokenizer",
cache_dir=global_cache_dir('diffusers') cache_dir=global_cache_dir('diffusers')
) )
pipe = StableDiffusionGeneratorPipeline( pipe = pipeline_class(
vae=vae, vae=vae,
text_encoder=text_model, text_encoder=text_model,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -969,7 +978,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
text_model = convert_ldm_clip_checkpoint(checkpoint) text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir) feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
pipe = StableDiffusionGeneratorPipeline( pipe = pipeline_class(
vae=vae, vae=vae,
text_encoder=text_model, text_encoder=text_model,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -983,6 +992,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",cache_dir=cache_dir) tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",cache_dir=cache_dir)
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
dlogging.set_verbosity(verbosity)
return pipe return pipe
@ -1000,6 +1010,7 @@ def convert_ckpt_to_diffuser(
checkpoint_path, checkpoint_path,
**kwargs **kwargs
) )
pipe.save_pretrained( pipe.save_pretrained(
dump_path, dump_path,
safe_serialization=is_safetensors_available(), safe_serialization=is_safetensors_available(),

View File

@ -356,6 +356,7 @@ class ModelManager(object):
checkpoint_path = weights, checkpoint_path = weights,
original_config_file = config, original_config_file = config,
vae = vae, vae = vae,
return_generator_pipeline=True,
) )
return ( return (
pipeline.to(self.device).to(torch.float16 if self.precision == 'float16' else torch.float32), pipeline.to(self.device).to(torch.float16 if self.precision == 'float16' else torch.float32),