convert no longer creates StableDiffusionGenerator pipelines unless asked to

This commit is contained in:
Lincoln Stein 2023-02-03 10:04:32 -05:00
parent ca0f3ec0e4
commit 9e46badc40
2 changed files with 160 additions and 148 deletions

View File

@ -20,6 +20,7 @@
import os
import re
import torch
import warnings
from pathlib import Path
from ldm.invoke.globals import Globals, global_cache_dir
from safetensors.torch import load_file
@ -48,6 +49,7 @@ from diffusers import (
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
from diffusers.utils import is_safetensors_available
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
@ -795,8 +797,9 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
prediction_type:str=None,
extract_ema:bool=True,
upcast_attn:bool=False,
vae:AutoencoderKL=None
)->StableDiffusionGeneratorPipeline:
vae:AutoencoderKL=None,
return_generator_pipeline:bool=False,
)->Union[StableDiffusionPipeline,StableDiffusionGeneratorPipeline]:
'''
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
config file.
@ -824,165 +827,172 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
running stable diffusion 2.1.
'''
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
cache_dir = global_cache_dir('hub')
with warnings.catch_warnings():
warnings.simplefilter('ignore')
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
# Sometimes models don't have the global_step item
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
print(" | global_step key not found in model")
global_step = None
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
cache_dir = global_cache_dir('hub')
pipeline_class = StableDiffusionGeneratorPipeline if return_generator_pipeline else StableDiffusionPipeline
# sometimes there is a state_dict key and sometimes not
if 'state_dict' in checkpoint:
checkpoint = checkpoint["state_dict"]
upcast_attention = False
if original_config_file is None:
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
original_config_file = os.path.join(Globals.root,'configs','stable-diffusion','v2-inference-v.yaml')
if global_step == 110000:
# v2.1 needs to upcast attention
upcast_attention = True
# Sometimes models don't have the global_step item
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
original_config_file = os.path.join(Globals.root,'configs','stable-diffusion','v1-inference.yaml')
print(" | global_step key not found in model")
global_step = None
original_config = OmegaConf.load(original_config_file)
# sometimes there is a state_dict key and sometimes not
if 'state_dict' in checkpoint:
checkpoint = checkpoint["state_dict"]
if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
upcast_attention = False
if original_config_file is None:
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if (
"parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
):
if prediction_type is None:
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
# as it relies on a brittle global step parameter here
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
if image_size is None:
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
image_size = 512 if global_step == 875000 else 768
else:
if prediction_type is None:
prediction_type = "epsilon"
if image_size is None:
image_size = 512
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
original_config_file = os.path.join(Globals.root,'configs','stable-diffusion','v2-inference-v.yaml')
num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start
beta_end = original_config.model.params.linear_end
if global_step == 110000:
# v2.1 needs to upcast attention
upcast_attention = True
else:
original_config_file = os.path.join(Globals.root,'configs','stable-diffusion','v1-inference.yaml')
scheduler = DDIMScheduler(
beta_end=beta_end,
beta_schedule="scaled_linear",
beta_start=beta_start,
num_train_timesteps=num_train_timesteps,
steps_offset=1,
clip_sample=False,
set_alpha_to_one=False,
prediction_type=prediction_type,
)
# make sure scheduler works correctly with DDIM
scheduler.register_to_config(clip_sample=False)
original_config = OmegaConf.load(original_config_file)
if scheduler_type == "pndm":
config = dict(scheduler.config)
config["skip_prk_steps"] = True
scheduler = PNDMScheduler.from_config(config)
elif scheduler_type == "lms":
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "heun":
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "euler":
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "euler-ancestral":
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == "ddim":
scheduler = scheduler
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention
unet = UNet2DConditionModel(**unet_config)
if (
"parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
):
if prediction_type is None:
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
# as it relies on a brittle global step parameter here
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
if image_size is None:
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
image_size = 512 if global_step == 875000 else 768
else:
if prediction_type is None:
prediction_type = "epsilon"
if image_size is None:
image_size = 512
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
)
num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start
beta_end = original_config.model.params.linear_end
unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model, or use the one passed
if not vae:
print(f' | Using checkpoint model\'s original VAE')
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
else:
print(f' | Using external VAE specified in config')
# Convert the text model.
model_type = pipeline_type
if model_type is None:
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2",
subfolder="tokenizer",
cache_dir=global_cache_dir('diffusers')
)
pipe = StableDiffusionGeneratorPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
scheduler = DDIMScheduler(
beta_end=beta_end,
beta_schedule="scaled_linear",
beta_start=beta_start,
num_train_timesteps=num_train_timesteps,
steps_offset=1,
clip_sample=False,
set_alpha_to_one=False,
prediction_type=prediction_type,
)
elif model_type == "PaintByExample":
vision_model = convert_paint_by_example_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
pipe = PaintByExamplePipeline(
vae=vae,
image_encoder=vision_model,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
# make sure scheduler works correctly with DDIM
scheduler.register_to_config(clip_sample=False)
if scheduler_type == "pndm":
config = dict(scheduler.config)
config["skip_prk_steps"] = True
scheduler = PNDMScheduler.from_config(config)
elif scheduler_type == "lms":
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "heun":
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "euler":
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "euler-ancestral":
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == "ddim":
scheduler = scheduler
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention
unet = UNet2DConditionModel(**unet_config)
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
)
elif model_type in ['FrozenCLIPEmbedder','WeightedFrozenCLIPEmbedder']:
text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
pipe = StableDiffusionGeneratorPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
)
else:
text_config = create_ldm_bert_config(original_config)
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",cache_dir=cache_dir)
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model, or use the one passed
if not vae:
print(' | Using checkpoint model\'s original VAE')
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
else:
print(' | Using external VAE specified in config')
# Convert the text model.
model_type = pipeline_type
if model_type is None:
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2",
subfolder="tokenizer",
cache_dir=global_cache_dir('diffusers')
)
pipe = pipeline_class(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
elif model_type == "PaintByExample":
vision_model = convert_paint_by_example_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
pipe = PaintByExamplePipeline(
vae=vae,
image_encoder=vision_model,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
)
elif model_type in ['FrozenCLIPEmbedder','WeightedFrozenCLIPEmbedder']:
text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
pipe = pipeline_class(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
)
else:
text_config = create_ldm_bert_config(original_config)
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",cache_dir=cache_dir)
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
dlogging.set_verbosity(verbosity)
return pipe
@ -1000,6 +1010,7 @@ def convert_ckpt_to_diffuser(
checkpoint_path,
**kwargs
)
pipe.save_pretrained(
dump_path,
safe_serialization=is_safetensors_available(),

View File

@ -356,6 +356,7 @@ class ModelManager(object):
checkpoint_path = weights,
original_config_file = config,
vae = vae,
return_generator_pipeline=True,
)
return (
pipeline.to(self.device).to(torch.float16 if self.precision == 'float16' else torch.float32),