mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into bugfix/use-cu117-wheel
This commit is contained in:
commit
61c3886843
@ -47,11 +47,11 @@ if [ "$0" != "bash" ]; then
|
|||||||
;;
|
;;
|
||||||
3)
|
3)
|
||||||
echo "Starting Textual Inversion:"
|
echo "Starting Textual Inversion:"
|
||||||
exec textual_inversion --gui $@
|
exec invokeai-ti --gui $@
|
||||||
;;
|
;;
|
||||||
4)
|
4)
|
||||||
echo "Merging Models:"
|
echo "Merging Models:"
|
||||||
exec merge_models --gui $@
|
exec invokeai-merge --gui $@
|
||||||
;;
|
;;
|
||||||
5)
|
5)
|
||||||
echo "Developer Console:"
|
echo "Developer Console:"
|
||||||
|
@ -20,6 +20,7 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from ldm.invoke.globals import Globals, global_cache_dir
|
from ldm.invoke.globals import Globals, global_cache_dir
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
@ -44,6 +45,7 @@ from diffusers import (
|
|||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
|
logging as dlogging,
|
||||||
)
|
)
|
||||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
|
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
|
||||||
@ -795,8 +797,9 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
prediction_type:str=None,
|
prediction_type:str=None,
|
||||||
extract_ema:bool=True,
|
extract_ema:bool=True,
|
||||||
upcast_attn:bool=False,
|
upcast_attn:bool=False,
|
||||||
vae:AutoencoderKL=None
|
vae:AutoencoderKL=None,
|
||||||
)->StableDiffusionGeneratorPipeline:
|
return_generator_pipeline:bool=False,
|
||||||
|
)->Union[StableDiffusionPipeline,StableDiffusionGeneratorPipeline]:
|
||||||
'''
|
'''
|
||||||
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
||||||
config file.
|
config file.
|
||||||
@ -823,166 +826,173 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
|
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
|
||||||
running stable diffusion 2.1.
|
running stable diffusion 2.1.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter('ignore')
|
||||||
|
verbosity = dlogging.get_verbosity()
|
||||||
|
dlogging.set_verbosity_error()
|
||||||
|
|
||||||
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
|
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
|
||||||
cache_dir = global_cache_dir('hub')
|
cache_dir = global_cache_dir('hub')
|
||||||
|
pipeline_class = StableDiffusionGeneratorPipeline if return_generator_pipeline else StableDiffusionPipeline
|
||||||
|
|
||||||
# Sometimes models don't have the global_step item
|
# Sometimes models don't have the global_step item
|
||||||
if "global_step" in checkpoint:
|
if "global_step" in checkpoint:
|
||||||
global_step = checkpoint["global_step"]
|
global_step = checkpoint["global_step"]
|
||||||
else:
|
|
||||||
print(" | global_step key not found in model")
|
|
||||||
global_step = None
|
|
||||||
|
|
||||||
# sometimes there is a state_dict key and sometimes not
|
|
||||||
if 'state_dict' in checkpoint:
|
|
||||||
checkpoint = checkpoint["state_dict"]
|
|
||||||
|
|
||||||
upcast_attention = False
|
|
||||||
if original_config_file is None:
|
|
||||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
|
||||||
|
|
||||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
|
||||||
original_config_file = os.path.join(Globals.root,'configs','stable-diffusion','v2-inference-v.yaml')
|
|
||||||
|
|
||||||
if global_step == 110000:
|
|
||||||
# v2.1 needs to upcast attention
|
|
||||||
upcast_attention = True
|
|
||||||
else:
|
else:
|
||||||
original_config_file = os.path.join(Globals.root,'configs','stable-diffusion','v1-inference.yaml')
|
print(" | global_step key not found in model")
|
||||||
|
global_step = None
|
||||||
|
|
||||||
original_config = OmegaConf.load(original_config_file)
|
# sometimes there is a state_dict key and sometimes not
|
||||||
|
if 'state_dict' in checkpoint:
|
||||||
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
if num_in_channels is not None:
|
upcast_attention = False
|
||||||
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
if original_config_file is None:
|
||||||
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
|
|
||||||
if (
|
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||||
"parameterization" in original_config["model"]["params"]
|
original_config_file = os.path.join(Globals.root,'configs','stable-diffusion','v2-inference-v.yaml')
|
||||||
and original_config["model"]["params"]["parameterization"] == "v"
|
|
||||||
):
|
|
||||||
if prediction_type is None:
|
|
||||||
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
|
|
||||||
# as it relies on a brittle global step parameter here
|
|
||||||
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
|
|
||||||
if image_size is None:
|
|
||||||
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
|
|
||||||
# as it relies on a brittle global step parameter here
|
|
||||||
image_size = 512 if global_step == 875000 else 768
|
|
||||||
else:
|
|
||||||
if prediction_type is None:
|
|
||||||
prediction_type = "epsilon"
|
|
||||||
if image_size is None:
|
|
||||||
image_size = 512
|
|
||||||
|
|
||||||
num_train_timesteps = original_config.model.params.timesteps
|
if global_step == 110000:
|
||||||
beta_start = original_config.model.params.linear_start
|
# v2.1 needs to upcast attention
|
||||||
beta_end = original_config.model.params.linear_end
|
upcast_attention = True
|
||||||
|
else:
|
||||||
|
original_config_file = os.path.join(Globals.root,'configs','stable-diffusion','v1-inference.yaml')
|
||||||
|
|
||||||
scheduler = DDIMScheduler(
|
original_config = OmegaConf.load(original_config_file)
|
||||||
beta_end=beta_end,
|
|
||||||
beta_schedule="scaled_linear",
|
|
||||||
beta_start=beta_start,
|
|
||||||
num_train_timesteps=num_train_timesteps,
|
|
||||||
steps_offset=1,
|
|
||||||
clip_sample=False,
|
|
||||||
set_alpha_to_one=False,
|
|
||||||
prediction_type=prediction_type,
|
|
||||||
)
|
|
||||||
# make sure scheduler works correctly with DDIM
|
|
||||||
scheduler.register_to_config(clip_sample=False)
|
|
||||||
|
|
||||||
if scheduler_type == "pndm":
|
if num_in_channels is not None:
|
||||||
config = dict(scheduler.config)
|
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||||
config["skip_prk_steps"] = True
|
|
||||||
scheduler = PNDMScheduler.from_config(config)
|
|
||||||
elif scheduler_type == "lms":
|
|
||||||
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
|
||||||
elif scheduler_type == "heun":
|
|
||||||
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
|
|
||||||
elif scheduler_type == "euler":
|
|
||||||
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
|
|
||||||
elif scheduler_type == "euler-ancestral":
|
|
||||||
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
|
||||||
elif scheduler_type == "dpm":
|
|
||||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
|
||||||
elif scheduler_type == "ddim":
|
|
||||||
scheduler = scheduler
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
|
||||||
|
|
||||||
# Convert the UNet2DConditionModel model.
|
if (
|
||||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
"parameterization" in original_config["model"]["params"]
|
||||||
unet_config["upcast_attention"] = upcast_attention
|
and original_config["model"]["params"]["parameterization"] == "v"
|
||||||
unet = UNet2DConditionModel(**unet_config)
|
):
|
||||||
|
if prediction_type is None:
|
||||||
|
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
|
||||||
|
# as it relies on a brittle global step parameter here
|
||||||
|
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
|
||||||
|
if image_size is None:
|
||||||
|
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
|
||||||
|
# as it relies on a brittle global step parameter here
|
||||||
|
image_size = 512 if global_step == 875000 else 768
|
||||||
|
else:
|
||||||
|
if prediction_type is None:
|
||||||
|
prediction_type = "epsilon"
|
||||||
|
if image_size is None:
|
||||||
|
image_size = 512
|
||||||
|
|
||||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
num_train_timesteps = original_config.model.params.timesteps
|
||||||
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
|
beta_start = original_config.model.params.linear_start
|
||||||
)
|
beta_end = original_config.model.params.linear_end
|
||||||
|
|
||||||
unet.load_state_dict(converted_unet_checkpoint)
|
scheduler = DDIMScheduler(
|
||||||
|
beta_end=beta_end,
|
||||||
# Convert the VAE model, or use the one passed
|
beta_schedule="scaled_linear",
|
||||||
if not vae:
|
beta_start=beta_start,
|
||||||
print(f' | Using checkpoint model\'s original VAE')
|
num_train_timesteps=num_train_timesteps,
|
||||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
steps_offset=1,
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
clip_sample=False,
|
||||||
|
set_alpha_to_one=False,
|
||||||
vae = AutoencoderKL(**vae_config)
|
prediction_type=prediction_type,
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
|
||||||
else:
|
|
||||||
print(f' | Using external VAE specified in config')
|
|
||||||
|
|
||||||
# Convert the text model.
|
|
||||||
model_type = pipeline_type
|
|
||||||
if model_type is None:
|
|
||||||
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
|
||||||
|
|
||||||
if model_type == "FrozenOpenCLIPEmbedder":
|
|
||||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
|
||||||
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2",
|
|
||||||
subfolder="tokenizer",
|
|
||||||
cache_dir=global_cache_dir('diffusers')
|
|
||||||
)
|
|
||||||
pipe = StableDiffusionGeneratorPipeline(
|
|
||||||
vae=vae,
|
|
||||||
text_encoder=text_model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
unet=unet,
|
|
||||||
scheduler=scheduler,
|
|
||||||
safety_checker=None,
|
|
||||||
feature_extractor=None,
|
|
||||||
requires_safety_checker=False,
|
|
||||||
)
|
)
|
||||||
elif model_type == "PaintByExample":
|
# make sure scheduler works correctly with DDIM
|
||||||
vision_model = convert_paint_by_example_checkpoint(checkpoint)
|
scheduler.register_to_config(clip_sample=False)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
if scheduler_type == "pndm":
|
||||||
pipe = PaintByExamplePipeline(
|
config = dict(scheduler.config)
|
||||||
vae=vae,
|
config["skip_prk_steps"] = True
|
||||||
image_encoder=vision_model,
|
scheduler = PNDMScheduler.from_config(config)
|
||||||
unet=unet,
|
elif scheduler_type == "lms":
|
||||||
scheduler=scheduler,
|
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
||||||
safety_checker=None,
|
elif scheduler_type == "heun":
|
||||||
feature_extractor=feature_extractor,
|
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
|
||||||
|
elif scheduler_type == "euler":
|
||||||
|
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
|
||||||
|
elif scheduler_type == "euler-ancestral":
|
||||||
|
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
||||||
|
elif scheduler_type == "dpm":
|
||||||
|
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||||
|
elif scheduler_type == "ddim":
|
||||||
|
scheduler = scheduler
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
||||||
|
|
||||||
|
# Convert the UNet2DConditionModel model.
|
||||||
|
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||||
|
unet_config["upcast_attention"] = upcast_attention
|
||||||
|
unet = UNet2DConditionModel(**unet_config)
|
||||||
|
|
||||||
|
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||||
|
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
|
||||||
)
|
)
|
||||||
elif model_type in ['FrozenCLIPEmbedder','WeightedFrozenCLIPEmbedder']:
|
|
||||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
unet.load_state_dict(converted_unet_checkpoint)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
# Convert the VAE model, or use the one passed
|
||||||
pipe = StableDiffusionGeneratorPipeline(
|
if not vae:
|
||||||
vae=vae,
|
print(' | Using checkpoint model\'s original VAE')
|
||||||
text_encoder=text_model,
|
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||||
tokenizer=tokenizer,
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||||
unet=unet,
|
|
||||||
scheduler=scheduler,
|
vae = AutoencoderKL(**vae_config)
|
||||||
safety_checker=None,
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
feature_extractor=feature_extractor,
|
else:
|
||||||
)
|
print(' | Using external VAE specified in config')
|
||||||
else:
|
|
||||||
text_config = create_ldm_bert_config(original_config)
|
# Convert the text model.
|
||||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
model_type = pipeline_type
|
||||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",cache_dir=cache_dir)
|
if model_type is None:
|
||||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||||
|
|
||||||
|
if model_type == "FrozenOpenCLIPEmbedder":
|
||||||
|
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2",
|
||||||
|
subfolder="tokenizer",
|
||||||
|
cache_dir=global_cache_dir('diffusers')
|
||||||
|
)
|
||||||
|
pipe = pipeline_class(
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=text_model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
)
|
||||||
|
elif model_type == "PaintByExample":
|
||||||
|
vision_model = convert_paint_by_example_checkpoint(checkpoint)
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||||
|
pipe = PaintByExamplePipeline(
|
||||||
|
vae=vae,
|
||||||
|
image_encoder=vision_model,
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
)
|
||||||
|
elif model_type in ['FrozenCLIPEmbedder','WeightedFrozenCLIPEmbedder']:
|
||||||
|
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||||
|
pipe = pipeline_class(
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=text_model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text_config = create_ldm_bert_config(original_config)
|
||||||
|
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||||
|
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",cache_dir=cache_dir)
|
||||||
|
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||||
|
dlogging.set_verbosity(verbosity)
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
@ -1000,6 +1010,7 @@ def convert_ckpt_to_diffuser(
|
|||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
pipe.save_pretrained(
|
pipe.save_pretrained(
|
||||||
dump_path,
|
dump_path,
|
||||||
safe_serialization=is_safetensors_available(),
|
safe_serialization=is_safetensors_available(),
|
||||||
|
@ -8,13 +8,14 @@ import argparse
|
|||||||
import curses
|
import curses
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import traceback
|
||||||
|
import warnings
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import npyscreen
|
import npyscreen
|
||||||
import warnings
|
from diffusers import DiffusionPipeline, logging as dlogging
|
||||||
from diffusers import DiffusionPipeline
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from ldm.invoke.globals import (
|
from ldm.invoke.globals import (
|
||||||
@ -46,18 +47,24 @@ def merge_diffusion_models(
|
|||||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||||
"""
|
"""
|
||||||
pipe = DiffusionPipeline.from_pretrained(
|
with warnings.catch_warnings():
|
||||||
model_ids_or_paths[0],
|
warnings.simplefilter('ignore')
|
||||||
cache_dir=kwargs.get("cache_dir", global_cache_dir()),
|
verbosity = dlogging.get_verbosity()
|
||||||
custom_pipeline="checkpoint_merger",
|
dlogging.set_verbosity_error()
|
||||||
)
|
|
||||||
merged_pipe = pipe.merge(
|
pipe = DiffusionPipeline.from_pretrained(
|
||||||
pretrained_model_name_or_path_list=model_ids_or_paths,
|
model_ids_or_paths[0],
|
||||||
alpha=alpha,
|
cache_dir=kwargs.get("cache_dir", global_cache_dir()),
|
||||||
interp=interp,
|
custom_pipeline="checkpoint_merger",
|
||||||
force=force,
|
)
|
||||||
**kwargs,
|
merged_pipe = pipe.merge(
|
||||||
)
|
pretrained_model_name_or_path_list=model_ids_or_paths,
|
||||||
|
alpha=alpha,
|
||||||
|
interp=interp,
|
||||||
|
force=force,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
dlogging.set_verbosity(verbosity)
|
||||||
return merged_pipe
|
return merged_pipe
|
||||||
|
|
||||||
|
|
||||||
@ -443,22 +450,5 @@ def main():
|
|||||||
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
|
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
|
||||||
args.cache_dir = cache_dir
|
args.cache_dir = cache_dir
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter('ignore')
|
|
||||||
try:
|
|
||||||
if args.front_end:
|
|
||||||
run_gui(args)
|
|
||||||
else:
|
|
||||||
run_cli(args)
|
|
||||||
print(f'>> Conversion successful.')
|
|
||||||
except Exception as e:
|
|
||||||
if str(e).startswith('Not enough space'):
|
|
||||||
print('** Not enough horizontal space! Try making the window wider, or relaunch with a smaller starting size.')
|
|
||||||
else:
|
|
||||||
print(f"** An error occurred while merging the pipelines: {str(e)}")
|
|
||||||
sys.exit(-1)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -356,6 +356,7 @@ class ModelManager(object):
|
|||||||
checkpoint_path = weights,
|
checkpoint_path = weights,
|
||||||
original_config_file = config,
|
original_config_file = config,
|
||||||
vae = vae,
|
vae = vae,
|
||||||
|
return_generator_pipeline=True,
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
pipeline.to(self.device).to(torch.float16 if self.precision == 'float16' else torch.float32),
|
pipeline.to(self.device).to(torch.float16 if self.precision == 'float16' else torch.float32),
|
||||||
|
Loading…
Reference in New Issue
Block a user