From 4fad71cd8c639cb2c0dc0fb68ade47a566079173 Mon Sep 17 00:00:00 2001 From: David Ford <2772469+david-ford@users.noreply.github.com> Date: Tue, 30 Aug 2022 14:59:32 -0500 Subject: [PATCH] Training optimizations (#217) * Optimizations to the training model Based on the changes made in textual_inversion I carried over the relevant changes that improve model training. These changes reduce the amount of memory used, significantly improve the speed at which training runs, and improves the quality of the results. It also fixes the problem where the model trainer wouldn't automatically stop when it hit the set number of steps. * Update main.py Cleaned up whitespace --- configs/stable-diffusion/v1-finetune.yaml | 7 ++++-- main.py | 26 ++++++++++++++++++----- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/configs/stable-diffusion/v1-finetune.yaml b/configs/stable-diffusion/v1-finetune.yaml index da6a5e775b..bf303cbdae 100644 --- a/configs/stable-diffusion/v1-finetune.yaml +++ b/configs/stable-diffusion/v1-finetune.yaml @@ -52,7 +52,7 @@ model: ddconfig: double_z: true z_channels: 4 - resolution: 256 + resolution: 512 in_channels: 3 out_ch: 3 ch: 128 @@ -73,7 +73,7 @@ model: data: target: main.DataModuleFromConfig params: - batch_size: 2 + batch_size: 1 num_workers: 16 wrap: false train: @@ -92,6 +92,9 @@ data: repeats: 10 lightning: + modelcheckpoint: + params: + every_n_train_steps: 500 callbacks: image_logger: target: main.ImageLogger diff --git a/main.py b/main.py index 8c36c270b1..c45194db44 100644 --- a/main.py +++ b/main.py @@ -171,8 +171,8 @@ def get_parser(**parser_kwargs): help='Initialize embedding manager from a checkpoint', ) parser.add_argument( - '--placeholder_tokens', type=str, nargs='+', default=['*'] - ) + '--placeholder_tokens', type=str, nargs='+', default=['*'], + help='Placeholder token which will be used to denote the concept in future prompts') parser.add_argument( '--init_word', @@ -473,7 +473,7 @@ class ImageLogger(Callback): self.check_frequency(check_idx) and hasattr( # batch_idx % self.batch_freq == 0 pl_module, 'log_images' - ) + ) and callable(pl_module.log_images) and self.max_images > 0 ): @@ -569,6 +569,21 @@ class CUDACallback(Callback): except AttributeError: pass +class ModeSwapCallback(Callback): + + def __init__(self, swap_step=2000): + super().__init__() + self.is_frozen = False + self.swap_step = swap_step + + def on_train_epoch_start(self, trainer, pl_module): + if trainer.global_step < self.swap_step and not self.is_frozen: + self.is_frozen = True + trainer.optimizers = [pl_module.configure_opt_embedding()] + + if trainer.global_step > self.swap_step and self.is_frozen: + self.is_frozen = False + trainer.optimizers = [pl_module.configure_opt_model()] if __name__ == '__main__': # custom parser to specify config files, train, test and debug mode, @@ -663,6 +678,7 @@ if __name__ == '__main__': if opt.datadir_in_name: now = os.path.basename(os.path.normpath(opt.data_root)) + now + nowname = now + name + opt.postfix logdir = os.path.join(opt.logdir, nowname) @@ -756,7 +772,7 @@ if __name__ == '__main__': if hasattr(model, 'monitor'): print(f'Monitoring {model.monitor} as checkpoint metric.') default_modelckpt_cfg['params']['monitor'] = model.monitor - default_modelckpt_cfg['params']['save_top_k'] = 3 + default_modelckpt_cfg['params']['save_top_k'] = 1 if 'modelcheckpoint' in lightning_config: modelckpt_cfg = lightning_config.modelcheckpoint @@ -846,7 +862,7 @@ if __name__ == '__main__': trainer_kwargs['callbacks'] = [ instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg ] - trainer_kwargs['max_steps'] = opt.max_steps + trainer_kwargs['max_steps'] = trainer_opt.max_steps trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) trainer.logdir = logdir ###