Training optimizations ()

* Optimizations to the training model

Based on the changes made in
textual_inversion I carried over the relevant changes that improve model training. These changes reduce the amount of memory used, significantly improve the speed at which training runs, and improves the quality of the results.

It also fixes the problem where the model trainer wouldn't automatically stop when it hit the set number of steps.

* Update main.py

Cleaned up whitespace
This commit is contained in:
David Ford 2022-08-30 14:59:32 -05:00 committed by GitHub
parent d126db2413
commit 4fad71cd8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 7 deletions
configs/stable-diffusion
main.py

@ -52,7 +52,7 @@ model:
ddconfig:
double_z: true
z_channels: 4
resolution: 256
resolution: 512
in_channels: 3
out_ch: 3
ch: 128
@ -73,7 +73,7 @@ model:
data:
target: main.DataModuleFromConfig
params:
batch_size: 2
batch_size: 1
num_workers: 16
wrap: false
train:
@ -92,6 +92,9 @@ data:
repeats: 10
lightning:
modelcheckpoint:
params:
every_n_train_steps: 500
callbacks:
image_logger:
target: main.ImageLogger

26
main.py

@ -171,8 +171,8 @@ def get_parser(**parser_kwargs):
help='Initialize embedding manager from a checkpoint',
)
parser.add_argument(
'--placeholder_tokens', type=str, nargs='+', default=['*']
)
'--placeholder_tokens', type=str, nargs='+', default=['*'],
help='Placeholder token which will be used to denote the concept in future prompts')
parser.add_argument(
'--init_word',
@ -473,7 +473,7 @@ class ImageLogger(Callback):
self.check_frequency(check_idx)
and hasattr( # batch_idx % self.batch_freq == 0
pl_module, 'log_images'
)
)
and callable(pl_module.log_images)
and self.max_images > 0
):
@ -569,6 +569,21 @@ class CUDACallback(Callback):
except AttributeError:
pass
class ModeSwapCallback(Callback):
def __init__(self, swap_step=2000):
super().__init__()
self.is_frozen = False
self.swap_step = swap_step
def on_train_epoch_start(self, trainer, pl_module):
if trainer.global_step < self.swap_step and not self.is_frozen:
self.is_frozen = True
trainer.optimizers = [pl_module.configure_opt_embedding()]
if trainer.global_step > self.swap_step and self.is_frozen:
self.is_frozen = False
trainer.optimizers = [pl_module.configure_opt_model()]
if __name__ == '__main__':
# custom parser to specify config files, train, test and debug mode,
@ -663,6 +678,7 @@ if __name__ == '__main__':
if opt.datadir_in_name:
now = os.path.basename(os.path.normpath(opt.data_root)) + now
nowname = now + name + opt.postfix
logdir = os.path.join(opt.logdir, nowname)
@ -756,7 +772,7 @@ if __name__ == '__main__':
if hasattr(model, 'monitor'):
print(f'Monitoring {model.monitor} as checkpoint metric.')
default_modelckpt_cfg['params']['monitor'] = model.monitor
default_modelckpt_cfg['params']['save_top_k'] = 3
default_modelckpt_cfg['params']['save_top_k'] = 1
if 'modelcheckpoint' in lightning_config:
modelckpt_cfg = lightning_config.modelcheckpoint
@ -846,7 +862,7 @@ if __name__ == '__main__':
trainer_kwargs['callbacks'] = [
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
]
trainer_kwargs['max_steps'] = opt.max_steps
trainer_kwargs['max_steps'] = trainer_opt.max_steps
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
trainer.logdir = logdir ###