mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Use cuda only when available in main.py. (#567)
Allows testing textual inversion / training flow on cpu only (very slow though). Context: #508
This commit is contained in:
parent
30e69f8b32
commit
ccb2b7c2fb
15
main.py
15
main.py
@ -40,6 +40,7 @@ def load_model_from_config(config, ckpt, verbose=False):
|
|||||||
print('unexpected keys:')
|
print('unexpected keys:')
|
||||||
print(u)
|
print(u)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
model.cuda()
|
model.cuda()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -549,22 +550,25 @@ class CUDACallback(Callback):
|
|||||||
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
||||||
def on_train_epoch_start(self, trainer, pl_module):
|
def on_train_epoch_start(self, trainer, pl_module):
|
||||||
# Reset the memory use counter
|
# Reset the memory use counter
|
||||||
|
if torch.cuda.is_available():
|
||||||
torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
|
torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
|
||||||
torch.cuda.synchronize(trainer.root_gpu)
|
torch.cuda.synchronize(trainer.root_gpu)
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
||||||
|
if torch.cuda.is_available():
|
||||||
torch.cuda.synchronize(trainer.root_gpu)
|
torch.cuda.synchronize(trainer.root_gpu)
|
||||||
max_memory = (
|
|
||||||
torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20
|
|
||||||
)
|
|
||||||
epoch_time = time.time() - self.start_time
|
epoch_time = time.time() - self.start_time
|
||||||
|
|
||||||
try:
|
try:
|
||||||
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
|
||||||
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
|
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
|
||||||
|
|
||||||
rank_zero_info(f'Average Epoch time: {epoch_time:.2f} seconds')
|
rank_zero_info(f'Average Epoch time: {epoch_time:.2f} seconds')
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
max_memory = (
|
||||||
|
torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20
|
||||||
|
)
|
||||||
|
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
||||||
rank_zero_info(f'Average Peak memory {max_memory:.2f}MiB')
|
rank_zero_info(f'Average Peak memory {max_memory:.2f}MiB')
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
@ -872,7 +876,6 @@ if __name__ == '__main__':
|
|||||||
config.data.params.validation.params.data_root = opt.data_root
|
config.data.params.validation.params.data_root = opt.data_root
|
||||||
data = instantiate_from_config(config.data)
|
data = instantiate_from_config(config.data)
|
||||||
|
|
||||||
data = instantiate_from_config(config.data)
|
|
||||||
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||||
# calling these ourselves should not be necessary but it is.
|
# calling these ourselves should not be necessary but it is.
|
||||||
# lightning still takes care of proper multiprocessing though
|
# lightning still takes care of proper multiprocessing though
|
||||||
|
Loading…
Reference in New Issue
Block a user