mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
4515d52a42
This PR addresses issues raised by #3008. 1. Update documentation to indicate the correct maximum batch size for TI training when xformers is and isn't used. 2. Update textual inversion code so that the default for batch size is aware of xformer availability. 3. Add documentation for how to launch TI with distributed learning.
462 lines
15 KiB
Python
Executable File
462 lines
15 KiB
Python
Executable File
#!/usr/bin/env python
|
|
|
|
"""
|
|
This is the frontend to "textual_inversion_training.py".
|
|
|
|
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
|
"""
|
|
|
|
|
|
import os
|
|
import re
|
|
import shutil
|
|
import sys
|
|
import traceback
|
|
from argparse import Namespace
|
|
from pathlib import Path
|
|
from typing import List, Tuple
|
|
|
|
import npyscreen
|
|
from diffusers.utils.import_utils import is_xformers_available
|
|
from npyscreen import widget
|
|
from omegaconf import OmegaConf
|
|
|
|
from ldm.invoke.globals import Globals, global_set_root
|
|
from ldm.invoke.training.textual_inversion_training import (
|
|
do_textual_inversion_training,
|
|
parse_args,
|
|
)
|
|
|
|
TRAINING_DATA = "text-inversion-training-data"
|
|
TRAINING_DIR = "text-inversion-output"
|
|
CONF_FILE = "preferences.conf"
|
|
XFORMERS_AVAILABLE = is_xformers_available()
|
|
|
|
class textualInversionForm(npyscreen.FormMultiPageAction):
|
|
resolutions = [512, 768, 1024]
|
|
lr_schedulers = [
|
|
"linear",
|
|
"cosine",
|
|
"cosine_with_restarts",
|
|
"polynomial",
|
|
"constant",
|
|
"constant_with_warmup",
|
|
]
|
|
precisions = ["no", "fp16", "bf16"]
|
|
learnable_properties = ["object", "style"]
|
|
|
|
def __init__(self, parentApp, name, saved_args=None):
|
|
self.saved_args = saved_args or {}
|
|
super().__init__(parentApp, name)
|
|
|
|
def afterEditing(self):
|
|
self.parentApp.setNextForm(None)
|
|
|
|
def create(self):
|
|
self.model_names, default = self.get_model_names()
|
|
default_initializer_token = "★"
|
|
default_placeholder_token = ""
|
|
saved_args = self.saved_args
|
|
|
|
try:
|
|
default = self.model_names.index(saved_args["model"])
|
|
except:
|
|
pass
|
|
|
|
self.add_widget_intelligent(
|
|
npyscreen.FixedText,
|
|
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields, cursor arrows to make a selection, and space to toggle checkboxes.",
|
|
editable=False,
|
|
)
|
|
|
|
self.model = self.add_widget_intelligent(
|
|
npyscreen.TitleSelectOne,
|
|
name="Model Name:",
|
|
values=self.model_names,
|
|
value=default,
|
|
max_height=len(self.model_names) + 1,
|
|
scroll_exit=True,
|
|
)
|
|
self.placeholder_token = self.add_widget_intelligent(
|
|
npyscreen.TitleText,
|
|
name="Trigger Term:",
|
|
value="", # saved_args.get('placeholder_token',''), # to restore previous term
|
|
scroll_exit=True,
|
|
)
|
|
self.placeholder_token.when_value_edited = self.initializer_changed
|
|
self.nextrely -= 1
|
|
self.nextrelx += 30
|
|
self.prompt_token = self.add_widget_intelligent(
|
|
npyscreen.FixedText,
|
|
name="Trigger term for use in prompt",
|
|
value="",
|
|
editable=False,
|
|
scroll_exit=True,
|
|
)
|
|
self.nextrelx -= 30
|
|
self.initializer_token = self.add_widget_intelligent(
|
|
npyscreen.TitleText,
|
|
name="Initializer:",
|
|
value=saved_args.get("initializer_token", default_initializer_token),
|
|
scroll_exit=True,
|
|
)
|
|
self.resume_from_checkpoint = self.add_widget_intelligent(
|
|
npyscreen.Checkbox,
|
|
name="Resume from last saved checkpoint",
|
|
value=False,
|
|
scroll_exit=True,
|
|
)
|
|
self.learnable_property = self.add_widget_intelligent(
|
|
npyscreen.TitleSelectOne,
|
|
name="Learnable property:",
|
|
values=self.learnable_properties,
|
|
value=self.learnable_properties.index(
|
|
saved_args.get("learnable_property", "object")
|
|
),
|
|
max_height=4,
|
|
scroll_exit=True,
|
|
)
|
|
self.train_data_dir = self.add_widget_intelligent(
|
|
npyscreen.TitleFilename,
|
|
name="Data Training Directory:",
|
|
select_dir=True,
|
|
must_exist=False,
|
|
value=str(
|
|
saved_args.get(
|
|
"train_data_dir",
|
|
Path(Globals.root) / TRAINING_DATA / default_placeholder_token,
|
|
)
|
|
),
|
|
scroll_exit=True,
|
|
)
|
|
self.output_dir = self.add_widget_intelligent(
|
|
npyscreen.TitleFilename,
|
|
name="Output Destination Directory:",
|
|
select_dir=True,
|
|
must_exist=False,
|
|
value=str(
|
|
saved_args.get(
|
|
"output_dir",
|
|
Path(Globals.root) / TRAINING_DIR / default_placeholder_token,
|
|
)
|
|
),
|
|
scroll_exit=True,
|
|
)
|
|
self.resolution = self.add_widget_intelligent(
|
|
npyscreen.TitleSelectOne,
|
|
name="Image resolution (pixels):",
|
|
values=self.resolutions,
|
|
value=self.resolutions.index(saved_args.get("resolution", 512)),
|
|
max_height=4,
|
|
scroll_exit=True,
|
|
)
|
|
self.center_crop = self.add_widget_intelligent(
|
|
npyscreen.Checkbox,
|
|
name="Center crop images before resizing to resolution",
|
|
value=saved_args.get("center_crop", False),
|
|
scroll_exit=True,
|
|
)
|
|
self.mixed_precision = self.add_widget_intelligent(
|
|
npyscreen.TitleSelectOne,
|
|
name="Mixed Precision:",
|
|
values=self.precisions,
|
|
value=self.precisions.index(saved_args.get("mixed_precision", "fp16")),
|
|
max_height=4,
|
|
scroll_exit=True,
|
|
)
|
|
self.num_train_epochs = self.add_widget_intelligent(
|
|
npyscreen.TitleSlider,
|
|
name="Number of training epochs:",
|
|
out_of=1000,
|
|
step=50,
|
|
lowest=1,
|
|
value=saved_args.get("num_train_epochs", 100),
|
|
scroll_exit=True,
|
|
)
|
|
self.max_train_steps = self.add_widget_intelligent(
|
|
npyscreen.TitleSlider,
|
|
name="Max Training Steps:",
|
|
out_of=10000,
|
|
step=500,
|
|
lowest=1,
|
|
value=saved_args.get("max_train_steps", 2500),
|
|
scroll_exit=True,
|
|
)
|
|
self.train_batch_size = self.add_widget_intelligent(
|
|
npyscreen.TitleSlider,
|
|
name="Batch Size (reduce if you run out of memory):",
|
|
out_of=50,
|
|
step=1,
|
|
lowest=1,
|
|
value=saved_args.get("train_batch_size", 8 if XFORMERS_AVAILABLE else 3),
|
|
scroll_exit=True,
|
|
)
|
|
self.gradient_accumulation_steps = self.add_widget_intelligent(
|
|
npyscreen.TitleSlider,
|
|
name="Gradient Accumulation Steps (may need to decrease this to resume from a checkpoint):",
|
|
out_of=10,
|
|
step=1,
|
|
lowest=1,
|
|
value=saved_args.get("gradient_accumulation_steps", 4),
|
|
scroll_exit=True,
|
|
)
|
|
self.lr_warmup_steps = self.add_widget_intelligent(
|
|
npyscreen.TitleSlider,
|
|
name="Warmup Steps:",
|
|
out_of=100,
|
|
step=1,
|
|
lowest=0,
|
|
value=saved_args.get("lr_warmup_steps", 0),
|
|
scroll_exit=True,
|
|
)
|
|
self.learning_rate = self.add_widget_intelligent(
|
|
npyscreen.TitleText,
|
|
name="Learning Rate:",
|
|
value=str(
|
|
saved_args.get("learning_rate", "5.0e-04"),
|
|
),
|
|
scroll_exit=True,
|
|
)
|
|
self.scale_lr = self.add_widget_intelligent(
|
|
npyscreen.Checkbox,
|
|
name="Scale learning rate by number GPUs, steps and batch size",
|
|
value=saved_args.get("scale_lr", True),
|
|
scroll_exit=True,
|
|
)
|
|
self.enable_xformers_memory_efficient_attention = self.add_widget_intelligent(
|
|
npyscreen.Checkbox,
|
|
name="Use xformers acceleration",
|
|
value=saved_args.get("enable_xformers_memory_efficient_attention", XFORMERS_AVAILABLE),
|
|
scroll_exit=True,
|
|
)
|
|
self.lr_scheduler = self.add_widget_intelligent(
|
|
npyscreen.TitleSelectOne,
|
|
name="Learning rate scheduler:",
|
|
values=self.lr_schedulers,
|
|
max_height=7,
|
|
value=self.lr_schedulers.index(saved_args.get("lr_scheduler", "constant")),
|
|
scroll_exit=True,
|
|
)
|
|
self.model.editing = True
|
|
|
|
def initializer_changed(self):
|
|
placeholder = self.placeholder_token.value
|
|
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
|
|
self.train_data_dir.value = str(
|
|
Path(Globals.root) / TRAINING_DATA / placeholder
|
|
)
|
|
self.output_dir.value = str(Path(Globals.root) / TRAINING_DIR / placeholder)
|
|
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
|
|
|
|
def on_ok(self):
|
|
if self.validate_field_values():
|
|
self.parentApp.setNextForm(None)
|
|
self.editing = False
|
|
self.parentApp.ti_arguments = self.marshall_arguments()
|
|
npyscreen.notify(
|
|
"Launching textual inversion training. This will take a while..."
|
|
)
|
|
else:
|
|
self.editing = True
|
|
|
|
def ok_cancel(self):
|
|
sys.exit(0)
|
|
|
|
def validate_field_values(self) -> bool:
|
|
bad_fields = []
|
|
if self.model.value is None:
|
|
bad_fields.append(
|
|
"Model Name must correspond to a known model in models.yaml"
|
|
)
|
|
if not re.match("^[a-zA-Z0-9.-]+$", self.placeholder_token.value):
|
|
bad_fields.append(
|
|
"Trigger term must only contain alphanumeric characters, the dot and hyphen"
|
|
)
|
|
if self.train_data_dir.value is None:
|
|
bad_fields.append("Data Training Directory cannot be empty")
|
|
if self.output_dir.value is None:
|
|
bad_fields.append("The Output Destination Directory cannot be empty")
|
|
if len(bad_fields) > 0:
|
|
message = "The following problems were detected and must be corrected:"
|
|
for problem in bad_fields:
|
|
message += f"\n* {problem}"
|
|
npyscreen.notify_confirm(message)
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
def get_model_names(self) -> Tuple[List[str], int]:
|
|
conf = OmegaConf.load(os.path.join(Globals.root, "configs/models.yaml"))
|
|
model_names = [
|
|
idx
|
|
for idx in sorted(list(conf.keys()))
|
|
if conf[idx].get("format", None) == "diffusers"
|
|
]
|
|
defaults = [
|
|
idx
|
|
for idx in range(len(model_names))
|
|
if "default" in conf[model_names[idx]]
|
|
]
|
|
default = defaults[0] if len(defaults) > 0 else 0
|
|
return (model_names, default)
|
|
|
|
def marshall_arguments(self) -> dict:
|
|
args = dict()
|
|
|
|
# the choices
|
|
args.update(
|
|
model=self.model_names[self.model.value[0]],
|
|
resolution=self.resolutions[self.resolution.value[0]],
|
|
lr_scheduler=self.lr_schedulers[self.lr_scheduler.value[0]],
|
|
mixed_precision=self.precisions[self.mixed_precision.value[0]],
|
|
learnable_property=self.learnable_properties[
|
|
self.learnable_property.value[0]
|
|
],
|
|
)
|
|
|
|
# all the strings and booleans
|
|
for attr in (
|
|
"initializer_token",
|
|
"placeholder_token",
|
|
"train_data_dir",
|
|
"output_dir",
|
|
"scale_lr",
|
|
"center_crop",
|
|
"enable_xformers_memory_efficient_attention",
|
|
):
|
|
args[attr] = getattr(self, attr).value
|
|
|
|
# all the integers
|
|
for attr in (
|
|
"train_batch_size",
|
|
"gradient_accumulation_steps",
|
|
"num_train_epochs",
|
|
"max_train_steps",
|
|
"lr_warmup_steps",
|
|
):
|
|
args[attr] = int(getattr(self, attr).value)
|
|
|
|
# the floats (just one)
|
|
args.update(learning_rate=float(self.learning_rate.value))
|
|
|
|
# a special case
|
|
if self.resume_from_checkpoint.value and Path(self.output_dir.value).exists():
|
|
args["resume_from_checkpoint"] = "latest"
|
|
|
|
return args
|
|
|
|
|
|
class MyApplication(npyscreen.NPSAppManaged):
|
|
def __init__(self, saved_args=None):
|
|
super().__init__()
|
|
self.ti_arguments = None
|
|
self.saved_args = saved_args
|
|
|
|
def onStart(self):
|
|
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
|
self.main = self.addForm(
|
|
"MAIN",
|
|
textualInversionForm,
|
|
name="Textual Inversion Settings",
|
|
saved_args=self.saved_args,
|
|
)
|
|
|
|
|
|
def copy_to_embeddings_folder(args: dict):
|
|
"""
|
|
Copy learned_embeds.bin into the embeddings folder, and offer to
|
|
delete the full model and checkpoints.
|
|
"""
|
|
source = Path(args["output_dir"], "learned_embeds.bin")
|
|
dest_dir_name = args["placeholder_token"].strip("<>")
|
|
destination = Path(Globals.root, "embeddings", dest_dir_name)
|
|
os.makedirs(destination, exist_ok=True)
|
|
print(f">> Training completed. Copying learned_embeds.bin into {str(destination)}")
|
|
shutil.copy(source, destination)
|
|
if (
|
|
input("Delete training logs and intermediate checkpoints? [y] ") or "y"
|
|
).startswith(("y", "Y")):
|
|
shutil.rmtree(Path(args["output_dir"]))
|
|
else:
|
|
print(f'>> Keeping {args["output_dir"]}')
|
|
|
|
|
|
def save_args(args: dict):
|
|
"""
|
|
Save the current argument values to an omegaconf file
|
|
"""
|
|
dest_dir = Path(Globals.root) / TRAINING_DIR
|
|
os.makedirs(dest_dir, exist_ok=True)
|
|
conf_file = dest_dir / CONF_FILE
|
|
conf = OmegaConf.create(args)
|
|
OmegaConf.save(config=conf, f=conf_file)
|
|
|
|
|
|
def previous_args() -> dict:
|
|
"""
|
|
Get the previous arguments used.
|
|
"""
|
|
conf_file = Path(Globals.root) / TRAINING_DIR / CONF_FILE
|
|
try:
|
|
conf = OmegaConf.load(conf_file)
|
|
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
|
|
except:
|
|
conf = None
|
|
|
|
return conf
|
|
|
|
|
|
def do_front_end(args: Namespace):
|
|
saved_args = previous_args()
|
|
myapplication = MyApplication(saved_args=saved_args)
|
|
myapplication.run()
|
|
|
|
if args := myapplication.ti_arguments:
|
|
os.makedirs(args["output_dir"], exist_ok=True)
|
|
|
|
# Automatically add angle brackets around the trigger
|
|
if not re.match("^<.+>$", args["placeholder_token"]):
|
|
args["placeholder_token"] = f"<{args['placeholder_token']}>"
|
|
|
|
args["only_save_embeds"] = True
|
|
save_args(args)
|
|
|
|
try:
|
|
do_textual_inversion_training(**args)
|
|
copy_to_embeddings_folder(args)
|
|
except Exception as e:
|
|
print("** An exception occurred during training. The exception was:")
|
|
print(str(e))
|
|
print("** DETAILS:")
|
|
print(traceback.format_exc())
|
|
|
|
def main():
|
|
args = parse_args()
|
|
global_set_root(args.root_dir or Globals.root)
|
|
try:
|
|
if args.front_end:
|
|
do_front_end(args)
|
|
else:
|
|
do_textual_inversion_training(**vars(args))
|
|
except AssertionError as e:
|
|
print(str(e))
|
|
sys.exit(-1)
|
|
except KeyboardInterrupt:
|
|
pass
|
|
except (widget.NotEnoughSpaceForWidget, Exception) as e:
|
|
if str(e).startswith("Height of 1 allocated"):
|
|
print(
|
|
"** You need to have at least one diffusers models defined in models.yaml in order to train"
|
|
)
|
|
elif str(e).startswith('addwstr'):
|
|
print(
|
|
'** Not enough window space for the interface. Please make your window larger and try again.'
|
|
)
|
|
else:
|
|
print(f"** An error has occurred: {str(e)}")
|
|
sys.exit(-1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|