mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[enhancement] Reorganize form for textual inversion training
- Add num_train_epochs - Reorganize widgets so all sliders that control # of steps are together
This commit is contained in:
parent
ce17051b28
commit
9b1843307b
@ -115,6 +115,14 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
value=self.precisions.index(saved_args.get('mixed_precision','fp16')),
|
value=self.precisions.index(saved_args.get('mixed_precision','fp16')),
|
||||||
max_height=4,
|
max_height=4,
|
||||||
)
|
)
|
||||||
|
self.num_train_epochs = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSlider,
|
||||||
|
name='Number of training epochs:',
|
||||||
|
out_of=1000,
|
||||||
|
step=50,
|
||||||
|
lowest=1,
|
||||||
|
value=saved_args.get('num_train_epochs',100)
|
||||||
|
)
|
||||||
self.max_train_steps = self.add_widget_intelligent(
|
self.max_train_steps = self.add_widget_intelligent(
|
||||||
npyscreen.TitleSlider,
|
npyscreen.TitleSlider,
|
||||||
name='Max Training Steps:',
|
name='Max Training Steps:',
|
||||||
@ -131,6 +139,22 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
lowest=1,
|
lowest=1,
|
||||||
value=saved_args.get('train_batch_size',8),
|
value=saved_args.get('train_batch_size',8),
|
||||||
)
|
)
|
||||||
|
self.gradient_accumulation_steps = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSlider,
|
||||||
|
name='Gradient Accumulation Steps (may need to decrease this to resume from a checkpoint):',
|
||||||
|
out_of=10,
|
||||||
|
step=1,
|
||||||
|
lowest=1,
|
||||||
|
value=saved_args.get('gradient_accumulation_steps',4)
|
||||||
|
)
|
||||||
|
self.lr_warmup_steps = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSlider,
|
||||||
|
name='Warmup Steps:',
|
||||||
|
out_of=100,
|
||||||
|
step=1,
|
||||||
|
lowest=0,
|
||||||
|
value=saved_args.get('lr_warmup_steps',0),
|
||||||
|
)
|
||||||
self.learning_rate = self.add_widget_intelligent(
|
self.learning_rate = self.add_widget_intelligent(
|
||||||
npyscreen.TitleText,
|
npyscreen.TitleText,
|
||||||
name="Learning Rate:",
|
name="Learning Rate:",
|
||||||
@ -154,22 +178,6 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
scroll_exit = True,
|
scroll_exit = True,
|
||||||
value=self.lr_schedulers.index(saved_args.get('lr_scheduler','constant')),
|
value=self.lr_schedulers.index(saved_args.get('lr_scheduler','constant')),
|
||||||
)
|
)
|
||||||
self.gradient_accumulation_steps = self.add_widget_intelligent(
|
|
||||||
npyscreen.TitleSlider,
|
|
||||||
name='Gradient Accumulation Steps:',
|
|
||||||
out_of=10,
|
|
||||||
step=1,
|
|
||||||
lowest=1,
|
|
||||||
value=saved_args.get('gradient_accumulation_steps',4)
|
|
||||||
)
|
|
||||||
self.lr_warmup_steps = self.add_widget_intelligent(
|
|
||||||
npyscreen.TitleSlider,
|
|
||||||
name='Warmup Steps:',
|
|
||||||
out_of=100,
|
|
||||||
step=1,
|
|
||||||
lowest=0,
|
|
||||||
value=saved_args.get('lr_warmup_steps',0),
|
|
||||||
)
|
|
||||||
|
|
||||||
def initializer_changed(self):
|
def initializer_changed(self):
|
||||||
placeholder = self.placeholder_token.value
|
placeholder = self.placeholder_token.value
|
||||||
@ -236,7 +244,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
|
|
||||||
# all the integers
|
# all the integers
|
||||||
for attr in ('train_batch_size','gradient_accumulation_steps',
|
for attr in ('train_batch_size','gradient_accumulation_steps',
|
||||||
'max_train_steps','lr_warmup_steps'):
|
'num_train_epochs','max_train_steps','lr_warmup_steps'):
|
||||||
args[attr] = int(getattr(self,attr).value)
|
args[attr] = int(getattr(self,attr).value)
|
||||||
|
|
||||||
# the floats (just one)
|
# the floats (just one)
|
||||||
@ -324,6 +332,7 @@ if __name__ == '__main__':
|
|||||||
save_args(args)
|
save_args(args)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
print(f'DEBUG: args = {args}')
|
||||||
do_textual_inversion_training(**args)
|
do_textual_inversion_training(**args)
|
||||||
copy_to_embeddings_folder(args)
|
copy_to_embeddings_folder(args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
Loading…
Reference in New Issue
Block a user