mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
report VRAM usage stats during initial model loading (#419)
This commit is contained in:
parent
f6284777e6
commit
dd2aedacaf
@ -501,12 +501,22 @@ class Generate:
|
|||||||
|
|
||||||
def _load_model_from_config(self, config, ckpt):
|
def _load_model_from_config(self, config, ckpt):
|
||||||
print(f'>> Loading model from {ckpt}')
|
print(f'>> Loading model from {ckpt}')
|
||||||
|
|
||||||
|
# for usage statistics
|
||||||
|
device_type = choose_torch_device()
|
||||||
|
if device_type == 'cuda':
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
tic = time.time()
|
||||||
|
|
||||||
|
# this does the work
|
||||||
pl_sd = torch.load(ckpt, map_location='cpu')
|
pl_sd = torch.load(ckpt, map_location='cpu')
|
||||||
sd = pl_sd['state_dict']
|
sd = pl_sd['state_dict']
|
||||||
model = instantiate_from_config(config.model)
|
model = instantiate_from_config(config.model)
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
model.to(self.device)
|
model.to(self.device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
|
||||||
if self.full_precision:
|
if self.full_precision:
|
||||||
print(
|
print(
|
||||||
'>> Using slower but more accurate full-precision math (--full_precision)'
|
'>> Using slower but more accurate full-precision math (--full_precision)'
|
||||||
@ -516,6 +526,20 @@ class Generate:
|
|||||||
'>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.'
|
'>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.'
|
||||||
)
|
)
|
||||||
model.half()
|
model.half()
|
||||||
|
|
||||||
|
# usage statistics
|
||||||
|
toc = time.time()
|
||||||
|
print(
|
||||||
|
f'>> Model loaded in', '%4.2fs' % (toc - tic)
|
||||||
|
)
|
||||||
|
if device_type == 'cuda':
|
||||||
|
print(
|
||||||
|
'>> Max VRAM used to load the model:',
|
||||||
|
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||||
|
'\n>> Current VRAM usage:'
|
||||||
|
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _load_img(self, path, width, height, fit=False):
|
def _load_img(self, path, width, height, fit=False):
|
||||||
|
@ -91,11 +91,7 @@ def main():
|
|||||||
print(">> changed to seamless tiling mode")
|
print(">> changed to seamless tiling mode")
|
||||||
|
|
||||||
# preload the model
|
# preload the model
|
||||||
tic = time.time()
|
|
||||||
t2i.load_model()
|
t2i.load_model()
|
||||||
print(
|
|
||||||
f'>> model loaded in', '%4.2fs' % (time.time() - tic)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not infile:
|
if not infile:
|
||||||
print(
|
print(
|
||||||
|
Loading…
Reference in New Issue
Block a user