mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(model_cache): factor out load_ckpt
This commit is contained in:
parent
dcfdb83513
commit
9b274bd57c
@ -102,7 +102,7 @@ class ModelCache(object):
|
|||||||
'hash': hash
|
'hash': hash
|
||||||
}
|
}
|
||||||
|
|
||||||
def default_model(self) -> str:
|
def default_model(self) -> str | None:
|
||||||
'''
|
'''
|
||||||
Returns the name of the default model, or None
|
Returns the name of the default model, or None
|
||||||
if none is defined.
|
if none is defined.
|
||||||
@ -198,18 +198,6 @@ class ModelCache(object):
|
|||||||
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
||||||
|
|
||||||
mconfig = self.config[model_name]
|
mconfig = self.config[model_name]
|
||||||
config = mconfig.config
|
|
||||||
weights = mconfig.weights
|
|
||||||
vae = mconfig.get('vae')
|
|
||||||
width = mconfig.width
|
|
||||||
height = mconfig.height
|
|
||||||
|
|
||||||
if not os.path.isabs(weights):
|
|
||||||
weights = os.path.normpath(os.path.join(Globals.root,weights))
|
|
||||||
# scan model
|
|
||||||
self.scan_model(model_name, weights)
|
|
||||||
|
|
||||||
print(f'>> Loading {model_name} from {weights}')
|
|
||||||
|
|
||||||
# for usage statistics
|
# for usage statistics
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
@ -219,9 +207,36 @@ class ModelCache(object):
|
|||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
|
||||||
# this does the work
|
# this does the work
|
||||||
if not os.path.isabs(config):
|
model_format = mconfig.get('format', 'ckpt')
|
||||||
config = os.path.join(Globals.root,config)
|
if model_format == 'ckpt':
|
||||||
omega_config = OmegaConf.load(config)
|
weights = mconfig.weights
|
||||||
|
print(f'>> Loading {model_name} from {weights}')
|
||||||
|
model, width, height, model_hash = self._load_ckpt_model(mconfig)
|
||||||
|
elif model_format == 'diffusers':
|
||||||
|
model, width, height, model_hash = self._load_diffusers_model(mconfig)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown model format {model_name}: {model_format}")
|
||||||
|
|
||||||
|
# usage statistics
|
||||||
|
toc = time.time()
|
||||||
|
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
|
||||||
|
if self._has_cuda():
|
||||||
|
print(
|
||||||
|
'>> Max VRAM used to load the model:',
|
||||||
|
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||||
|
'\n>> Current VRAM usage:'
|
||||||
|
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||||
|
)
|
||||||
|
return model, width, height, model_hash
|
||||||
|
|
||||||
|
def _load_ckpt_model(self, mconfig):
|
||||||
|
config = mconfig.config
|
||||||
|
weights = mconfig.weights
|
||||||
|
vae = mconfig.get('vae', None)
|
||||||
|
width = mconfig.width
|
||||||
|
height = mconfig.height
|
||||||
|
|
||||||
|
c = OmegaConf.load(config)
|
||||||
with open(weights, 'rb') as f:
|
with open(weights, 'rb') as f:
|
||||||
weight_bytes = f.read()
|
weight_bytes = f.read()
|
||||||
model_hash = self._cached_sha256(weights, weight_bytes)
|
model_hash = self._cached_sha256(weights, weight_bytes)
|
||||||
@ -259,21 +274,12 @@ class ModelCache(object):
|
|||||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
||||||
module._orig_padding_mode = module.padding_mode
|
module._orig_padding_mode = module.padding_mode
|
||||||
|
|
||||||
# usage statistics
|
|
||||||
toc = time.time()
|
|
||||||
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
|
|
||||||
|
|
||||||
if self._has_cuda():
|
|
||||||
print(
|
|
||||||
'>> Max VRAM used to load the model:',
|
|
||||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
|
||||||
'\n>> Current VRAM usage:'
|
|
||||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
|
||||||
)
|
|
||||||
|
|
||||||
return model, width, height, model_hash
|
return model, width, height, model_hash
|
||||||
|
|
||||||
def offload_model(self, model_name:str) -> None:
|
def _load_diffusers_model(self, mconfig):
|
||||||
|
raise NotImplementedError() # return pipeline, width, height, model_hash
|
||||||
|
|
||||||
|
def offload_model(self, model_name:str):
|
||||||
'''
|
'''
|
||||||
Offload the indicated model to CPU. Will call
|
Offload the indicated model to CPU. Will call
|
||||||
_make_cache_room() to free space if needed.
|
_make_cache_room() to free space if needed.
|
||||||
|
Loading…
Reference in New Issue
Block a user