mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix a number of bugs in textual inversion
- remove unsupported testtubelogger, use csvlogger instead - fix logic for parsing --gpus option so that it won't crash if trailing comma absent - change trainer accelerator from unsupported 'ddp' to 'auto'
This commit is contained in:
parent
d4d1014c9f
commit
fbea657eff
26
main.py
26
main.py
@ -439,7 +439,7 @@ class ImageLogger(Callback):
|
||||
self.rescale = rescale
|
||||
self.batch_freq = batch_frequency
|
||||
self.max_images = max_images
|
||||
self.logger_log_images = { pl.loggers.TestTubeLogger: self._testtube, } if torch.cuda.is_available() else { }
|
||||
self.logger_log_images = { }
|
||||
self.log_steps = [
|
||||
2**n for n in range(int(np.log2(self.batch_freq)) + 1)
|
||||
]
|
||||
@ -451,17 +451,6 @@ class ImageLogger(Callback):
|
||||
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
||||
self.log_first_step = log_first_step
|
||||
|
||||
@rank_zero_only
|
||||
def _testtube(self, pl_module, images, batch_idx, split):
|
||||
for k in images:
|
||||
grid = torchvision.utils.make_grid(images[k])
|
||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||
|
||||
tag = f'{split}/{k}'
|
||||
pl_module.logger.experiment.add_image(
|
||||
tag, grid, global_step=pl_module.global_step
|
||||
)
|
||||
|
||||
@rank_zero_only
|
||||
def log_local(
|
||||
self, save_dir, split, images, global_step, current_epoch, batch_idx
|
||||
@ -714,7 +703,7 @@ if __name__ == '__main__':
|
||||
# merge trainer cli with config
|
||||
trainer_config = lightning_config.get('trainer', OmegaConf.create())
|
||||
# default to ddp
|
||||
trainer_config['accelerator'] = 'ddp'
|
||||
trainer_config['accelerator'] = 'auto'
|
||||
for k in nondefault_trainer_args(opt):
|
||||
trainer_config[k] = getattr(opt, k)
|
||||
if not 'gpus' in trainer_config:
|
||||
@ -751,12 +740,8 @@ if __name__ == '__main__':
|
||||
trainer_kwargs = dict()
|
||||
|
||||
# default logger configs
|
||||
if torch.cuda.is_available():
|
||||
def_logger = 'testtube'
|
||||
def_logger_target = 'TestTubeLogger'
|
||||
else:
|
||||
def_logger = 'csv'
|
||||
def_logger_target = 'CSVLogger'
|
||||
def_logger = 'csv'
|
||||
def_logger_target = 'CSVLogger'
|
||||
default_logger_cfgs = {
|
||||
'wandb': {
|
||||
'target': 'pytorch_lightning.loggers.WandbLogger',
|
||||
@ -918,7 +903,8 @@ if __name__ == '__main__':
|
||||
config.model.base_learning_rate,
|
||||
)
|
||||
if not cpu:
|
||||
ngpu = len(lightning_config.trainer.gpus.strip(',').split(','))
|
||||
gpus = str(lightning_config.trainer.gpus).strip(', ').split(',')
|
||||
ngpu = len(gpus)
|
||||
else:
|
||||
ngpu = 1
|
||||
if 'accumulate_grad_batches' in lightning_config.trainer:
|
||||
|
Loading…
Reference in New Issue
Block a user