#!/usr/bin/env python """ This is the frontend to "textual_inversion_training.py". Copyright (c) 2023-24 Lincoln Stein and the InvokeAI Development Team """ import os import re import shutil import sys import traceback from argparse import Namespace from pathlib import Path from typing import Dict, List, Optional, Tuple import npyscreen from npyscreen import widget from omegaconf import OmegaConf import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.install.install_helper import initialize_installer from invokeai.backend.model_manager import ModelType from invokeai.backend.training import do_textual_inversion_training, parse_args TRAINING_DATA = "text-inversion-training-data" TRAINING_DIR = "text-inversion-output" CONF_FILE = "preferences.conf" config = None class textualInversionForm(npyscreen.FormMultiPageAction): resolutions = [512, 768, 1024] lr_schedulers = [ "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", ] precisions = ["no", "fp16", "bf16"] learnable_properties = ["object", "style"] def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, saved_args: Optional[Dict[str, str]] = None): self.saved_args = saved_args or {} super().__init__(parentApp, name) def afterEditing(self) -> None: self.parentApp.setNextForm(None) def create(self) -> None: self.model_names, default = self.get_model_names() default_initializer_token = "★" default_placeholder_token = "" saved_args = self.saved_args assert config is not None try: default = self.model_names.index(saved_args["model"]) except Exception: pass self.add_widget_intelligent( npyscreen.FixedText, value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields, cursor arrows to make a selection, and space to toggle checkboxes.", editable=False, ) self.model = self.add_widget_intelligent( npyscreen.TitleSelectOne, name="Model Name:", values=sorted(self.model_names), value=default, max_height=len(self.model_names) + 1, scroll_exit=True, ) self.placeholder_token = self.add_widget_intelligent( npyscreen.TitleText, name="Trigger Term:", value="", # saved_args.get('placeholder_token',''), # to restore previous term scroll_exit=True, ) self.placeholder_token.when_value_edited = self.initializer_changed self.nextrely -= 1 self.nextrelx += 30 self.prompt_token = self.add_widget_intelligent( npyscreen.FixedText, name="Trigger term for use in prompt", value="", editable=False, scroll_exit=True, ) self.nextrelx -= 30 self.initializer_token = self.add_widget_intelligent( npyscreen.TitleText, name="Initializer:", value=saved_args.get("initializer_token", default_initializer_token), scroll_exit=True, ) self.resume_from_checkpoint = self.add_widget_intelligent( npyscreen.Checkbox, name="Resume from last saved checkpoint", value=False, scroll_exit=True, ) self.learnable_property = self.add_widget_intelligent( npyscreen.TitleSelectOne, name="Learnable property:", values=self.learnable_properties, value=self.learnable_properties.index(saved_args.get("learnable_property", "object")), max_height=4, scroll_exit=True, ) self.train_data_dir = self.add_widget_intelligent( npyscreen.TitleFilename, name="Data Training Directory:", select_dir=True, must_exist=False, value=str( saved_args.get( "train_data_dir", config.root_dir / TRAINING_DATA / default_placeholder_token, ) ), scroll_exit=True, ) self.output_dir = self.add_widget_intelligent( npyscreen.TitleFilename, name="Output Destination Directory:", select_dir=True, must_exist=False, value=str( saved_args.get( "output_dir", config.root_dir / TRAINING_DIR / default_placeholder_token, ) ), scroll_exit=True, ) self.resolution = self.add_widget_intelligent( npyscreen.TitleSelectOne, name="Image resolution (pixels):", values=self.resolutions, value=self.resolutions.index(saved_args.get("resolution", 512)), max_height=4, scroll_exit=True, ) self.center_crop = self.add_widget_intelligent( npyscreen.Checkbox, name="Center crop images before resizing to resolution", value=saved_args.get("center_crop", False), scroll_exit=True, ) self.mixed_precision = self.add_widget_intelligent( npyscreen.TitleSelectOne, name="Mixed Precision:", values=self.precisions, value=self.precisions.index(saved_args.get("mixed_precision", "fp16")), max_height=4, scroll_exit=True, ) self.num_train_epochs = self.add_widget_intelligent( npyscreen.TitleSlider, name="Number of training epochs:", out_of=1000, step=50, lowest=1, value=saved_args.get("num_train_epochs", 100), scroll_exit=True, ) self.max_train_steps = self.add_widget_intelligent( npyscreen.TitleSlider, name="Max Training Steps:", out_of=10000, step=500, lowest=1, value=saved_args.get("max_train_steps", 3000), scroll_exit=True, ) self.train_batch_size = self.add_widget_intelligent( npyscreen.TitleSlider, name="Batch Size (reduce if you run out of memory):", out_of=50, step=1, lowest=1, value=saved_args.get("train_batch_size", 8), scroll_exit=True, ) self.gradient_accumulation_steps = self.add_widget_intelligent( npyscreen.TitleSlider, name="Gradient Accumulation Steps (may need to decrease this to resume from a checkpoint):", out_of=10, step=1, lowest=1, value=saved_args.get("gradient_accumulation_steps", 4), scroll_exit=True, ) self.lr_warmup_steps = self.add_widget_intelligent( npyscreen.TitleSlider, name="Warmup Steps:", out_of=100, step=1, lowest=0, value=saved_args.get("lr_warmup_steps", 0), scroll_exit=True, ) self.learning_rate = self.add_widget_intelligent( npyscreen.TitleText, name="Learning Rate:", value=str( saved_args.get("learning_rate", "5.0e-04"), ), scroll_exit=True, ) self.scale_lr = self.add_widget_intelligent( npyscreen.Checkbox, name="Scale learning rate by number GPUs, steps and batch size", value=saved_args.get("scale_lr", True), scroll_exit=True, ) self.enable_xformers_memory_efficient_attention = self.add_widget_intelligent( npyscreen.Checkbox, name="Use xformers acceleration", value=saved_args.get("enable_xformers_memory_efficient_attention", False), scroll_exit=True, ) self.lr_scheduler = self.add_widget_intelligent( npyscreen.TitleSelectOne, name="Learning rate scheduler:", values=self.lr_schedulers, max_height=7, value=self.lr_schedulers.index(saved_args.get("lr_scheduler", "constant")), scroll_exit=True, ) self.model.editing = True def initializer_changed(self) -> None: placeholder = self.placeholder_token.value self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)" self.train_data_dir.value = str(config.root_dir / TRAINING_DATA / placeholder) self.output_dir.value = str(config.root_dir / TRAINING_DIR / placeholder) self.resume_from_checkpoint.value = Path(self.output_dir.value).exists() def on_ok(self): if self.validate_field_values(): self.parentApp.setNextForm(None) self.editing = False self.parentApp.ti_arguments = self.marshall_arguments() npyscreen.notify("Launching textual inversion training. This will take a while...") else: self.editing = True def ok_cancel(self): sys.exit(0) def validate_field_values(self) -> bool: bad_fields = [] if self.model.value is None: bad_fields.append("Model Name must correspond to a known model in invokeai.db") if not re.match("^[a-zA-Z0-9.-]+$", self.placeholder_token.value): bad_fields.append("Trigger term must only contain alphanumeric characters, the dot and hyphen") if self.train_data_dir.value is None: bad_fields.append("Data Training Directory cannot be empty") if self.output_dir.value is None: bad_fields.append("The Output Destination Directory cannot be empty") if len(bad_fields) > 0: message = "The following problems were detected and must be corrected:" for problem in bad_fields: message += f"\n* {problem}" npyscreen.notify_confirm(message) return False else: return True def get_model_names(self) -> Tuple[List[str], int]: global config assert config is not None installer = initialize_installer(config) store = installer.record_store main_models = store.search_by_attr(model_type=ModelType.Main) model_names = [f"{x.base.value}/{x.type.value}/{x.name}" for x in main_models if x.format == "diffusers"] default = 0 return (model_names, default) def marshall_arguments(self) -> dict: args = {} # the choices args.update( model=self.model_names[self.model.value[0]], resolution=self.resolutions[self.resolution.value[0]], lr_scheduler=self.lr_schedulers[self.lr_scheduler.value[0]], mixed_precision=self.precisions[self.mixed_precision.value[0]], learnable_property=self.learnable_properties[self.learnable_property.value[0]], ) # all the strings and booleans for attr in ( "initializer_token", "placeholder_token", "train_data_dir", "output_dir", "scale_lr", "center_crop", "enable_xformers_memory_efficient_attention", ): args[attr] = getattr(self, attr).value # all the integers for attr in ( "train_batch_size", "gradient_accumulation_steps", "num_train_epochs", "max_train_steps", "lr_warmup_steps", ): args[attr] = int(getattr(self, attr).value) # the floats (just one) args.update(learning_rate=float(self.learning_rate.value)) # a special case if self.resume_from_checkpoint.value and Path(self.output_dir.value).exists(): args["resume_from_checkpoint"] = "latest" return args class MyApplication(npyscreen.NPSAppManaged): def __init__(self, saved_args: Optional[Dict[str, str]] = None): super().__init__() self.ti_arguments = None self.saved_args = saved_args def onStart(self): npyscreen.setTheme(npyscreen.Themes.DefaultTheme) self.main = self.addForm( "MAIN", textualInversionForm, name="Textual Inversion Settings", saved_args=self.saved_args, ) def copy_to_embeddings_folder(args: Dict[str, str]) -> None: """ Copy learned_embeds.bin into the embeddings folder, and offer to delete the full model and checkpoints. """ assert config is not None source = Path(args["output_dir"], "learned_embeds.bin") dest_dir_name = args["placeholder_token"].strip("<>") destination = config.root_dir / "embeddings" / dest_dir_name os.makedirs(destination, exist_ok=True) logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}") shutil.copy(source, destination) if (input("Delete training logs and intermediate checkpoints? [y] ") or "y").startswith(("y", "Y")): shutil.rmtree(Path(args["output_dir"])) else: logger.info(f'Keeping {args["output_dir"]}') def save_args(args: dict) -> None: """ Save the current argument values to an omegaconf file """ assert config is not None dest_dir = config.root_dir / TRAINING_DIR os.makedirs(dest_dir, exist_ok=True) conf_file = dest_dir / CONF_FILE conf = OmegaConf.create(args) OmegaConf.save(config=conf, f=conf_file) def previous_args() -> dict: """ Get the previous arguments used. """ assert config is not None conf_file = config.root_dir / TRAINING_DIR / CONF_FILE try: conf = OmegaConf.load(conf_file) conf["placeholder_token"] = conf["placeholder_token"].strip("<>") except Exception: conf = None return conf def do_front_end() -> None: global config saved_args = previous_args() myapplication = MyApplication(saved_args=saved_args) myapplication.run() if my_args := myapplication.ti_arguments: os.makedirs(my_args["output_dir"], exist_ok=True) # Automatically add angle brackets around the trigger if not re.match("^<.+>$", my_args["placeholder_token"]): my_args["placeholder_token"] = f"<{my_args['placeholder_token']}>" my_args["only_save_embeds"] = True save_args(my_args) try: print(my_args) do_textual_inversion_training(config, **my_args) copy_to_embeddings_folder(my_args) except Exception as e: logger.error("An exception occurred during training. The exception was:") logger.error(str(e)) logger.error("DETAILS:") logger.error(traceback.format_exc()) def main() -> None: global config args: Namespace = parse_args() config = InvokeAIAppConfig.get_config() config.parse_args([]) # change root if needed if args.root_dir: config.root = args.root_dir try: if args.front_end: do_front_end() else: do_textual_inversion_training(config, **vars(args)) except AssertionError as e: logger.error(e) sys.exit(-1) except KeyboardInterrupt: pass except (widget.NotEnoughSpaceForWidget, Exception) as e: if str(e).startswith("Height of 1 allocated"): logger.error("You need to have at least one diffusers models defined in invokeai.db in order to train") elif str(e).startswith("addwstr"): logger.error("Not enough window space for the interface. Please make your window larger and try again.") else: logger.error(e) sys.exit(-1) if __name__ == "__main__": main()