diff --git a/main.py b/main.py index c45194db44..72aaa49c3b 100644 --- a/main.py +++ b/main.py @@ -40,7 +40,8 @@ def load_model_from_config(config, ckpt, verbose=False): print('unexpected keys:') print(u) - model.cuda() + if torch.cuda.is_available(): + model.cuda() return model @@ -549,23 +550,26 @@ class CUDACallback(Callback): # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py def on_train_epoch_start(self, trainer, pl_module): # Reset the memory use counter - torch.cuda.reset_peak_memory_stats(trainer.root_gpu) - torch.cuda.synchronize(trainer.root_gpu) + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats(trainer.root_gpu) + torch.cuda.synchronize(trainer.root_gpu) self.start_time = time.time() def on_train_epoch_end(self, trainer, pl_module, outputs): - torch.cuda.synchronize(trainer.root_gpu) - max_memory = ( - torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20 - ) + if torch.cuda.is_available(): + torch.cuda.synchronize(trainer.root_gpu) epoch_time = time.time() - self.start_time try: - max_memory = trainer.training_type_plugin.reduce(max_memory) epoch_time = trainer.training_type_plugin.reduce(epoch_time) - rank_zero_info(f'Average Epoch time: {epoch_time:.2f} seconds') - rank_zero_info(f'Average Peak memory {max_memory:.2f}MiB') + + if torch.cuda.is_available(): + max_memory = ( + torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20 + ) + max_memory = trainer.training_type_plugin.reduce(max_memory) + rank_zero_info(f'Average Peak memory {max_memory:.2f}MiB') except AttributeError: pass @@ -872,7 +876,6 @@ if __name__ == '__main__': config.data.params.validation.params.data_root = opt.data_root data = instantiate_from_config(config.data) - data = instantiate_from_config(config.data) # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html # calling these ourselves should not be necessary but it is. # lightning still takes care of proper multiprocessing though