mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
351 lines
13 KiB
Python
Executable File
351 lines
13 KiB
Python
Executable File
#!/usr/bin/env python
|
|
|
|
import npyscreen
|
|
import os
|
|
import sys
|
|
import re
|
|
import shutil
|
|
import traceback
|
|
import curses
|
|
from ldm.invoke.globals import Globals, global_set_root
|
|
from omegaconf import OmegaConf
|
|
from pathlib import Path
|
|
from typing import List
|
|
import argparse
|
|
|
|
TRAINING_DATA = 'text-inversion-training-data'
|
|
TRAINING_DIR = 'text-inversion-output'
|
|
CONF_FILE = 'preferences.conf'
|
|
|
|
class textualInversionForm(npyscreen.FormMultiPageAction):
|
|
resolutions = [512, 768, 1024]
|
|
lr_schedulers = [
|
|
"linear", "cosine", "cosine_with_restarts",
|
|
"polynomial","constant", "constant_with_warmup"
|
|
]
|
|
precisions = ['no','fp16','bf16']
|
|
learnable_properties = ['object','style']
|
|
|
|
def __init__(self, parentApp, name, saved_args=None):
|
|
self.saved_args = saved_args or {}
|
|
super().__init__(parentApp, name)
|
|
|
|
def afterEditing(self):
|
|
self.parentApp.setNextForm(None)
|
|
|
|
def create(self):
|
|
self.model_names, default = self.get_model_names()
|
|
default_initializer_token = '★'
|
|
default_placeholder_token = ''
|
|
saved_args = self.saved_args
|
|
|
|
try:
|
|
default = self.model_names.index(saved_args['model'])
|
|
except:
|
|
pass
|
|
|
|
self.add_widget_intelligent(
|
|
npyscreen.FixedText,
|
|
value='Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields, cursor arrows to make a selection, and space to toggle checkboxes.'
|
|
)
|
|
|
|
self.model = self.add_widget_intelligent(
|
|
npyscreen.TitleSelectOne,
|
|
name='Model Name:',
|
|
values=self.model_names,
|
|
value=default,
|
|
max_height=len(self.model_names)+1
|
|
)
|
|
self.placeholder_token = self.add_widget_intelligent(
|
|
npyscreen.TitleText,
|
|
name='Trigger Term:',
|
|
value='', # saved_args.get('placeholder_token',''), # to restore previous term
|
|
)
|
|
self.placeholder_token.when_value_edited = self.initializer_changed
|
|
self.nextrely -= 1
|
|
self.nextrelx += 30
|
|
self.prompt_token = self.add_widget_intelligent(
|
|
npyscreen.FixedText,
|
|
name="Trigger term for use in prompt",
|
|
value='',
|
|
)
|
|
self.nextrelx -= 30
|
|
self.initializer_token = self.add_widget_intelligent(
|
|
npyscreen.TitleText,
|
|
name='Initializer:',
|
|
value=saved_args.get('initializer_token',default_initializer_token),
|
|
)
|
|
self.resume_from_checkpoint = self.add_widget_intelligent(
|
|
npyscreen.Checkbox,
|
|
name="Resume from last saved checkpoint",
|
|
value=False,
|
|
)
|
|
self.learnable_property = self.add_widget_intelligent(
|
|
npyscreen.TitleSelectOne,
|
|
name="Learnable property:",
|
|
values=self.learnable_properties,
|
|
value=self.learnable_properties.index(saved_args.get('learnable_property','object')),
|
|
max_height=4,
|
|
)
|
|
self.train_data_dir = self.add_widget_intelligent(
|
|
npyscreen.TitleFilename,
|
|
name='Data Training Directory:',
|
|
select_dir=True,
|
|
must_exist=False,
|
|
value=str(saved_args.get('train_data_dir',Path(Globals.root) / TRAINING_DATA / default_placeholder_token))
|
|
)
|
|
self.output_dir = self.add_widget_intelligent(
|
|
npyscreen.TitleFilename,
|
|
name='Output Destination Directory:',
|
|
select_dir=True,
|
|
must_exist=False,
|
|
value=str(saved_args.get('output_dir',Path(Globals.root) / TRAINING_DIR / default_placeholder_token))
|
|
)
|
|
self.resolution = self.add_widget_intelligent(
|
|
npyscreen.TitleSelectOne,
|
|
name='Image resolution (pixels):',
|
|
values = self.resolutions,
|
|
value=self.resolutions.index(saved_args.get('resolution',512)),
|
|
scroll_exit = True,
|
|
max_height=4,
|
|
)
|
|
self.center_crop = self.add_widget_intelligent(
|
|
npyscreen.Checkbox,
|
|
name="Center crop images before resizing to resolution",
|
|
value=saved_args.get('center_crop',False)
|
|
)
|
|
self.mixed_precision = self.add_widget_intelligent(
|
|
npyscreen.TitleSelectOne,
|
|
name='Mixed Precision:',
|
|
values=self.precisions,
|
|
value=self.precisions.index(saved_args.get('mixed_precision','fp16')),
|
|
max_height=4,
|
|
)
|
|
self.num_train_epochs = self.add_widget_intelligent(
|
|
npyscreen.TitleSlider,
|
|
name='Number of training epochs:',
|
|
out_of=1000,
|
|
step=50,
|
|
lowest=1,
|
|
value=saved_args.get('num_train_epochs',100)
|
|
)
|
|
self.max_train_steps = self.add_widget_intelligent(
|
|
npyscreen.TitleSlider,
|
|
name='Max Training Steps:',
|
|
out_of=10000,
|
|
step=500,
|
|
lowest=1,
|
|
value=saved_args.get('max_train_steps',3000)
|
|
)
|
|
self.train_batch_size = self.add_widget_intelligent(
|
|
npyscreen.TitleSlider,
|
|
name='Batch Size (reduce if you run out of memory):',
|
|
out_of=50,
|
|
step=1,
|
|
lowest=1,
|
|
value=saved_args.get('train_batch_size',8),
|
|
)
|
|
self.gradient_accumulation_steps = self.add_widget_intelligent(
|
|
npyscreen.TitleSlider,
|
|
name='Gradient Accumulation Steps (may need to decrease this to resume from a checkpoint):',
|
|
out_of=10,
|
|
step=1,
|
|
lowest=1,
|
|
value=saved_args.get('gradient_accumulation_steps',4)
|
|
)
|
|
self.lr_warmup_steps = self.add_widget_intelligent(
|
|
npyscreen.TitleSlider,
|
|
name='Warmup Steps:',
|
|
out_of=100,
|
|
step=1,
|
|
lowest=0,
|
|
value=saved_args.get('lr_warmup_steps',0),
|
|
)
|
|
self.learning_rate = self.add_widget_intelligent(
|
|
npyscreen.TitleText,
|
|
name="Learning Rate:",
|
|
value=str(saved_args.get('learning_rate','5.0e-04'),)
|
|
)
|
|
self.scale_lr = self.add_widget_intelligent(
|
|
npyscreen.Checkbox,
|
|
name="Scale learning rate by number GPUs, steps and batch size",
|
|
value=saved_args.get('scale_lr',True),
|
|
)
|
|
self.enable_xformers_memory_efficient_attention = self.add_widget_intelligent(
|
|
npyscreen.Checkbox,
|
|
name="Use xformers acceleration",
|
|
value=saved_args.get('enable_xformers_memory_efficient_attention',False),
|
|
)
|
|
self.lr_scheduler = self.add_widget_intelligent(
|
|
npyscreen.TitleSelectOne,
|
|
name='Learning rate scheduler:',
|
|
values = self.lr_schedulers,
|
|
max_height=7,
|
|
scroll_exit = True,
|
|
value=self.lr_schedulers.index(saved_args.get('lr_scheduler','constant')),
|
|
)
|
|
|
|
def initializer_changed(self):
|
|
placeholder = self.placeholder_token.value
|
|
self.prompt_token.value = f'(Trigger by using <{placeholder}> in your prompts)'
|
|
self.train_data_dir.value = str(Path(Globals.root) / TRAINING_DATA / placeholder)
|
|
self.output_dir.value = str(Path(Globals.root) / TRAINING_DIR / placeholder)
|
|
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
|
|
|
|
def on_ok(self):
|
|
if self.validate_field_values():
|
|
self.parentApp.setNextForm(None)
|
|
self.editing = False
|
|
self.parentApp.ti_arguments = self.marshall_arguments()
|
|
npyscreen.notify('Launching textual inversion training. This will take a while...')
|
|
# The module load takes a while, so we do it while the form and message are still up
|
|
import ldm.invoke.textual_inversion_training
|
|
else:
|
|
self.editing = True
|
|
|
|
def ok_cancel(self):
|
|
sys.exit(0)
|
|
|
|
def validate_field_values(self)->bool:
|
|
bad_fields = []
|
|
if self.model.value is None:
|
|
bad_fields.append('Model Name must correspond to a known model in models.yaml')
|
|
if not re.match('^[a-zA-Z0-9.-]+$',self.placeholder_token.value):
|
|
bad_fields.append('Trigger term must only contain alphanumeric characters, the dot and hyphen')
|
|
if self.train_data_dir.value is None:
|
|
bad_fields.append('Data Training Directory cannot be empty')
|
|
if self.output_dir.value is None:
|
|
bad_fields.append('The Output Destination Directory cannot be empty')
|
|
if len(bad_fields) > 0:
|
|
message = 'The following problems were detected and must be corrected:'
|
|
for problem in bad_fields:
|
|
message += f'\n* {problem}'
|
|
npyscreen.notify_confirm(message)
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
def get_model_names(self)->(List[str],int):
|
|
conf = OmegaConf.load(os.path.join(Globals.root,'configs/models.yaml'))
|
|
model_names = [idx for idx in sorted(list(conf.keys())) if conf[idx].get('format',None)=='diffusers']
|
|
defaults = [idx for idx in range(len(model_names)) if 'default' in conf[model_names[idx]]]
|
|
return (model_names,defaults[0])
|
|
|
|
def marshall_arguments(self)->dict:
|
|
args = dict()
|
|
|
|
# the choices
|
|
args.update(
|
|
model = self.model_names[self.model.value[0]],
|
|
resolution = self.resolutions[self.resolution.value[0]],
|
|
lr_scheduler = self.lr_schedulers[self.lr_scheduler.value[0]],
|
|
mixed_precision = self.precisions[self.mixed_precision.value[0]],
|
|
learnable_property = self.learnable_properties[self.learnable_property.value[0]],
|
|
)
|
|
|
|
# all the strings and booleans
|
|
for attr in ('initializer_token','placeholder_token','train_data_dir',
|
|
'output_dir','scale_lr','center_crop','enable_xformers_memory_efficient_attention'):
|
|
args[attr] = getattr(self,attr).value
|
|
|
|
# all the integers
|
|
for attr in ('train_batch_size','gradient_accumulation_steps',
|
|
'num_train_epochs','max_train_steps','lr_warmup_steps'):
|
|
args[attr] = int(getattr(self,attr).value)
|
|
|
|
# the floats (just one)
|
|
args.update(
|
|
learning_rate = float(self.learning_rate.value)
|
|
)
|
|
|
|
# a special case
|
|
if self.resume_from_checkpoint.value and Path(self.output_dir.value).exists():
|
|
args['resume_from_checkpoint'] = 'latest'
|
|
|
|
return args
|
|
|
|
class MyApplication(npyscreen.NPSAppManaged):
|
|
def __init__(self, saved_args=None):
|
|
super().__init__()
|
|
self.ti_arguments=None
|
|
self.saved_args=saved_args
|
|
|
|
def onStart(self):
|
|
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
|
self.main = self.addForm('MAIN', textualInversionForm, name='Textual Inversion Settings', saved_args=self.saved_args)
|
|
|
|
def copy_to_embeddings_folder(args:dict):
|
|
'''
|
|
Copy learned_embeds.bin into the embeddings folder, and offer to
|
|
delete the full model and checkpoints.
|
|
'''
|
|
source = Path(args['output_dir'],'learned_embeds.bin')
|
|
dest_dir_name = args['placeholder_token'].strip('<>')
|
|
destination = Path(Globals.root,'embeddings',dest_dir_name)
|
|
os.makedirs(destination,exist_ok=True)
|
|
print(f'>> Training completed. Copying learned_embeds.bin into {str(destination)}')
|
|
shutil.copy(source,destination)
|
|
if (input('Delete training logs and intermediate checkpoints? [y] ') or 'y').startswith(('y','Y')):
|
|
shutil.rmtree(Path(args['output_dir']))
|
|
else:
|
|
print(f'>> Keeping {args["output_dir"]}')
|
|
|
|
def save_args(args:dict):
|
|
'''
|
|
Save the current argument values to an omegaconf file
|
|
'''
|
|
dest_dir = Path(Globals.root) / TRAINING_DIR
|
|
os.makedirs(dest_dir, exist_ok=True)
|
|
conf_file = dest_dir / CONF_FILE
|
|
conf = OmegaConf.create(args)
|
|
OmegaConf.save(config=conf, f=conf_file)
|
|
|
|
def previous_args()->dict:
|
|
'''
|
|
Get the previous arguments used.
|
|
'''
|
|
conf_file = Path(Globals.root) / TRAINING_DIR / CONF_FILE
|
|
try:
|
|
conf = OmegaConf.load(conf_file)
|
|
conf['placeholder_token'] = conf['placeholder_token'].strip('<>')
|
|
except:
|
|
conf= None
|
|
|
|
return conf
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='InvokeAI textual inversion training')
|
|
parser.add_argument(
|
|
'--root_dir','--root-dir',
|
|
type=Path,
|
|
default=Globals.root,
|
|
help='Path to the invokeai runtime directory',
|
|
)
|
|
args = parser.parse_args()
|
|
global_set_root(args.root_dir)
|
|
|
|
saved_args = previous_args()
|
|
myapplication = MyApplication(saved_args=saved_args)
|
|
myapplication.run()
|
|
|
|
from ldm.invoke.textual_inversion_training import do_textual_inversion_training
|
|
if args := myapplication.ti_arguments:
|
|
os.makedirs(args['output_dir'],exist_ok=True)
|
|
|
|
# Automatically add angle brackets around the trigger
|
|
if not re.match('^<.+>$',args['placeholder_token']):
|
|
args['placeholder_token'] = f"<{args['placeholder_token']}>"
|
|
|
|
args['only_save_embeds'] = True
|
|
save_args(args)
|
|
|
|
try:
|
|
print(f'DEBUG: args = {args}')
|
|
do_textual_inversion_training(**args)
|
|
copy_to_embeddings_folder(args)
|
|
except Exception as e:
|
|
print('** An exception occurred during training. The exception was:')
|
|
print(str(e))
|
|
print('** DETAILS:')
|
|
print(traceback.format_exc())
|