From ccb2b7c2fbfda243ec7fd3c523e9f6eb5dd46c65 Mon Sep 17 00:00:00 2001 From: Mihai <299015+mh-dm@users.noreply.github.com> Date: Thu, 15 Sep 2022 14:41:24 +0300 Subject: [PATCH] Use cuda only when available in main.py. (#567) Allows testing textual inversion / training flow on cpu only (very slow though). Context: #508 --- main.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/main.py b/main.py index c45194db44..72aaa49c3b 100644 --- a/main.py +++ b/main.py @@ -40,7 +40,8 @@ def load_model_from_config(config, ckpt, verbose=False): print('unexpected keys:') print(u) - model.cuda() + if torch.cuda.is_available(): + model.cuda() return model @@ -549,23 +550,26 @@ class CUDACallback(Callback): # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py def on_train_epoch_start(self, trainer, pl_module): # Reset the memory use counter - torch.cuda.reset_peak_memory_stats(trainer.root_gpu) - torch.cuda.synchronize(trainer.root_gpu) + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats(trainer.root_gpu) + torch.cuda.synchronize(trainer.root_gpu) self.start_time = time.time() def on_train_epoch_end(self, trainer, pl_module, outputs): - torch.cuda.synchronize(trainer.root_gpu) - max_memory = ( - torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20 - ) + if torch.cuda.is_available(): + torch.cuda.synchronize(trainer.root_gpu) epoch_time = time.time() - self.start_time try: - max_memory = trainer.training_type_plugin.reduce(max_memory) epoch_time = trainer.training_type_plugin.reduce(epoch_time) - rank_zero_info(f'Average Epoch time: {epoch_time:.2f} seconds') - rank_zero_info(f'Average Peak memory {max_memory:.2f}MiB') + + if torch.cuda.is_available(): + max_memory = ( + torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20 + ) + max_memory = trainer.training_type_plugin.reduce(max_memory) + rank_zero_info(f'Average Peak memory {max_memory:.2f}MiB') except AttributeError: pass @@ -872,7 +876,6 @@ if __name__ == '__main__': config.data.params.validation.params.data_root = opt.data_root data = instantiate_from_config(config.data) - data = instantiate_from_config(config.data) # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html # calling these ourselves should not be necessary but it is. # lightning still takes care of proper multiprocessing though