From fbea657effa4d9adb34a2d1c7306b3063b79fadc Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 20 Oct 2022 20:08:24 -0400 Subject: [PATCH] fix a number of bugs in textual inversion - remove unsupported testtubelogger, use csvlogger instead - fix logic for parsing --gpus option so that it won't crash if trailing comma absent - change trainer accelerator from unsupported 'ddp' to 'auto' --- main.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/main.py b/main.py index 436b7251ba..60c091891c 100644 --- a/main.py +++ b/main.py @@ -439,7 +439,7 @@ class ImageLogger(Callback): self.rescale = rescale self.batch_freq = batch_frequency self.max_images = max_images - self.logger_log_images = { pl.loggers.TestTubeLogger: self._testtube, } if torch.cuda.is_available() else { } + self.logger_log_images = { } self.log_steps = [ 2**n for n in range(int(np.log2(self.batch_freq)) + 1) ] @@ -451,17 +451,6 @@ class ImageLogger(Callback): self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} self.log_first_step = log_first_step - @rank_zero_only - def _testtube(self, pl_module, images, batch_idx, split): - for k in images: - grid = torchvision.utils.make_grid(images[k]) - grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w - - tag = f'{split}/{k}' - pl_module.logger.experiment.add_image( - tag, grid, global_step=pl_module.global_step - ) - @rank_zero_only def log_local( self, save_dir, split, images, global_step, current_epoch, batch_idx @@ -714,7 +703,7 @@ if __name__ == '__main__': # merge trainer cli with config trainer_config = lightning_config.get('trainer', OmegaConf.create()) # default to ddp - trainer_config['accelerator'] = 'ddp' + trainer_config['accelerator'] = 'auto' for k in nondefault_trainer_args(opt): trainer_config[k] = getattr(opt, k) if not 'gpus' in trainer_config: @@ -751,12 +740,8 @@ if __name__ == '__main__': trainer_kwargs = dict() # default logger configs - if torch.cuda.is_available(): - def_logger = 'testtube' - def_logger_target = 'TestTubeLogger' - else: - def_logger = 'csv' - def_logger_target = 'CSVLogger' + def_logger = 'csv' + def_logger_target = 'CSVLogger' default_logger_cfgs = { 'wandb': { 'target': 'pytorch_lightning.loggers.WandbLogger', @@ -918,7 +903,8 @@ if __name__ == '__main__': config.model.base_learning_rate, ) if not cpu: - ngpu = len(lightning_config.trainer.gpus.strip(',').split(',')) + gpus = str(lightning_config.trainer.gpus).strip(', ').split(',') + ngpu = len(gpus) else: ngpu = 1 if 'accumulate_grad_batches' in lightning_config.trainer: