mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Use cuda only when available in main.py. (#567)
Allows testing textual inversion / training flow on cpu only (very slow though). Context: #508
This commit is contained in:
parent
30e69f8b32
commit
ccb2b7c2fb
25
main.py
25
main.py
@ -40,7 +40,8 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||
print('unexpected keys:')
|
||||
print(u)
|
||||
|
||||
model.cuda()
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
return model
|
||||
|
||||
|
||||
@ -549,23 +550,26 @@ class CUDACallback(Callback):
|
||||
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
# Reset the memory use counter
|
||||
torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
|
||||
torch.cuda.synchronize(trainer.root_gpu)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
|
||||
torch.cuda.synchronize(trainer.root_gpu)
|
||||
self.start_time = time.time()
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
||||
torch.cuda.synchronize(trainer.root_gpu)
|
||||
max_memory = (
|
||||
torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize(trainer.root_gpu)
|
||||
epoch_time = time.time() - self.start_time
|
||||
|
||||
try:
|
||||
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
||||
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
|
||||
|
||||
rank_zero_info(f'Average Epoch time: {epoch_time:.2f} seconds')
|
||||
rank_zero_info(f'Average Peak memory {max_memory:.2f}MiB')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
max_memory = (
|
||||
torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20
|
||||
)
|
||||
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
||||
rank_zero_info(f'Average Peak memory {max_memory:.2f}MiB')
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@ -872,7 +876,6 @@ if __name__ == '__main__':
|
||||
config.data.params.validation.params.data_root = opt.data_root
|
||||
data = instantiate_from_config(config.data)
|
||||
|
||||
data = instantiate_from_config(config.data)
|
||||
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||
# calling these ourselves should not be necessary but it is.
|
||||
# lightning still takes care of proper multiprocessing though
|
||||
|
Loading…
Reference in New Issue
Block a user