Use cuda only when available in main.py. (#567)

Allows testing textual inversion / training flow on cpu only (very slow though).
Context: #508
This commit is contained in:
Mihai 2022-09-15 14:41:24 +03:00 committed by GitHub
parent 30e69f8b32
commit ccb2b7c2fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

15
main.py
View File

@ -40,6 +40,7 @@ def load_model_from_config(config, ckpt, verbose=False):
print('unexpected keys:') print('unexpected keys:')
print(u) print(u)
if torch.cuda.is_available():
model.cuda() model.cuda()
return model return model
@ -549,22 +550,25 @@ class CUDACallback(Callback):
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
def on_train_epoch_start(self, trainer, pl_module): def on_train_epoch_start(self, trainer, pl_module):
# Reset the memory use counter # Reset the memory use counter
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats(trainer.root_gpu) torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
torch.cuda.synchronize(trainer.root_gpu) torch.cuda.synchronize(trainer.root_gpu)
self.start_time = time.time() self.start_time = time.time()
def on_train_epoch_end(self, trainer, pl_module, outputs): def on_train_epoch_end(self, trainer, pl_module, outputs):
if torch.cuda.is_available():
torch.cuda.synchronize(trainer.root_gpu) torch.cuda.synchronize(trainer.root_gpu)
max_memory = (
torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20
)
epoch_time = time.time() - self.start_time epoch_time = time.time() - self.start_time
try: try:
max_memory = trainer.training_type_plugin.reduce(max_memory)
epoch_time = trainer.training_type_plugin.reduce(epoch_time) epoch_time = trainer.training_type_plugin.reduce(epoch_time)
rank_zero_info(f'Average Epoch time: {epoch_time:.2f} seconds') rank_zero_info(f'Average Epoch time: {epoch_time:.2f} seconds')
if torch.cuda.is_available():
max_memory = (
torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20
)
max_memory = trainer.training_type_plugin.reduce(max_memory)
rank_zero_info(f'Average Peak memory {max_memory:.2f}MiB') rank_zero_info(f'Average Peak memory {max_memory:.2f}MiB')
except AttributeError: except AttributeError:
pass pass
@ -872,7 +876,6 @@ if __name__ == '__main__':
config.data.params.validation.params.data_root = opt.data_root config.data.params.validation.params.data_root = opt.data_root
data = instantiate_from_config(config.data) data = instantiate_from_config(config.data)
data = instantiate_from_config(config.data)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is. # calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though # lightning still takes care of proper multiprocessing though