refactor(model_cache): factor out load_ckpt

This commit is contained in:
Kevin Turner 2022-11-09 15:25:56 -08:00
parent dcfdb83513
commit 9b274bd57c

View File

@ -1,5 +1,5 @@
'''
Manage a cache of Stable Diffusion model files for fast switching.
Manage a cache of Stable Diffusion model files for fast switching.
They are moved between GPU and CPU as necessary. If CPU memory falls
below a preset minimum, the least recently used model will be
cleared and loaded from disk when next needed.
@ -51,7 +51,7 @@ class ModelCache(object):
identifier.
'''
return model_name in self.config
def get_model(self, model_name:str):
'''
Given a model named identified in models.yaml, return
@ -66,7 +66,7 @@ class ModelCache(object):
if model_name not in self.models: # make room for a new one
self._make_cache_room()
self.offload_model(self.current_model)
if model_name in self.models:
requested_model = self.models[model_name]['model']
print(f'>> Retrieving model {model_name} from system RAM cache')
@ -92,7 +92,7 @@ class ModelCache(object):
print(f'** restoring {self.current_model}')
self.get_model(self.current_model)
return
self.current_model = model_name
self._push_newest_model(model_name)
return {
@ -102,7 +102,7 @@ class ModelCache(object):
'hash': hash
}
def default_model(self) -> str:
def default_model(self) -> str | None:
'''
Returns the name of the default model, or None
if none is defined.
@ -191,25 +191,13 @@ class ModelCache(object):
omega[model_name] = config
if clobber:
self._invalidate_cached_model(model_name)
def _load_model(self, model_name:str):
"""Load and initialize the model from configuration variables passed at object creation time"""
if model_name not in self.config:
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
mconfig = self.config[model_name]
config = mconfig.config
weights = mconfig.weights
vae = mconfig.get('vae')
width = mconfig.width
height = mconfig.height
if not os.path.isabs(weights):
weights = os.path.normpath(os.path.join(Globals.root,weights))
# scan model
self.scan_model(model_name, weights)
print(f'>> Loading {model_name} from {weights}')
# for usage statistics
if self._has_cuda():
@ -219,12 +207,39 @@ class ModelCache(object):
tic = time.time()
# this does the work
if not os.path.isabs(config):
config = os.path.join(Globals.root,config)
omega_config = OmegaConf.load(config)
with open(weights,'rb') as f:
model_format = mconfig.get('format', 'ckpt')
if model_format == 'ckpt':
weights = mconfig.weights
print(f'>> Loading {model_name} from {weights}')
model, width, height, model_hash = self._load_ckpt_model(mconfig)
elif model_format == 'diffusers':
model, width, height, model_hash = self._load_diffusers_model(mconfig)
else:
raise NotImplementedError(f"Unknown model format {model_name}: {model_format}")
# usage statistics
toc = time.time()
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
if self._has_cuda():
print(
'>> Max VRAM used to load the model:',
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
'\n>> Current VRAM usage:'
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
)
return model, width, height, model_hash
def _load_ckpt_model(self, mconfig):
config = mconfig.config
weights = mconfig.weights
vae = mconfig.get('vae', None)
width = mconfig.width
height = mconfig.height
c = OmegaConf.load(config)
with open(weights, 'rb') as f:
weight_bytes = f.read()
model_hash = self._cached_sha256(weights,weight_bytes)
model_hash = self._cached_sha256(weights, weight_bytes)
sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
del weight_bytes
sd = sd['state_dict']
@ -252,28 +267,19 @@ class ModelCache(object):
model.to(self.device)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
model.cond_stage_model.device = self.device
model.eval()
for module in model.modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
module._orig_padding_mode = module.padding_mode
# usage statistics
toc = time.time()
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
if self._has_cuda():
print(
'>> Max VRAM used to load the model:',
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
'\n>> Current VRAM usage:'
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
)
return model, width, height, model_hash
def offload_model(self, model_name:str) -> None:
def _load_diffusers_model(self, mconfig):
raise NotImplementedError() # return pipeline, width, height, model_hash
def offload_model(self, model_name:str):
'''
Offload the indicated model to CPU. Will call
_make_cache_room() to free space if needed.
@ -288,7 +294,7 @@ class ModelCache(object):
gc.collect()
if self._has_cuda():
torch.cuda.empty_cache()
def scan_model(self, model_name, checkpoint):
# scan model
print(f'>> Scanning Model: {model_name}')
@ -318,7 +324,7 @@ class ModelCache(object):
if least_recent_model is not None:
del self.models[least_recent_model]
gc.collect()
def print_vram_usage(self) -> None:
if self._has_cuda:
print('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
@ -353,12 +359,12 @@ class ModelCache(object):
if model_name in self.stack:
self.stack.remove(model_name)
self.models.pop(model_name,None)
def _model_to_cpu(self,model):
if self.device != 'cpu':
model.cond_stage_model.device = 'cpu'
model.first_stage_model.to('cpu')
model.cond_stage_model.to('cpu')
model.cond_stage_model.to('cpu')
model.model.to('cpu')
return model.to('cpu')
else:
@ -388,7 +394,7 @@ class ModelCache(object):
with contextlib.suppress(ValueError):
self.stack.remove(model_name)
self.stack.append(model_name)
def _has_cuda(self) -> bool:
return self.device.type == 'cuda'