mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
convert no longer creates StableDiffusionGenerator pipelines unless asked to
This commit is contained in:
parent
ca0f3ec0e4
commit
9e46badc40
@ -20,6 +20,7 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from ldm.invoke.globals import Globals, global_cache_dir
|
from ldm.invoke.globals import Globals, global_cache_dir
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
@ -48,6 +49,7 @@ from diffusers import (
|
|||||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
|
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
|
||||||
from diffusers.utils import is_safetensors_available
|
from diffusers.utils import is_safetensors_available
|
||||||
|
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
|
||||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
|
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
|
||||||
|
|
||||||
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
@ -795,8 +797,9 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
prediction_type:str=None,
|
prediction_type:str=None,
|
||||||
extract_ema:bool=True,
|
extract_ema:bool=True,
|
||||||
upcast_attn:bool=False,
|
upcast_attn:bool=False,
|
||||||
vae:AutoencoderKL=None
|
vae:AutoencoderKL=None,
|
||||||
)->StableDiffusionGeneratorPipeline:
|
return_generator_pipeline:bool=False,
|
||||||
|
)->Union[StableDiffusionPipeline,StableDiffusionGeneratorPipeline]:
|
||||||
'''
|
'''
|
||||||
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
||||||
config file.
|
config file.
|
||||||
@ -824,8 +827,14 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
running stable diffusion 2.1.
|
running stable diffusion 2.1.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter('ignore')
|
||||||
|
verbosity = dlogging.get_verbosity()
|
||||||
|
dlogging.set_verbosity_error()
|
||||||
|
|
||||||
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
|
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
|
||||||
cache_dir = global_cache_dir('hub')
|
cache_dir = global_cache_dir('hub')
|
||||||
|
pipeline_class = StableDiffusionGeneratorPipeline if return_generator_pipeline else StableDiffusionPipeline
|
||||||
|
|
||||||
# Sometimes models don't have the global_step item
|
# Sometimes models don't have the global_step item
|
||||||
if "global_step" in checkpoint:
|
if "global_step" in checkpoint:
|
||||||
@ -923,14 +932,14 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
|
|
||||||
# Convert the VAE model, or use the one passed
|
# Convert the VAE model, or use the one passed
|
||||||
if not vae:
|
if not vae:
|
||||||
print(f' | Using checkpoint model\'s original VAE')
|
print(' | Using checkpoint model\'s original VAE')
|
||||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
else:
|
else:
|
||||||
print(f' | Using external VAE specified in config')
|
print(' | Using external VAE specified in config')
|
||||||
|
|
||||||
# Convert the text model.
|
# Convert the text model.
|
||||||
model_type = pipeline_type
|
model_type = pipeline_type
|
||||||
@ -943,7 +952,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
subfolder="tokenizer",
|
subfolder="tokenizer",
|
||||||
cache_dir=global_cache_dir('diffusers')
|
cache_dir=global_cache_dir('diffusers')
|
||||||
)
|
)
|
||||||
pipe = StableDiffusionGeneratorPipeline(
|
pipe = pipeline_class(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
text_encoder=text_model,
|
text_encoder=text_model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -969,7 +978,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||||
pipe = StableDiffusionGeneratorPipeline(
|
pipe = pipeline_class(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
text_encoder=text_model,
|
text_encoder=text_model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -983,6 +992,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",cache_dir=cache_dir)
|
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",cache_dir=cache_dir)
|
||||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||||
|
dlogging.set_verbosity(verbosity)
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
@ -1000,6 +1010,7 @@ def convert_ckpt_to_diffuser(
|
|||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
pipe.save_pretrained(
|
pipe.save_pretrained(
|
||||||
dump_path,
|
dump_path,
|
||||||
safe_serialization=is_safetensors_available(),
|
safe_serialization=is_safetensors_available(),
|
||||||
|
@ -356,6 +356,7 @@ class ModelManager(object):
|
|||||||
checkpoint_path = weights,
|
checkpoint_path = weights,
|
||||||
original_config_file = config,
|
original_config_file = config,
|
||||||
vae = vae,
|
vae = vae,
|
||||||
|
return_generator_pipeline=True,
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
pipeline.to(self.device).to(torch.float16 if self.precision == 'float16' else torch.float32),
|
pipeline.to(self.device).to(torch.float16 if self.precision == 'float16' else torch.float32),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user