mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Apply black
This commit is contained in:
@ -60,9 +60,7 @@ from invokeai.backend.install.model_install_backend import (
|
||||
InstallSelections,
|
||||
ModelInstall,
|
||||
)
|
||||
from invokeai.backend.model_management.model_probe import (
|
||||
ModelType, BaseModelType
|
||||
)
|
||||
from invokeai.backend.model_management.model_probe import ModelType, BaseModelType
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
transformers.logging.set_verbosity_error()
|
||||
@ -77,7 +75,7 @@ Model_dir = "models"
|
||||
Default_config_file = config.model_conf_path
|
||||
SD_Configs = config.legacy_conf_path
|
||||
|
||||
PRECISION_CHOICES = ['auto','float16','float32']
|
||||
PRECISION_CHOICES = ["auto", "float16", "float32"]
|
||||
|
||||
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# This is the InvokeAI initialization file, which contains command-line default values.
|
||||
@ -85,7 +83,8 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# or renaming it and then running invokeai-configure again.
|
||||
"""
|
||||
|
||||
logger=InvokeAILogger.getLogger()
|
||||
logger = InvokeAILogger.getLogger()
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
def postscript(errors: None):
|
||||
@ -108,7 +107,9 @@ Add the '--help' argument to see all of the command-line switches available for
|
||||
"""
|
||||
|
||||
else:
|
||||
message = "\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n"
|
||||
message = (
|
||||
"\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n"
|
||||
)
|
||||
for err in errors:
|
||||
message += f"\t - {err}\n"
|
||||
message += "Please check the logs above and correct any issues."
|
||||
@ -169,9 +170,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
||||
logger.info(f"Installing {label} model file {model_url}...")
|
||||
if not os.path.exists(model_dest):
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
request.urlretrieve(
|
||||
model_url, model_dest, ProgressBar(os.path.basename(model_dest))
|
||||
)
|
||||
request.urlretrieve(model_url, model_dest, ProgressBar(os.path.basename(model_dest)))
|
||||
logger.info("...downloaded successfully")
|
||||
else:
|
||||
logger.info("...exists")
|
||||
@ -182,90 +181,93 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
||||
|
||||
|
||||
def download_conversion_models():
|
||||
target_dir = config.root_path / 'models/core/convert'
|
||||
target_dir = config.root_path / "models/core/convert"
|
||||
kwargs = dict() # for future use
|
||||
try:
|
||||
logger.info('Downloading core tokenizers and text encoders')
|
||||
logger.info("Downloading core tokenizers and text encoders")
|
||||
|
||||
# bert
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
|
||||
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True)
|
||||
|
||||
bert.save_pretrained(target_dir / "bert-base-uncased", safe_serialization=True)
|
||||
|
||||
# sd-1
|
||||
repo_id = 'openai/clip-vit-large-patch14'
|
||||
hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / 'clip-vit-large-patch14')
|
||||
hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / 'clip-vit-large-patch14')
|
||||
repo_id = "openai/clip-vit-large-patch14"
|
||||
hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / "clip-vit-large-patch14")
|
||||
hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / "clip-vit-large-patch14")
|
||||
|
||||
# sd-2
|
||||
repo_id = "stabilityai/stable-diffusion-2"
|
||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True)
|
||||
pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "tokenizer", safe_serialization=True)
|
||||
|
||||
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
|
||||
pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "text_encoder", safe_serialization=True)
|
||||
|
||||
# sd-xl - tokenizer_2
|
||||
repo_id = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
_, model_name = repo_id.split('/')
|
||||
_, model_name = repo_id.split("/")
|
||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
|
||||
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
|
||||
|
||||
|
||||
pipeline = CLIPTextConfig.from_pretrained(repo_id, **kwargs)
|
||||
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
|
||||
|
||||
|
||||
# VAE
|
||||
logger.info('Downloading stable diffusion VAE')
|
||||
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
|
||||
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True)
|
||||
logger.info("Downloading stable diffusion VAE")
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", **kwargs)
|
||||
vae.save_pretrained(target_dir / "sd-vae-ft-mse", safe_serialization=True)
|
||||
|
||||
# safety checking
|
||||
logger.info('Downloading safety checker')
|
||||
logger.info("Downloading safety checker")
|
||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
||||
pipeline = AutoFeatureExtractor.from_pretrained(repo_id, **kwargs)
|
||||
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
|
||||
|
||||
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
||||
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id, **kwargs)
|
||||
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_realesrgan():
|
||||
logger.info("Installing ESRGAN Upscaling models...")
|
||||
URLs = [
|
||||
dict(
|
||||
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||
description = "RealESRGAN_x4plus.pth",
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
dest="core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||
description="RealESRGAN_x4plus.pth",
|
||||
),
|
||||
dict(
|
||||
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||
description = "RealESRGAN_x4plus_anime_6B.pth",
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
dest="core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||
description="RealESRGAN_x4plus_anime_6B.pth",
|
||||
),
|
||||
dict(
|
||||
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
description = "ESRGAN_SRx4_DF2KOST_official.pth",
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
dest="core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
description="ESRGAN_SRx4_DF2KOST_official.pth",
|
||||
),
|
||||
dict(
|
||||
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
dest= "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||
description = "RealESRGAN_x2plus.pth",
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
dest="core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||
description="RealESRGAN_x2plus.pth",
|
||||
),
|
||||
]
|
||||
for model in URLs:
|
||||
download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])
|
||||
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_support_models():
|
||||
download_realesrgan()
|
||||
download_conversion_models()
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def get_root(root: str = None) -> str:
|
||||
if root:
|
||||
@ -275,6 +277,7 @@ def get_root(root: str = None) -> str:
|
||||
else:
|
||||
return str(config.root_path)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
# for responsive resizing - disabled
|
||||
@ -283,14 +286,14 @@ class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
def create(self):
|
||||
program_opts = self.parentApp.program_opts
|
||||
old_opts = self.parentApp.invokeai_opts
|
||||
first_time = not (config.root_path / 'invokeai.yaml').exists()
|
||||
first_time = not (config.root_path / "invokeai.yaml").exists()
|
||||
access_token = HfFolder.get_token()
|
||||
window_width, window_height = get_terminal_size()
|
||||
label = """Configure startup settings. You can come back and change these later.
|
||||
Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.
|
||||
Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
"""
|
||||
for i in textwrap.wrap(label,width=window_width-6):
|
||||
for i in textwrap.wrap(label, width=window_width - 6):
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=i,
|
||||
@ -300,7 +303,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
|
||||
self.nextrely += 1
|
||||
label = """HuggingFace access token (OPTIONAL) for automatic model downloads. See https://huggingface.co/settings/tokens."""
|
||||
for line in textwrap.wrap(label,width=window_width-6):
|
||||
for line in textwrap.wrap(label, width=window_width - 6):
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=line,
|
||||
@ -343,7 +346,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
relx=50,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -=1
|
||||
self.nextrely -= 1
|
||||
self.always_use_cpu = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Force CPU to be used on GPU systems",
|
||||
@ -351,10 +354,8 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
relx=80,
|
||||
scroll_exit=True,
|
||||
)
|
||||
precision = old_opts.precision or (
|
||||
"float32" if program_opts.full_precision else "auto"
|
||||
)
|
||||
self.nextrely +=1
|
||||
precision = old_opts.precision or ("float32" if program_opts.full_precision else "auto")
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Floating Point Precision",
|
||||
@ -363,10 +364,10 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -=1
|
||||
self.nextrely -= 1
|
||||
self.precision = self.add_widget_intelligent(
|
||||
SingleSelectColumns,
|
||||
columns = 3,
|
||||
columns=3,
|
||||
name="Precision",
|
||||
values=PRECISION_CHOICES,
|
||||
value=PRECISION_CHOICES.index(precision),
|
||||
@ -398,25 +399,25 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.autoimport_dirs = {}
|
||||
self.autoimport_dirs['autoimport_dir'] = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
name=f'Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models',
|
||||
value=str(config.root_path / config.autoimport_dir),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=32,
|
||||
max_height = 3,
|
||||
scroll_exit=True
|
||||
)
|
||||
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
name=f"Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
|
||||
value=str(config.root_path / config.autoimport_dir),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=32,
|
||||
max_height=3,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
label = """BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ
|
||||
AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSES LOCATED AT
|
||||
https://huggingface.co/spaces/CompVis/stable-diffusion-license and
|
||||
https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md
|
||||
"""
|
||||
for i in textwrap.wrap(label,width=window_width-6):
|
||||
for i in textwrap.wrap(label, width=window_width - 6):
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=i,
|
||||
@ -431,11 +432,7 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
label = (
|
||||
"DONE"
|
||||
if program_opts.skip_sd_weights or program_opts.default_only
|
||||
else "NEXT"
|
||||
)
|
||||
label = "DONE" if program_opts.skip_sd_weights or program_opts.default_only else "NEXT"
|
||||
self.ok_button = self.add_widget_intelligent(
|
||||
CenteredButtonPress,
|
||||
name=label,
|
||||
@ -454,13 +451,11 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
self.editing = False
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
|
||||
def validate_field_values(self, opt: Namespace) -> bool:
|
||||
bad_fields = []
|
||||
if not opt.license_acceptance:
|
||||
bad_fields.append(
|
||||
"Please accept the license terms before proceeding to model downloads"
|
||||
)
|
||||
bad_fields.append("Please accept the license terms before proceeding to model downloads")
|
||||
if not Path(opt.outdir).parent.exists():
|
||||
bad_fields.append(
|
||||
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
|
||||
@ -478,11 +473,11 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
new_opts = Namespace()
|
||||
|
||||
for attr in [
|
||||
"outdir",
|
||||
"free_gpu_mem",
|
||||
"max_cache_size",
|
||||
"xformers_enabled",
|
||||
"always_use_cpu",
|
||||
"outdir",
|
||||
"free_gpu_mem",
|
||||
"max_cache_size",
|
||||
"xformers_enabled",
|
||||
"always_use_cpu",
|
||||
]:
|
||||
setattr(new_opts, attr, getattr(self, attr).value)
|
||||
|
||||
@ -495,7 +490,7 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
new_opts.hf_token = self.hf_token.value
|
||||
new_opts.license_acceptance = self.license_acceptance.value
|
||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||
|
||||
|
||||
return new_opts
|
||||
|
||||
|
||||
@ -534,19 +529,20 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
|
||||
editApp.run()
|
||||
return editApp.new_opts()
|
||||
|
||||
|
||||
def default_startup_options(init_file: Path) -> Namespace:
|
||||
opts = InvokeAIAppConfig.get_config()
|
||||
return opts
|
||||
|
||||
|
||||
def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
||||
|
||||
try:
|
||||
installer = ModelInstall(config)
|
||||
except omegaconf.errors.ConfigKeyError:
|
||||
logger.warning('Your models.yaml file is corrupt or out of date. Reinitializing')
|
||||
logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing")
|
||||
initialize_rootdir(config.root_path, True)
|
||||
installer = ModelInstall(config)
|
||||
|
||||
|
||||
models = installer.all_models()
|
||||
return InstallSelections(
|
||||
install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id]
|
||||
@ -556,55 +552,46 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
||||
else list(),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||
logger.info("Initializing InvokeAI runtime directory")
|
||||
for name in (
|
||||
"models",
|
||||
"databases",
|
||||
"text-inversion-output",
|
||||
"text-inversion-training-data",
|
||||
"configs"
|
||||
):
|
||||
for name in ("models", "databases", "text-inversion-output", "text-inversion-training-data", "configs"):
|
||||
os.makedirs(os.path.join(root, name), exist_ok=True)
|
||||
for model_type in ModelType:
|
||||
Path(root, 'autoimport', model_type.value).mkdir(parents=True, exist_ok=True)
|
||||
Path(root, "autoimport", model_type.value).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
configs_src = Path(configs.__path__[0])
|
||||
configs_dest = root / "configs"
|
||||
if not os.path.samefile(configs_src, configs_dest):
|
||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||
|
||||
dest = root / 'models'
|
||||
dest = root / "models"
|
||||
for model_base in BaseModelType:
|
||||
for model_type in ModelType:
|
||||
path = dest / model_base.value / model_type.value
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
path = dest / 'core'
|
||||
path = dest / "core"
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
maybe_create_models_yaml(root)
|
||||
|
||||
|
||||
def maybe_create_models_yaml(root: Path):
|
||||
models_yaml = root / 'configs' / 'models.yaml'
|
||||
models_yaml = root / "configs" / "models.yaml"
|
||||
if models_yaml.exists():
|
||||
if OmegaConf.load(models_yaml).get('__metadata__'): # up to date
|
||||
if OmegaConf.load(models_yaml).get("__metadata__"): # up to date
|
||||
return
|
||||
else:
|
||||
logger.info('Creating new models.yaml, original saved as models.yaml.orig')
|
||||
models_yaml.rename(models_yaml.parent / 'models.yaml.orig')
|
||||
|
||||
with open(models_yaml,'w') as yaml_file:
|
||||
yaml_file.write(yaml.dump({'__metadata__':
|
||||
{'version':'3.0.0'}
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Creating new models.yaml, original saved as models.yaml.orig")
|
||||
models_yaml.rename(models_yaml.parent / "models.yaml.orig")
|
||||
|
||||
with open(models_yaml, "w") as yaml_file:
|
||||
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def run_console_ui(
|
||||
program_opts: Namespace, initfile: Path = None
|
||||
) -> (Namespace, Namespace):
|
||||
def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
|
||||
# parse_args() will read from init file if present
|
||||
invokeai_opts = default_startup_options(initfile)
|
||||
invokeai_opts.root = program_opts.root
|
||||
@ -616,8 +603,9 @@ def run_console_ui(
|
||||
# the install-models application spawns a subprocess to install
|
||||
# models, and will crash unless this is set before running.
|
||||
import torch
|
||||
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||
editApp.run()
|
||||
if editApp.user_cancelled:
|
||||
@ -634,39 +622,42 @@ def write_opts(opts: Namespace, init_file: Path):
|
||||
# this will load current settings
|
||||
new_config = InvokeAIAppConfig.get_config()
|
||||
new_config.root = config.root
|
||||
|
||||
for key,value in opts.__dict__.items():
|
||||
if hasattr(new_config,key):
|
||||
setattr(new_config,key,value)
|
||||
|
||||
with open(init_file,'w', encoding='utf-8') as file:
|
||||
for key, value in opts.__dict__.items():
|
||||
if hasattr(new_config, key):
|
||||
setattr(new_config, key, value)
|
||||
|
||||
with open(init_file, "w", encoding="utf-8") as file:
|
||||
file.write(new_config.to_yaml())
|
||||
|
||||
if hasattr(opts,'hf_token') and opts.hf_token:
|
||||
if hasattr(opts, "hf_token") and opts.hf_token:
|
||||
HfLogin(opts.hf_token)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def default_output_dir() -> Path:
|
||||
return config.root_path / "outputs"
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def write_default_options(program_opts: Namespace, initfile: Path):
|
||||
opt = default_startup_options(initfile)
|
||||
write_opts(opt, initfile)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
# Here we bring in
|
||||
# the legacy Args object in order to parse
|
||||
# the old init file and write out the new
|
||||
# yaml format.
|
||||
def migrate_init_file(legacy_format:Path):
|
||||
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
|
||||
def migrate_init_file(legacy_format: Path):
|
||||
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
|
||||
new = InvokeAIAppConfig.get_config()
|
||||
|
||||
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
||||
for attr in fields:
|
||||
if hasattr(old,attr):
|
||||
setattr(new,attr,getattr(old,attr))
|
||||
if hasattr(old, attr):
|
||||
setattr(new, attr, getattr(old, attr))
|
||||
|
||||
# a few places where the field names have changed and we have to
|
||||
# manually add in the new names/values
|
||||
@ -674,40 +665,43 @@ def migrate_init_file(legacy_format:Path):
|
||||
new.conf_path = old.conf
|
||||
new.root = legacy_format.parent.resolve()
|
||||
|
||||
invokeai_yaml = legacy_format.parent / 'invokeai.yaml'
|
||||
with open(invokeai_yaml,"w", encoding="utf-8") as outfile:
|
||||
invokeai_yaml = legacy_format.parent / "invokeai.yaml"
|
||||
with open(invokeai_yaml, "w", encoding="utf-8") as outfile:
|
||||
outfile.write(new.to_yaml())
|
||||
|
||||
legacy_format.replace(legacy_format.parent / 'invokeai.init.orig')
|
||||
legacy_format.replace(legacy_format.parent / "invokeai.init.orig")
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def migrate_models(root: Path):
|
||||
from invokeai.backend.install.migrate_to_3 import do_migrate
|
||||
|
||||
do_migrate(root, root)
|
||||
|
||||
def migrate_if_needed(opt: Namespace, root: Path)->bool:
|
||||
# We check for to see if the runtime directory is correctly initialized.
|
||||
old_init_file = root / 'invokeai.init'
|
||||
new_init_file = root / 'invokeai.yaml'
|
||||
old_hub = root / 'models/hub'
|
||||
migration_needed = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists()
|
||||
|
||||
if migration_needed:
|
||||
if opt.yes_to_all or \
|
||||
yes_or_no(f'{str(config.root_path)} appears to be a 2.3 format root directory. Convert to version 3.0?'):
|
||||
|
||||
logger.info('** Migrating invokeai.init to invokeai.yaml')
|
||||
def migrate_if_needed(opt: Namespace, root: Path) -> bool:
|
||||
# We check for to see if the runtime directory is correctly initialized.
|
||||
old_init_file = root / "invokeai.init"
|
||||
new_init_file = root / "invokeai.yaml"
|
||||
old_hub = root / "models/hub"
|
||||
migration_needed = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists()
|
||||
|
||||
if migration_needed:
|
||||
if opt.yes_to_all or yes_or_no(
|
||||
f"{str(config.root_path)} appears to be a 2.3 format root directory. Convert to version 3.0?"
|
||||
):
|
||||
logger.info("** Migrating invokeai.init to invokeai.yaml")
|
||||
migrate_init_file(old_init_file)
|
||||
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
|
||||
config.parse_args(argv=[], conf=OmegaConf.load(new_init_file))
|
||||
|
||||
if old_hub.exists():
|
||||
migrate_models(config.root_path)
|
||||
else:
|
||||
print('Cannot continue without conversion. Aborting.')
|
||||
|
||||
print("Cannot continue without conversion. Aborting.")
|
||||
|
||||
return migration_needed
|
||||
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||
@ -764,9 +758,9 @@ def main():
|
||||
|
||||
invoke_args = []
|
||||
if opt.root:
|
||||
invoke_args.extend(['--root',opt.root])
|
||||
invoke_args.extend(["--root", opt.root])
|
||||
if opt.full_precision:
|
||||
invoke_args.extend(['--precision','float32'])
|
||||
invoke_args.extend(["--precision", "float32"])
|
||||
config.parse_args(invoke_args)
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
|
||||
@ -782,22 +776,18 @@ def main():
|
||||
initialize_rootdir(config.root_path, opt.yes_to_all)
|
||||
|
||||
models_to_download = default_user_selections(opt)
|
||||
new_init_file = config.root_path / 'invokeai.yaml'
|
||||
new_init_file = config.root_path / "invokeai.yaml"
|
||||
if opt.yes_to_all:
|
||||
write_default_options(opt, new_init_file)
|
||||
init_options = Namespace(
|
||||
precision="float32" if opt.full_precision else "float16"
|
||||
)
|
||||
init_options = Namespace(precision="float32" if opt.full_precision else "float16")
|
||||
else:
|
||||
init_options, models_to_download = run_console_ui(opt, new_init_file)
|
||||
if init_options:
|
||||
write_opts(init_options, new_init_file)
|
||||
else:
|
||||
logger.info(
|
||||
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
|
||||
)
|
||||
logger.info('\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n')
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if opt.skip_support_models:
|
||||
logger.info("Skipping support models at user's request")
|
||||
else:
|
||||
@ -811,7 +801,7 @@ def main():
|
||||
|
||||
postscript(errors=errors)
|
||||
if not opt.yes_to_all:
|
||||
input('Press any key to continue...')
|
||||
input("Press any key to continue...")
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye! Come back soon.")
|
||||
|
||||
|
Reference in New Issue
Block a user