report VRAM usage stats during initial model loading (#419)

This commit is contained in:
Lincoln Stein 2022-09-07 13:23:53 -04:00 committed by GitHub
parent f6284777e6
commit dd2aedacaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 4 deletions

View File

@ -501,12 +501,22 @@ class Generate:
def _load_model_from_config(self, config, ckpt): def _load_model_from_config(self, config, ckpt):
print(f'>> Loading model from {ckpt}') print(f'>> Loading model from {ckpt}')
# for usage statistics
device_type = choose_torch_device()
if device_type == 'cuda':
torch.cuda.reset_peak_memory_stats()
tic = time.time()
# this does the work
pl_sd = torch.load(ckpt, map_location='cpu') pl_sd = torch.load(ckpt, map_location='cpu')
sd = pl_sd['state_dict'] sd = pl_sd['state_dict']
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
model.to(self.device) model.to(self.device)
model.eval() model.eval()
if self.full_precision: if self.full_precision:
print( print(
'>> Using slower but more accurate full-precision math (--full_precision)' '>> Using slower but more accurate full-precision math (--full_precision)'
@ -516,6 +526,20 @@ class Generate:
'>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.' '>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.'
) )
model.half() model.half()
# usage statistics
toc = time.time()
print(
f'>> Model loaded in', '%4.2fs' % (toc - tic)
)
if device_type == 'cuda':
print(
'>> Max VRAM used to load the model:',
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
'\n>> Current VRAM usage:'
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
)
return model return model
def _load_img(self, path, width, height, fit=False): def _load_img(self, path, width, height, fit=False):

View File

@ -91,11 +91,7 @@ def main():
print(">> changed to seamless tiling mode") print(">> changed to seamless tiling mode")
# preload the model # preload the model
tic = time.time()
t2i.load_model() t2i.load_model()
print(
f'>> model loaded in', '%4.2fs' % (time.time() - tic)
)
if not infile: if not infile:
print( print(