mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(model_cache): factor out load_ckpt
This commit is contained in:
parent
dcfdb83513
commit
9b274bd57c
@ -102,7 +102,7 @@ class ModelCache(object):
|
||||
'hash': hash
|
||||
}
|
||||
|
||||
def default_model(self) -> str:
|
||||
def default_model(self) -> str | None:
|
||||
'''
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
@ -198,18 +198,6 @@ class ModelCache(object):
|
||||
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
|
||||
mconfig = self.config[model_name]
|
||||
config = mconfig.config
|
||||
weights = mconfig.weights
|
||||
vae = mconfig.get('vae')
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
if not os.path.isabs(weights):
|
||||
weights = os.path.normpath(os.path.join(Globals.root,weights))
|
||||
# scan model
|
||||
self.scan_model(model_name, weights)
|
||||
|
||||
print(f'>> Loading {model_name} from {weights}')
|
||||
|
||||
# for usage statistics
|
||||
if self._has_cuda():
|
||||
@ -219,12 +207,39 @@ class ModelCache(object):
|
||||
tic = time.time()
|
||||
|
||||
# this does the work
|
||||
if not os.path.isabs(config):
|
||||
config = os.path.join(Globals.root,config)
|
||||
omega_config = OmegaConf.load(config)
|
||||
with open(weights,'rb') as f:
|
||||
model_format = mconfig.get('format', 'ckpt')
|
||||
if model_format == 'ckpt':
|
||||
weights = mconfig.weights
|
||||
print(f'>> Loading {model_name} from {weights}')
|
||||
model, width, height, model_hash = self._load_ckpt_model(mconfig)
|
||||
elif model_format == 'diffusers':
|
||||
model, width, height, model_hash = self._load_diffusers_model(mconfig)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown model format {model_name}: {model_format}")
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
|
||||
if self._has_cuda():
|
||||
print(
|
||||
'>> Max VRAM used to load the model:',
|
||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
'\n>> Current VRAM usage:'
|
||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
return model, width, height, model_hash
|
||||
|
||||
def _load_ckpt_model(self, mconfig):
|
||||
config = mconfig.config
|
||||
weights = mconfig.weights
|
||||
vae = mconfig.get('vae', None)
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
c = OmegaConf.load(config)
|
||||
with open(weights, 'rb') as f:
|
||||
weight_bytes = f.read()
|
||||
model_hash = self._cached_sha256(weights,weight_bytes)
|
||||
model_hash = self._cached_sha256(weights, weight_bytes)
|
||||
sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
||||
del weight_bytes
|
||||
sd = sd['state_dict']
|
||||
@ -259,21 +274,12 @@ class ModelCache(object):
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
||||
module._orig_padding_mode = module.padding_mode
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
|
||||
|
||||
if self._has_cuda():
|
||||
print(
|
||||
'>> Max VRAM used to load the model:',
|
||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
'\n>> Current VRAM usage:'
|
||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
|
||||
return model, width, height, model_hash
|
||||
|
||||
def offload_model(self, model_name:str) -> None:
|
||||
def _load_diffusers_model(self, mconfig):
|
||||
raise NotImplementedError() # return pipeline, width, height, model_hash
|
||||
|
||||
def offload_model(self, model_name:str):
|
||||
'''
|
||||
Offload the indicated model to CPU. Will call
|
||||
_make_cache_room() to free space if needed.
|
||||
|
Loading…
Reference in New Issue
Block a user