mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Training optimizations (#217)
* Optimizations to the training model Based on the changes made in textual_inversion I carried over the relevant changes that improve model training. These changes reduce the amount of memory used, significantly improve the speed at which training runs, and improves the quality of the results. It also fixes the problem where the model trainer wouldn't automatically stop when it hit the set number of steps. * Update main.py Cleaned up whitespace
This commit is contained in:
parent
d126db2413
commit
4fad71cd8c
@ -52,7 +52,7 @@ model:
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
resolution: 512
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
@ -73,7 +73,7 @@ model:
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 2
|
||||
batch_size: 1
|
||||
num_workers: 16
|
||||
wrap: false
|
||||
train:
|
||||
@ -92,6 +92,9 @@ data:
|
||||
repeats: 10
|
||||
|
||||
lightning:
|
||||
modelcheckpoint:
|
||||
params:
|
||||
every_n_train_steps: 500
|
||||
callbacks:
|
||||
image_logger:
|
||||
target: main.ImageLogger
|
||||
|
26
main.py
26
main.py
@ -171,8 +171,8 @@ def get_parser(**parser_kwargs):
|
||||
help='Initialize embedding manager from a checkpoint',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--placeholder_tokens', type=str, nargs='+', default=['*']
|
||||
)
|
||||
'--placeholder_tokens', type=str, nargs='+', default=['*'],
|
||||
help='Placeholder token which will be used to denote the concept in future prompts')
|
||||
|
||||
parser.add_argument(
|
||||
'--init_word',
|
||||
@ -473,7 +473,7 @@ class ImageLogger(Callback):
|
||||
self.check_frequency(check_idx)
|
||||
and hasattr( # batch_idx % self.batch_freq == 0
|
||||
pl_module, 'log_images'
|
||||
)
|
||||
)
|
||||
and callable(pl_module.log_images)
|
||||
and self.max_images > 0
|
||||
):
|
||||
@ -569,6 +569,21 @@ class CUDACallback(Callback):
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
class ModeSwapCallback(Callback):
|
||||
|
||||
def __init__(self, swap_step=2000):
|
||||
super().__init__()
|
||||
self.is_frozen = False
|
||||
self.swap_step = swap_step
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
if trainer.global_step < self.swap_step and not self.is_frozen:
|
||||
self.is_frozen = True
|
||||
trainer.optimizers = [pl_module.configure_opt_embedding()]
|
||||
|
||||
if trainer.global_step > self.swap_step and self.is_frozen:
|
||||
self.is_frozen = False
|
||||
trainer.optimizers = [pl_module.configure_opt_model()]
|
||||
|
||||
if __name__ == '__main__':
|
||||
# custom parser to specify config files, train, test and debug mode,
|
||||
@ -663,6 +678,7 @@ if __name__ == '__main__':
|
||||
if opt.datadir_in_name:
|
||||
now = os.path.basename(os.path.normpath(opt.data_root)) + now
|
||||
|
||||
|
||||
nowname = now + name + opt.postfix
|
||||
logdir = os.path.join(opt.logdir, nowname)
|
||||
|
||||
@ -756,7 +772,7 @@ if __name__ == '__main__':
|
||||
if hasattr(model, 'monitor'):
|
||||
print(f'Monitoring {model.monitor} as checkpoint metric.')
|
||||
default_modelckpt_cfg['params']['monitor'] = model.monitor
|
||||
default_modelckpt_cfg['params']['save_top_k'] = 3
|
||||
default_modelckpt_cfg['params']['save_top_k'] = 1
|
||||
|
||||
if 'modelcheckpoint' in lightning_config:
|
||||
modelckpt_cfg = lightning_config.modelcheckpoint
|
||||
@ -846,7 +862,7 @@ if __name__ == '__main__':
|
||||
trainer_kwargs['callbacks'] = [
|
||||
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
|
||||
]
|
||||
trainer_kwargs['max_steps'] = opt.max_steps
|
||||
trainer_kwargs['max_steps'] = trainer_opt.max_steps
|
||||
|
||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||
trainer.logdir = logdir ###
|
||||
|
Loading…
x
Reference in New Issue
Block a user