mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into bugfix/use-cu117-wheel
This commit is contained in:
commit
61c3886843
@ -47,11 +47,11 @@ if [ "$0" != "bash" ]; then
|
||||
;;
|
||||
3)
|
||||
echo "Starting Textual Inversion:"
|
||||
exec textual_inversion --gui $@
|
||||
exec invokeai-ti --gui $@
|
||||
;;
|
||||
4)
|
||||
echo "Merging Models:"
|
||||
exec merge_models --gui $@
|
||||
exec invokeai-merge --gui $@
|
||||
;;
|
||||
5)
|
||||
echo "Developer Console:"
|
||||
|
@ -20,6 +20,7 @@
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
from safetensors.torch import load_file
|
||||
@ -44,6 +45,7 @@ from diffusers import (
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
logging as dlogging,
|
||||
)
|
||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
|
||||
@ -795,8 +797,9 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
prediction_type:str=None,
|
||||
extract_ema:bool=True,
|
||||
upcast_attn:bool=False,
|
||||
vae:AutoencoderKL=None
|
||||
)->StableDiffusionGeneratorPipeline:
|
||||
vae:AutoencoderKL=None,
|
||||
return_generator_pipeline:bool=False,
|
||||
)->Union[StableDiffusionPipeline,StableDiffusionGeneratorPipeline]:
|
||||
'''
|
||||
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
||||
config file.
|
||||
@ -824,8 +827,14 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
running stable diffusion 2.1.
|
||||
'''
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
|
||||
cache_dir = global_cache_dir('hub')
|
||||
pipeline_class = StableDiffusionGeneratorPipeline if return_generator_pipeline else StableDiffusionPipeline
|
||||
|
||||
# Sometimes models don't have the global_step item
|
||||
if "global_step" in checkpoint:
|
||||
@ -923,14 +932,14 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
|
||||
# Convert the VAE model, or use the one passed
|
||||
if not vae:
|
||||
print(f' | Using checkpoint model\'s original VAE')
|
||||
print(' | Using checkpoint model\'s original VAE')
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
else:
|
||||
print(f' | Using external VAE specified in config')
|
||||
print(' | Using external VAE specified in config')
|
||||
|
||||
# Convert the text model.
|
||||
model_type = pipeline_type
|
||||
@ -943,7 +952,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
subfolder="tokenizer",
|
||||
cache_dir=global_cache_dir('diffusers')
|
||||
)
|
||||
pipe = StableDiffusionGeneratorPipeline(
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
@ -969,7 +978,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||
pipe = StableDiffusionGeneratorPipeline(
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
@ -983,6 +992,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",cache_dir=cache_dir)
|
||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
dlogging.set_verbosity(verbosity)
|
||||
|
||||
return pipe
|
||||
|
||||
@ -1000,6 +1010,7 @@ def convert_ckpt_to_diffuser(
|
||||
checkpoint_path,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
pipe.save_pretrained(
|
||||
dump_path,
|
||||
safe_serialization=is_safetensors_available(),
|
||||
|
@ -8,13 +8,14 @@ import argparse
|
||||
import curses
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
import npyscreen
|
||||
import warnings
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import DiffusionPipeline, logging as dlogging
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ldm.invoke.globals import (
|
||||
@ -46,6 +47,11 @@ def merge_diffusion_models(
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_ids_or_paths[0],
|
||||
cache_dir=kwargs.get("cache_dir", global_cache_dir()),
|
||||
@ -58,6 +64,7 @@ def merge_diffusion_models(
|
||||
force=force,
|
||||
**kwargs,
|
||||
)
|
||||
dlogging.set_verbosity(verbosity)
|
||||
return merged_pipe
|
||||
|
||||
|
||||
@ -443,22 +450,5 @@ def main():
|
||||
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
|
||||
args.cache_dir = cache_dir
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
try:
|
||||
if args.front_end:
|
||||
run_gui(args)
|
||||
else:
|
||||
run_cli(args)
|
||||
print(f'>> Conversion successful.')
|
||||
except Exception as e:
|
||||
if str(e).startswith('Not enough space'):
|
||||
print('** Not enough horizontal space! Try making the window wider, or relaunch with a smaller starting size.')
|
||||
else:
|
||||
print(f"** An error occurred while merging the pipelines: {str(e)}")
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -356,6 +356,7 @@ class ModelManager(object):
|
||||
checkpoint_path = weights,
|
||||
original_config_file = config,
|
||||
vae = vae,
|
||||
return_generator_pipeline=True,
|
||||
)
|
||||
return (
|
||||
pipeline.to(self.device).to(torch.float16 if self.precision == 'float16' else torch.float32),
|
||||
|
Loading…
x
Reference in New Issue
Block a user