mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Port the command-line tools to use model_manager2 (#5546)
* Port the command-line tools to use model_manager2 1.Reimplement the following: - invokeai-model-install - invokeai-merge - invokeai-ti To avoid breaking the original modeal manager, the udpated tools have been renamed invokeai-model-install2 and invokeai-merge2. The textual inversion training script should continue to work with existing installations. The "starter" models now live in `invokeai/configs/INITIAL_MODELS2.yaml`. When the full model manager 2 is in place and working, I'll rename these files and commands. 2. Add the `merge` route to the web API. This will merge two or three models, resulting a new one. - Note that because the model installer selectively installs the `fp16` variant of models (rather than both 16- and 32-bit versions as previous), the diffusers merge script will choke on any huggingface diffuserse models that were downloaded with the new installer. Previously-downloaded models should continue to merge correctly. I have a PR upstream https://github.com/huggingface/diffusers/pull/6670 to fix this. 3. (more important!) During implementation of the CLI tools, found and fixed a number of small runtime bugs in the model_manager2 implementation: - During model database migration, if a registered models file was not found on disk, the migration would be aborted. Now the offending model is skipped with a log warning. - Caught and fixed a condition in which the installer would download the entire diffusers repo when the user provided a single `.safetensors` file URL. - Caught and fixed a condition in which the installer would raise an exception and stop the app when a request for an unknown model's metadata was passed to Civitai. Now an error is logged and the installer continues. - Replaced the LoWRA starter LoRA with FlatColor. The former has been removed from Civitai. * fix ruff issue --------- Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
@ -3,7 +3,7 @@
|
||||
"""
|
||||
This is the frontend to "textual_inversion_training.py".
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
Copyright (c) 2023-24 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ import sys
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import npyscreen
|
||||
from npyscreen import widget
|
||||
@ -22,8 +22,9 @@ from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from ...backend.training import do_textual_inversion_training, parse_args
|
||||
from invokeai.backend.install.install_helper import initialize_installer
|
||||
from invokeai.backend.model_manager import ModelType
|
||||
from invokeai.backend.training import do_textual_inversion_training, parse_args
|
||||
|
||||
TRAINING_DATA = "text-inversion-training-data"
|
||||
TRAINING_DIR = "text-inversion-output"
|
||||
@ -44,19 +45,21 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
precisions = ["no", "fp16", "bf16"]
|
||||
learnable_properties = ["object", "style"]
|
||||
|
||||
def __init__(self, parentApp, name, saved_args=None):
|
||||
def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, saved_args: Optional[Dict[str, str]] = None):
|
||||
self.saved_args = saved_args or {}
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
def afterEditing(self):
|
||||
def afterEditing(self) -> None:
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self):
|
||||
def create(self) -> None:
|
||||
self.model_names, default = self.get_model_names()
|
||||
default_initializer_token = "★"
|
||||
default_placeholder_token = ""
|
||||
saved_args = self.saved_args
|
||||
|
||||
assert config is not None
|
||||
|
||||
try:
|
||||
default = self.model_names.index(saved_args["model"])
|
||||
except Exception:
|
||||
@ -71,7 +74,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
self.model = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Model Name:",
|
||||
values=self.model_names,
|
||||
values=sorted(self.model_names),
|
||||
value=default,
|
||||
max_height=len(self.model_names) + 1,
|
||||
scroll_exit=True,
|
||||
@ -236,7 +239,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
)
|
||||
self.model.editing = True
|
||||
|
||||
def initializer_changed(self):
|
||||
def initializer_changed(self) -> None:
|
||||
placeholder = self.placeholder_token.value
|
||||
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
|
||||
self.train_data_dir.value = str(config.root_dir / TRAINING_DATA / placeholder)
|
||||
@ -275,10 +278,13 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
return True
|
||||
|
||||
def get_model_names(self) -> Tuple[List[str], int]:
|
||||
conf = OmegaConf.load(config.root_dir / "configs/models.yaml")
|
||||
model_names = [idx for idx in sorted(conf.keys()) if conf[idx].get("format", None) == "diffusers"]
|
||||
defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]]
|
||||
default = defaults[0] if len(defaults) > 0 else 0
|
||||
global config
|
||||
assert config is not None
|
||||
installer = initialize_installer(config)
|
||||
store = installer.record_store
|
||||
main_models = store.search_by_attr(model_type=ModelType.Main)
|
||||
model_names = [f"{x.base.value}/{x.type.value}/{x.name}" for x in main_models if x.format == "diffusers"]
|
||||
default = 0
|
||||
return (model_names, default)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
@ -326,7 +332,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
|
||||
class MyApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, saved_args=None):
|
||||
def __init__(self, saved_args: Optional[Dict[str, str]] = None):
|
||||
super().__init__()
|
||||
self.ti_arguments = None
|
||||
self.saved_args = saved_args
|
||||
@ -341,11 +347,12 @@ class MyApplication(npyscreen.NPSAppManaged):
|
||||
)
|
||||
|
||||
|
||||
def copy_to_embeddings_folder(args: dict):
|
||||
def copy_to_embeddings_folder(args: Dict[str, str]) -> None:
|
||||
"""
|
||||
Copy learned_embeds.bin into the embeddings folder, and offer to
|
||||
delete the full model and checkpoints.
|
||||
"""
|
||||
assert config is not None
|
||||
source = Path(args["output_dir"], "learned_embeds.bin")
|
||||
dest_dir_name = args["placeholder_token"].strip("<>")
|
||||
destination = config.root_dir / "embeddings" / dest_dir_name
|
||||
@ -358,10 +365,11 @@ def copy_to_embeddings_folder(args: dict):
|
||||
logger.info(f'Keeping {args["output_dir"]}')
|
||||
|
||||
|
||||
def save_args(args: dict):
|
||||
def save_args(args: dict) -> None:
|
||||
"""
|
||||
Save the current argument values to an omegaconf file
|
||||
"""
|
||||
assert config is not None
|
||||
dest_dir = config.root_dir / TRAINING_DIR
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
conf_file = dest_dir / CONF_FILE
|
||||
@ -373,6 +381,7 @@ def previous_args() -> dict:
|
||||
"""
|
||||
Get the previous arguments used.
|
||||
"""
|
||||
assert config is not None
|
||||
conf_file = config.root_dir / TRAINING_DIR / CONF_FILE
|
||||
try:
|
||||
conf = OmegaConf.load(conf_file)
|
||||
@ -383,24 +392,26 @@ def previous_args() -> dict:
|
||||
return conf
|
||||
|
||||
|
||||
def do_front_end(args: Namespace):
|
||||
def do_front_end() -> None:
|
||||
global config
|
||||
saved_args = previous_args()
|
||||
myapplication = MyApplication(saved_args=saved_args)
|
||||
myapplication.run()
|
||||
|
||||
if args := myapplication.ti_arguments:
|
||||
os.makedirs(args["output_dir"], exist_ok=True)
|
||||
if my_args := myapplication.ti_arguments:
|
||||
os.makedirs(my_args["output_dir"], exist_ok=True)
|
||||
|
||||
# Automatically add angle brackets around the trigger
|
||||
if not re.match("^<.+>$", args["placeholder_token"]):
|
||||
args["placeholder_token"] = f"<{args['placeholder_token']}>"
|
||||
if not re.match("^<.+>$", my_args["placeholder_token"]):
|
||||
my_args["placeholder_token"] = f"<{my_args['placeholder_token']}>"
|
||||
|
||||
args["only_save_embeds"] = True
|
||||
save_args(args)
|
||||
my_args["only_save_embeds"] = True
|
||||
save_args(my_args)
|
||||
|
||||
try:
|
||||
do_textual_inversion_training(InvokeAIAppConfig.get_config(), **args)
|
||||
copy_to_embeddings_folder(args)
|
||||
print(my_args)
|
||||
do_textual_inversion_training(config, **my_args)
|
||||
copy_to_embeddings_folder(my_args)
|
||||
except Exception as e:
|
||||
logger.error("An exception occurred during training. The exception was:")
|
||||
logger.error(str(e))
|
||||
@ -408,11 +419,12 @@ def do_front_end(args: Namespace):
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
global config
|
||||
|
||||
args = parse_args()
|
||||
args: Namespace = parse_args()
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args([])
|
||||
|
||||
# change root if needed
|
||||
if args.root_dir:
|
||||
@ -420,7 +432,7 @@ def main():
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
do_front_end(args)
|
||||
do_front_end()
|
||||
else:
|
||||
do_textual_inversion_training(config, **vars(args))
|
||||
except AssertionError as e:
|
||||
|
454
invokeai/frontend/training/textual_inversion2.py
Normal file
454
invokeai/frontend/training/textual_inversion2.py
Normal file
@ -0,0 +1,454 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
This is the frontend to "textual_inversion_training.py".
|
||||
|
||||
Copyright (c) 2023-24 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import npyscreen
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.install.install_helper import initialize_installer
|
||||
from invokeai.backend.model_manager import ModelType
|
||||
from invokeai.backend.training import do_textual_inversion_training, parse_args
|
||||
|
||||
TRAINING_DATA = "text-inversion-training-data"
|
||||
TRAINING_DIR = "text-inversion-output"
|
||||
CONF_FILE = "preferences.conf"
|
||||
config = None
|
||||
|
||||
|
||||
class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
resolutions = [512, 768, 1024]
|
||||
lr_schedulers = [
|
||||
"linear",
|
||||
"cosine",
|
||||
"cosine_with_restarts",
|
||||
"polynomial",
|
||||
"constant",
|
||||
"constant_with_warmup",
|
||||
]
|
||||
precisions = ["no", "fp16", "bf16"]
|
||||
learnable_properties = ["object", "style"]
|
||||
|
||||
def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, saved_args: Optional[Dict[str, str]] = None):
|
||||
self.saved_args = saved_args or {}
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
def afterEditing(self) -> None:
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self) -> None:
|
||||
self.model_names, default = self.get_model_names()
|
||||
default_initializer_token = "★"
|
||||
default_placeholder_token = ""
|
||||
saved_args = self.saved_args
|
||||
|
||||
assert config is not None
|
||||
|
||||
try:
|
||||
default = self.model_names.index(saved_args["model"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields, cursor arrows to make a selection, and space to toggle checkboxes.",
|
||||
editable=False,
|
||||
)
|
||||
|
||||
self.model = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Model Name:",
|
||||
values=sorted(self.model_names),
|
||||
value=default,
|
||||
max_height=len(self.model_names) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.placeholder_token = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Trigger Term:",
|
||||
value="", # saved_args.get('placeholder_token',''), # to restore previous term
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.placeholder_token.when_value_edited = self.initializer_changed
|
||||
self.nextrely -= 1
|
||||
self.nextrelx += 30
|
||||
self.prompt_token = self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
name="Trigger term for use in prompt",
|
||||
value="",
|
||||
editable=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrelx -= 30
|
||||
self.initializer_token = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Initializer:",
|
||||
value=saved_args.get("initializer_token", default_initializer_token),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.resume_from_checkpoint = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Resume from last saved checkpoint",
|
||||
value=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.learnable_property = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Learnable property:",
|
||||
values=self.learnable_properties,
|
||||
value=self.learnable_properties.index(saved_args.get("learnable_property", "object")),
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.train_data_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="Data Training Directory:",
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
value=str(
|
||||
saved_args.get(
|
||||
"train_data_dir",
|
||||
config.root_dir / TRAINING_DATA / default_placeholder_token,
|
||||
)
|
||||
),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.output_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="Output Destination Directory:",
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
value=str(
|
||||
saved_args.get(
|
||||
"output_dir",
|
||||
config.root_dir / TRAINING_DIR / default_placeholder_token,
|
||||
)
|
||||
),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.resolution = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Image resolution (pixels):",
|
||||
values=self.resolutions,
|
||||
value=self.resolutions.index(saved_args.get("resolution", 512)),
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.center_crop = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Center crop images before resizing to resolution",
|
||||
value=saved_args.get("center_crop", False),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.mixed_precision = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Mixed Precision:",
|
||||
values=self.precisions,
|
||||
value=self.precisions.index(saved_args.get("mixed_precision", "fp16")),
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.num_train_epochs = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Number of training epochs:",
|
||||
out_of=1000,
|
||||
step=50,
|
||||
lowest=1,
|
||||
value=saved_args.get("num_train_epochs", 100),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.max_train_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Max Training Steps:",
|
||||
out_of=10000,
|
||||
step=500,
|
||||
lowest=1,
|
||||
value=saved_args.get("max_train_steps", 3000),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.train_batch_size = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Batch Size (reduce if you run out of memory):",
|
||||
out_of=50,
|
||||
step=1,
|
||||
lowest=1,
|
||||
value=saved_args.get("train_batch_size", 8),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.gradient_accumulation_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Gradient Accumulation Steps (may need to decrease this to resume from a checkpoint):",
|
||||
out_of=10,
|
||||
step=1,
|
||||
lowest=1,
|
||||
value=saved_args.get("gradient_accumulation_steps", 4),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.lr_warmup_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Warmup Steps:",
|
||||
out_of=100,
|
||||
step=1,
|
||||
lowest=0,
|
||||
value=saved_args.get("lr_warmup_steps", 0),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.learning_rate = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Learning Rate:",
|
||||
value=str(
|
||||
saved_args.get("learning_rate", "5.0e-04"),
|
||||
),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.scale_lr = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Scale learning rate by number GPUs, steps and batch size",
|
||||
value=saved_args.get("scale_lr", True),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.enable_xformers_memory_efficient_attention = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Use xformers acceleration",
|
||||
value=saved_args.get("enable_xformers_memory_efficient_attention", False),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.lr_scheduler = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Learning rate scheduler:",
|
||||
values=self.lr_schedulers,
|
||||
max_height=7,
|
||||
value=self.lr_schedulers.index(saved_args.get("lr_scheduler", "constant")),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.model.editing = True
|
||||
|
||||
def initializer_changed(self) -> None:
|
||||
placeholder = self.placeholder_token.value
|
||||
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
|
||||
self.train_data_dir.value = str(config.root_dir / TRAINING_DATA / placeholder)
|
||||
self.output_dir.value = str(config.root_dir / TRAINING_DIR / placeholder)
|
||||
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
|
||||
|
||||
def on_ok(self):
|
||||
if self.validate_field_values():
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.ti_arguments = self.marshall_arguments()
|
||||
npyscreen.notify("Launching textual inversion training. This will take a while...")
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def ok_cancel(self):
|
||||
sys.exit(0)
|
||||
|
||||
def validate_field_values(self) -> bool:
|
||||
bad_fields = []
|
||||
if self.model.value is None:
|
||||
bad_fields.append("Model Name must correspond to a known model in models.yaml")
|
||||
if not re.match("^[a-zA-Z0-9.-]+$", self.placeholder_token.value):
|
||||
bad_fields.append("Trigger term must only contain alphanumeric characters, the dot and hyphen")
|
||||
if self.train_data_dir.value is None:
|
||||
bad_fields.append("Data Training Directory cannot be empty")
|
||||
if self.output_dir.value is None:
|
||||
bad_fields.append("The Output Destination Directory cannot be empty")
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:"
|
||||
for problem in bad_fields:
|
||||
message += f"\n* {problem}"
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self) -> Tuple[List[str], int]:
|
||||
global config
|
||||
assert config is not None
|
||||
installer = initialize_installer(config)
|
||||
store = installer.record_store
|
||||
main_models = store.search_by_attr(model_type=ModelType.Main)
|
||||
model_names = [f"{x.base.value}/{x.type.value}/{x.name}" for x in main_models if x.format == "diffusers"]
|
||||
default = 0
|
||||
return (model_names, default)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
args = {}
|
||||
|
||||
# the choices
|
||||
args.update(
|
||||
model=self.model_names[self.model.value[0]],
|
||||
resolution=self.resolutions[self.resolution.value[0]],
|
||||
lr_scheduler=self.lr_schedulers[self.lr_scheduler.value[0]],
|
||||
mixed_precision=self.precisions[self.mixed_precision.value[0]],
|
||||
learnable_property=self.learnable_properties[self.learnable_property.value[0]],
|
||||
)
|
||||
|
||||
# all the strings and booleans
|
||||
for attr in (
|
||||
"initializer_token",
|
||||
"placeholder_token",
|
||||
"train_data_dir",
|
||||
"output_dir",
|
||||
"scale_lr",
|
||||
"center_crop",
|
||||
"enable_xformers_memory_efficient_attention",
|
||||
):
|
||||
args[attr] = getattr(self, attr).value
|
||||
|
||||
# all the integers
|
||||
for attr in (
|
||||
"train_batch_size",
|
||||
"gradient_accumulation_steps",
|
||||
"num_train_epochs",
|
||||
"max_train_steps",
|
||||
"lr_warmup_steps",
|
||||
):
|
||||
args[attr] = int(getattr(self, attr).value)
|
||||
|
||||
# the floats (just one)
|
||||
args.update(learning_rate=float(self.learning_rate.value))
|
||||
|
||||
# a special case
|
||||
if self.resume_from_checkpoint.value and Path(self.output_dir.value).exists():
|
||||
args["resume_from_checkpoint"] = "latest"
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class MyApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, saved_args: Optional[Dict[str, str]] = None):
|
||||
super().__init__()
|
||||
self.ti_arguments = None
|
||||
self.saved_args = saved_args
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main = self.addForm(
|
||||
"MAIN",
|
||||
textualInversionForm,
|
||||
name="Textual Inversion Settings",
|
||||
saved_args=self.saved_args,
|
||||
)
|
||||
|
||||
|
||||
def copy_to_embeddings_folder(args: Dict[str, str]) -> None:
|
||||
"""
|
||||
Copy learned_embeds.bin into the embeddings folder, and offer to
|
||||
delete the full model and checkpoints.
|
||||
"""
|
||||
assert config is not None
|
||||
source = Path(args["output_dir"], "learned_embeds.bin")
|
||||
dest_dir_name = args["placeholder_token"].strip("<>")
|
||||
destination = config.root_dir / "embeddings" / dest_dir_name
|
||||
os.makedirs(destination, exist_ok=True)
|
||||
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
|
||||
shutil.copy(source, destination)
|
||||
if (input("Delete training logs and intermediate checkpoints? [y] ") or "y").startswith(("y", "Y")):
|
||||
shutil.rmtree(Path(args["output_dir"]))
|
||||
else:
|
||||
logger.info(f'Keeping {args["output_dir"]}')
|
||||
|
||||
|
||||
def save_args(args: dict) -> None:
|
||||
"""
|
||||
Save the current argument values to an omegaconf file
|
||||
"""
|
||||
assert config is not None
|
||||
dest_dir = config.root_dir / TRAINING_DIR
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
conf_file = dest_dir / CONF_FILE
|
||||
conf = OmegaConf.create(args)
|
||||
OmegaConf.save(config=conf, f=conf_file)
|
||||
|
||||
|
||||
def previous_args() -> dict:
|
||||
"""
|
||||
Get the previous arguments used.
|
||||
"""
|
||||
assert config is not None
|
||||
conf_file = config.root_dir / TRAINING_DIR / CONF_FILE
|
||||
try:
|
||||
conf = OmegaConf.load(conf_file)
|
||||
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
|
||||
except Exception:
|
||||
conf = None
|
||||
|
||||
return conf
|
||||
|
||||
|
||||
def do_front_end() -> None:
|
||||
global config
|
||||
saved_args = previous_args()
|
||||
myapplication = MyApplication(saved_args=saved_args)
|
||||
myapplication.run()
|
||||
|
||||
if my_args := myapplication.ti_arguments:
|
||||
os.makedirs(my_args["output_dir"], exist_ok=True)
|
||||
|
||||
# Automatically add angle brackets around the trigger
|
||||
if not re.match("^<.+>$", my_args["placeholder_token"]):
|
||||
my_args["placeholder_token"] = f"<{my_args['placeholder_token']}>"
|
||||
|
||||
my_args["only_save_embeds"] = True
|
||||
save_args(my_args)
|
||||
|
||||
try:
|
||||
print(my_args)
|
||||
do_textual_inversion_training(config, **my_args)
|
||||
copy_to_embeddings_folder(my_args)
|
||||
except Exception as e:
|
||||
logger.error("An exception occurred during training. The exception was:")
|
||||
logger.error(str(e))
|
||||
logger.error("DETAILS:")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
def main() -> None:
|
||||
global config
|
||||
|
||||
args: Namespace = parse_args()
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args([])
|
||||
|
||||
# change root if needed
|
||||
if args.root_dir:
|
||||
config.root = args.root_dir
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
do_front_end()
|
||||
else:
|
||||
do_textual_inversion_training(config, **vars(args))
|
||||
except AssertionError as e:
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except (widget.NotEnoughSpaceForWidget, Exception) as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
logger.error("You need to have at least one diffusers models defined in models.yaml in order to train")
|
||||
elif str(e).startswith("addwstr"):
|
||||
logger.error("Not enough window space for the interface. Please make your window larger and try again.")
|
||||
else:
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user