mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' of github.com:invoke-ai/InvokeAI into development
This commit is contained in:
commit
38b1dce7c3
26
main.py
26
main.py
@ -439,7 +439,7 @@ class ImageLogger(Callback):
|
|||||||
self.rescale = rescale
|
self.rescale = rescale
|
||||||
self.batch_freq = batch_frequency
|
self.batch_freq = batch_frequency
|
||||||
self.max_images = max_images
|
self.max_images = max_images
|
||||||
self.logger_log_images = { pl.loggers.TestTubeLogger: self._testtube, } if torch.cuda.is_available() else { }
|
self.logger_log_images = { }
|
||||||
self.log_steps = [
|
self.log_steps = [
|
||||||
2**n for n in range(int(np.log2(self.batch_freq)) + 1)
|
2**n for n in range(int(np.log2(self.batch_freq)) + 1)
|
||||||
]
|
]
|
||||||
@ -451,17 +451,6 @@ class ImageLogger(Callback):
|
|||||||
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
||||||
self.log_first_step = log_first_step
|
self.log_first_step = log_first_step
|
||||||
|
|
||||||
@rank_zero_only
|
|
||||||
def _testtube(self, pl_module, images, batch_idx, split):
|
|
||||||
for k in images:
|
|
||||||
grid = torchvision.utils.make_grid(images[k])
|
|
||||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
|
||||||
|
|
||||||
tag = f'{split}/{k}'
|
|
||||||
pl_module.logger.experiment.add_image(
|
|
||||||
tag, grid, global_step=pl_module.global_step
|
|
||||||
)
|
|
||||||
|
|
||||||
@rank_zero_only
|
@rank_zero_only
|
||||||
def log_local(
|
def log_local(
|
||||||
self, save_dir, split, images, global_step, current_epoch, batch_idx
|
self, save_dir, split, images, global_step, current_epoch, batch_idx
|
||||||
@ -714,7 +703,7 @@ if __name__ == '__main__':
|
|||||||
# merge trainer cli with config
|
# merge trainer cli with config
|
||||||
trainer_config = lightning_config.get('trainer', OmegaConf.create())
|
trainer_config = lightning_config.get('trainer', OmegaConf.create())
|
||||||
# default to ddp
|
# default to ddp
|
||||||
trainer_config['accelerator'] = 'ddp'
|
trainer_config['accelerator'] = 'auto'
|
||||||
for k in nondefault_trainer_args(opt):
|
for k in nondefault_trainer_args(opt):
|
||||||
trainer_config[k] = getattr(opt, k)
|
trainer_config[k] = getattr(opt, k)
|
||||||
if not 'gpus' in trainer_config:
|
if not 'gpus' in trainer_config:
|
||||||
@ -751,12 +740,8 @@ if __name__ == '__main__':
|
|||||||
trainer_kwargs = dict()
|
trainer_kwargs = dict()
|
||||||
|
|
||||||
# default logger configs
|
# default logger configs
|
||||||
if torch.cuda.is_available():
|
def_logger = 'csv'
|
||||||
def_logger = 'testtube'
|
def_logger_target = 'CSVLogger'
|
||||||
def_logger_target = 'TestTubeLogger'
|
|
||||||
else:
|
|
||||||
def_logger = 'csv'
|
|
||||||
def_logger_target = 'CSVLogger'
|
|
||||||
default_logger_cfgs = {
|
default_logger_cfgs = {
|
||||||
'wandb': {
|
'wandb': {
|
||||||
'target': 'pytorch_lightning.loggers.WandbLogger',
|
'target': 'pytorch_lightning.loggers.WandbLogger',
|
||||||
@ -918,7 +903,8 @@ if __name__ == '__main__':
|
|||||||
config.model.base_learning_rate,
|
config.model.base_learning_rate,
|
||||||
)
|
)
|
||||||
if not cpu:
|
if not cpu:
|
||||||
ngpu = len(lightning_config.trainer.gpus.strip(',').split(','))
|
gpus = str(lightning_config.trainer.gpus).strip(', ').split(',')
|
||||||
|
ngpu = len(gpus)
|
||||||
else:
|
else:
|
||||||
ngpu = 1
|
ngpu = 1
|
||||||
if 'accumulate_grad_batches' in lightning_config.trainer:
|
if 'accumulate_grad_batches' in lightning_config.trainer:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user