cleanup ldm/invoke/model_cache.py

remove duplicate import: os
ldm.util.ask_user is imported only once now
introduce textwrap and contextlib packages to clean up the code
return, returns None implicitly so it is omitted
a function returns None by default so it is omitted
dict.get returns None by default if the value is not found so it is omitted
type of True is a bool and if the module only returns True then it should not return anything in the first place
added some indentations and line breaks to further improve readability

Signed-off-by: devops117 <55235206+devops117@users.noreply.github.com>
This commit is contained in:
devops117 2022-11-21 07:30:05 +05:30 committed by Lincoln Stein
parent 8e81425e89
commit a095214e52

View File

@ -15,10 +15,11 @@ import psutil
import sys
import transformers
import traceback
import os
import textwrap
import contextlib
from omegaconf import OmegaConf
from omegaconf.errors import ConfigAttributeError
from ldm.util import instantiate_from_config
from ldm.util import instantiate_from_config, ask_user
from ldm.invoke.globals import Globals
from picklescan.scanner import scan_file_path
@ -72,6 +73,7 @@ class ModelCache(object):
width = self.models[model_name]['width']
height = self.models[model_name]['height']
hash = self.models[model_name]['hash']
else: # we're about to load a new model, so potentially offload the least recently used one
try:
requested_model, width, height, hash = self._load_model(model_name)
@ -80,12 +82,13 @@ class ModelCache(object):
self.models[model_name]['width'] = width
self.models[model_name]['height'] = height
self.models[model_name]['hash'] = hash
except Exception as e:
print(f'** model {model_name} could not be loaded: {str(e)}')
print(traceback.format_exc())
print(f'** restoring {self.current_model}')
self.get_model(self.current_model)
return None
return
self.current_model = model_name
self._push_newest_model(model_name)
@ -102,9 +105,8 @@ class ModelCache(object):
if none is defined.
'''
for model_name in self.config:
if self.config[model_name].get('default',False):
if self.config[model_name].get('default'):
return model_name
return None
def set_default_model(self,model_name:str):
'''
@ -112,6 +114,7 @@ class ModelCache(object):
effect until you call model_cache.commit()
'''
assert model_name in self.models,f"unknown model '{model_name}'"
config = self.config
for model in config:
config[model].pop('default',None)
@ -125,22 +128,24 @@ class ModelCache(object):
},
model_name2: { etc }
'''
result = {}
for name in self.config:
try:
description = self.config[name].description
except ConfigAttributeError:
description = '<no description>'
if self.current_model == name:
status = 'active'
elif name in self.models:
status = 'cached'
else:
status = 'not loaded'
result[name]={}
result[name]['status']=status
result[name]['description']=description
return result
return {
name: {
'status': status,
'description': description,
}}
def print_models(self):
'''
@ -150,11 +155,10 @@ class ModelCache(object):
for name in models:
line = f'{name:25s} {models[name]["status"]:>10s} {models[name]["description"]}'
if models[name]['status'] == 'active':
print(f'\033[1m{line}\033[0m')
else:
print(line)
line = f'\033[1m{line}\033[0m')
print(line)
def del_model(self, model_name:str) ->bool:
def del_model(self, model_name:str):
'''
Delete the named model.
'''
@ -162,9 +166,8 @@ class ModelCache(object):
del omega[model_name]
if model_name in self.stack:
self.stack.remove(model_name)
return True
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True:
def add_model(self, model_name:str, model_attributes:dict, clobber=False):
'''
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
@ -172,12 +175,11 @@ class ModelCache(object):
method will return True. Will fail with an assertion error if provided
attributes are incorrect or the model name is missing.
'''
omega = self.config
# check that all the required fields are present
for field in ('description','weights','height','width','config'):
assert field in model_attributes, f'required field {field} is missing'
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
omega = self.config
config = omega[model_name] if model_name in omega else {}
for field in model_attributes:
config[field] = model_attributes[field]
@ -185,18 +187,16 @@ class ModelCache(object):
omega[model_name] = config
if clobber:
self._invalidate_cached_model(model_name)
return True
def _load_model(self, model_name:str):
"""Load and initialize the model from configuration variables passed at object creation time"""
if model_name not in self.config:
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
return None
mconfig = self.config[model_name]
config = mconfig.config
weights = mconfig.weights
vae = mconfig.get('vae',None)
vae = mconfig.get('vae')
width = mconfig.width
height = mconfig.height
@ -217,15 +217,15 @@ class ModelCache(object):
# this does the work
if not os.path.isabs(config):
config = os.path.join(Globals.root,config)
c = OmegaConf.load(config)
omega_config = OmegaConf.load(config)
with open(weights,'rb') as f:
weight_bytes = f.read()
model_hash = self._cached_sha256(weights,weight_bytes)
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
del weight_bytes
sd = pl_sd['state_dict']
model = instantiate_from_config(c.model)
m, u = model.load_state_dict(sd, strict=False)
sd = sd['state_dict']
model = instantiate_from_config(omega_config.model)
model.load_state_dict(sd, strict=False)
if self.precision == 'float16':
print(' | Using faster float16 precision')
@ -251,13 +251,14 @@ class ModelCache(object):
model.eval()
for m in model.modules():
if isinstance(m, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
m._orig_padding_mode = m.padding_mode
for module in model.modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
module._orig_padding_mode = module.padding_mode
# usage statistics
toc = time.time()
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
if self._has_cuda():
print(
'>> Max VRAM used to load the model:',
@ -265,6 +266,7 @@ class ModelCache(object):
'\n>> Current VRAM usage:'
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
)
return model, width, height, model_hash
def offload_model(self, model_name:str):
@ -272,12 +274,10 @@ class ModelCache(object):
Offload the indicated model to CPU. Will call
_make_cache_room() to free space if needed.
'''
if model_name not in self.models:
return
message = f'>> Offloading {model_name} to CPU'
print(message)
print(f'>> Offloading {model_name} to CPU')
model = self.models[model_name]['model']
self.models[model_name]['model'] = self._model_to_cpu(model)
@ -299,11 +299,8 @@ class ModelCache(object):
sys.exit()
else:
print('\n### WARNING: InvokeAI was unable to scan the model you are using.')
from ldm.util import ask_user
model_safe_check_fail = ask_user('Do you want to to continue loading the model?', ['y', 'n'])
if model_safe_check_fail.lower() == 'y':
pass
else:
if model_safe_check_fail.lower() != 'y':
print("### Exiting InvokeAI")
sys.exit()
else:
@ -320,7 +317,7 @@ class ModelCache(object):
def print_vram_usage(self):
if self._has_cuda:
print ('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
print('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
def commit(self,config_file_path:str):
'''
@ -337,14 +334,15 @@ class ModelCache(object):
'''
Returns the preamble for the config file.
'''
return '''# This file describes the alternative machine learning models
# available to InvokeAI script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
'''
return textwrap.dedent('''
# This file describes the alternative machine learning models
# available to InvokeAI script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
''')
def _invalidate_cached_model(self,model_name:str):
self.offload_model(model_name)
@ -383,10 +381,8 @@ class ModelCache(object):
Maintain a simple FIFO. First element is always the
least recent, and last element is always the most recent.
'''
try:
with contextlib.suppress(ValueError):
self.stack.remove(model_name)
except ValueError:
pass
self.stack.append(model_name)
def _has_cuda(self):
@ -397,10 +393,12 @@ class ModelCache(object):
basename = os.path.basename(path)
base, _ = os.path.splitext(basename)
hashpath = os.path.join(dirname,base+'.sha256')
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
with open(hashpath) as f:
hash = f.read()
return hash
print(f'>> Calculating sha256 hash of weights file')
tic = time.time()
sha = hashlib.sha256()
@ -408,6 +406,7 @@ class ModelCache(object):
hash = sha.hexdigest()
toc = time.time()
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
with open(hashpath,'w') as f:
f.write(hash)
return hash