Merge branch 'development' of github.com:invoke-ai/InvokeAI into development

This commit is contained in:
Lincoln Stein 2022-10-21 12:58:51 -04:00
commit 38b1dce7c3

26
main.py
View File

@ -439,7 +439,7 @@ class ImageLogger(Callback):
self.rescale = rescale self.rescale = rescale
self.batch_freq = batch_frequency self.batch_freq = batch_frequency
self.max_images = max_images self.max_images = max_images
self.logger_log_images = { pl.loggers.TestTubeLogger: self._testtube, } if torch.cuda.is_available() else { } self.logger_log_images = { }
self.log_steps = [ self.log_steps = [
2**n for n in range(int(np.log2(self.batch_freq)) + 1) 2**n for n in range(int(np.log2(self.batch_freq)) + 1)
] ]
@ -451,17 +451,6 @@ class ImageLogger(Callback):
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
self.log_first_step = log_first_step self.log_first_step = log_first_step
@rank_zero_only
def _testtube(self, pl_module, images, batch_idx, split):
for k in images:
grid = torchvision.utils.make_grid(images[k])
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
tag = f'{split}/{k}'
pl_module.logger.experiment.add_image(
tag, grid, global_step=pl_module.global_step
)
@rank_zero_only @rank_zero_only
def log_local( def log_local(
self, save_dir, split, images, global_step, current_epoch, batch_idx self, save_dir, split, images, global_step, current_epoch, batch_idx
@ -714,7 +703,7 @@ if __name__ == '__main__':
# merge trainer cli with config # merge trainer cli with config
trainer_config = lightning_config.get('trainer', OmegaConf.create()) trainer_config = lightning_config.get('trainer', OmegaConf.create())
# default to ddp # default to ddp
trainer_config['accelerator'] = 'ddp' trainer_config['accelerator'] = 'auto'
for k in nondefault_trainer_args(opt): for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k) trainer_config[k] = getattr(opt, k)
if not 'gpus' in trainer_config: if not 'gpus' in trainer_config:
@ -751,12 +740,8 @@ if __name__ == '__main__':
trainer_kwargs = dict() trainer_kwargs = dict()
# default logger configs # default logger configs
if torch.cuda.is_available(): def_logger = 'csv'
def_logger = 'testtube' def_logger_target = 'CSVLogger'
def_logger_target = 'TestTubeLogger'
else:
def_logger = 'csv'
def_logger_target = 'CSVLogger'
default_logger_cfgs = { default_logger_cfgs = {
'wandb': { 'wandb': {
'target': 'pytorch_lightning.loggers.WandbLogger', 'target': 'pytorch_lightning.loggers.WandbLogger',
@ -918,7 +903,8 @@ if __name__ == '__main__':
config.model.base_learning_rate, config.model.base_learning_rate,
) )
if not cpu: if not cpu:
ngpu = len(lightning_config.trainer.gpus.strip(',').split(',')) gpus = str(lightning_config.trainer.gpus).strip(', ').split(',')
ngpu = len(gpus)
else: else:
ngpu = 1 ngpu = 1
if 'accumulate_grad_batches' in lightning_config.trainer: if 'accumulate_grad_batches' in lightning_config.trainer: