refactor(model_cache): factor out load_ckpt

This commit is contained in:
Kevin Turner 2022-11-09 15:25:56 -08:00
parent dcfdb83513
commit 9b274bd57c

View File

@ -102,7 +102,7 @@ class ModelCache(object):
'hash': hash 'hash': hash
} }
def default_model(self) -> str: def default_model(self) -> str | None:
''' '''
Returns the name of the default model, or None Returns the name of the default model, or None
if none is defined. if none is defined.
@ -198,18 +198,6 @@ class ModelCache(object):
print(f'"{model_name}" is not a known model name. Please check your models.yaml file') print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
mconfig = self.config[model_name] mconfig = self.config[model_name]
config = mconfig.config
weights = mconfig.weights
vae = mconfig.get('vae')
width = mconfig.width
height = mconfig.height
if not os.path.isabs(weights):
weights = os.path.normpath(os.path.join(Globals.root,weights))
# scan model
self.scan_model(model_name, weights)
print(f'>> Loading {model_name} from {weights}')
# for usage statistics # for usage statistics
if self._has_cuda(): if self._has_cuda():
@ -219,9 +207,36 @@ class ModelCache(object):
tic = time.time() tic = time.time()
# this does the work # this does the work
if not os.path.isabs(config): model_format = mconfig.get('format', 'ckpt')
config = os.path.join(Globals.root,config) if model_format == 'ckpt':
omega_config = OmegaConf.load(config) weights = mconfig.weights
print(f'>> Loading {model_name} from {weights}')
model, width, height, model_hash = self._load_ckpt_model(mconfig)
elif model_format == 'diffusers':
model, width, height, model_hash = self._load_diffusers_model(mconfig)
else:
raise NotImplementedError(f"Unknown model format {model_name}: {model_format}")
# usage statistics
toc = time.time()
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
if self._has_cuda():
print(
'>> Max VRAM used to load the model:',
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
'\n>> Current VRAM usage:'
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
)
return model, width, height, model_hash
def _load_ckpt_model(self, mconfig):
config = mconfig.config
weights = mconfig.weights
vae = mconfig.get('vae', None)
width = mconfig.width
height = mconfig.height
c = OmegaConf.load(config)
with open(weights, 'rb') as f: with open(weights, 'rb') as f:
weight_bytes = f.read() weight_bytes = f.read()
model_hash = self._cached_sha256(weights, weight_bytes) model_hash = self._cached_sha256(weights, weight_bytes)
@ -259,21 +274,12 @@ class ModelCache(object):
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
module._orig_padding_mode = module.padding_mode module._orig_padding_mode = module.padding_mode
# usage statistics
toc = time.time()
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
if self._has_cuda():
print(
'>> Max VRAM used to load the model:',
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
'\n>> Current VRAM usage:'
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
)
return model, width, height, model_hash return model, width, height, model_hash
def offload_model(self, model_name:str) -> None: def _load_diffusers_model(self, mconfig):
raise NotImplementedError() # return pipeline, width, height, model_hash
def offload_model(self, model_name:str):
''' '''
Offload the indicated model to CPU. Will call Offload the indicated model to CPU. Will call
_make_cache_room() to free space if needed. _make_cache_room() to free space if needed.