cleanup ldm/invoke/model_cache.py

remove duplicate import: os
ldm.util.ask_user is imported only once now
introduce textwrap and contextlib packages to clean up the code
return, returns None implicitly so it is omitted
a function returns None by default so it is omitted
dict.get returns None by default if the value is not found so it is omitted
type of True is a bool and if the module only returns True then it should not return anything in the first place
added some indentations and line breaks to further improve readability

Signed-off-by: devops117 <55235206+devops117@users.noreply.github.com>
This commit is contained in:
devops117 2022-11-21 07:30:05 +05:30 committed by Lincoln Stein
parent 8e81425e89
commit a095214e52

View File

@ -15,10 +15,11 @@ import psutil
import sys import sys
import transformers import transformers
import traceback import traceback
import os import textwrap
import contextlib
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.errors import ConfigAttributeError from omegaconf.errors import ConfigAttributeError
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config, ask_user
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
@ -72,6 +73,7 @@ class ModelCache(object):
width = self.models[model_name]['width'] width = self.models[model_name]['width']
height = self.models[model_name]['height'] height = self.models[model_name]['height']
hash = self.models[model_name]['hash'] hash = self.models[model_name]['hash']
else: # we're about to load a new model, so potentially offload the least recently used one else: # we're about to load a new model, so potentially offload the least recently used one
try: try:
requested_model, width, height, hash = self._load_model(model_name) requested_model, width, height, hash = self._load_model(model_name)
@ -80,12 +82,13 @@ class ModelCache(object):
self.models[model_name]['width'] = width self.models[model_name]['width'] = width
self.models[model_name]['height'] = height self.models[model_name]['height'] = height
self.models[model_name]['hash'] = hash self.models[model_name]['hash'] = hash
except Exception as e: except Exception as e:
print(f'** model {model_name} could not be loaded: {str(e)}') print(f'** model {model_name} could not be loaded: {str(e)}')
print(traceback.format_exc()) print(traceback.format_exc())
print(f'** restoring {self.current_model}') print(f'** restoring {self.current_model}')
self.get_model(self.current_model) self.get_model(self.current_model)
return None return
self.current_model = model_name self.current_model = model_name
self._push_newest_model(model_name) self._push_newest_model(model_name)
@ -102,9 +105,8 @@ class ModelCache(object):
if none is defined. if none is defined.
''' '''
for model_name in self.config: for model_name in self.config:
if self.config[model_name].get('default',False): if self.config[model_name].get('default'):
return model_name return model_name
return None
def set_default_model(self,model_name:str): def set_default_model(self,model_name:str):
''' '''
@ -112,6 +114,7 @@ class ModelCache(object):
effect until you call model_cache.commit() effect until you call model_cache.commit()
''' '''
assert model_name in self.models,f"unknown model '{model_name}'" assert model_name in self.models,f"unknown model '{model_name}'"
config = self.config config = self.config
for model in config: for model in config:
config[model].pop('default',None) config[model].pop('default',None)
@ -125,22 +128,24 @@ class ModelCache(object):
}, },
model_name2: { etc } model_name2: { etc }
''' '''
result = {}
for name in self.config: for name in self.config:
try: try:
description = self.config[name].description description = self.config[name].description
except ConfigAttributeError: except ConfigAttributeError:
description = '<no description>' description = '<no description>'
if self.current_model == name: if self.current_model == name:
status = 'active' status = 'active'
elif name in self.models: elif name in self.models:
status = 'cached' status = 'cached'
else: else:
status = 'not loaded' status = 'not loaded'
result[name]={}
result[name]['status']=status return {
result[name]['description']=description name: {
return result 'status': status,
'description': description,
}}
def print_models(self): def print_models(self):
''' '''
@ -150,11 +155,10 @@ class ModelCache(object):
for name in models: for name in models:
line = f'{name:25s} {models[name]["status"]:>10s} {models[name]["description"]}' line = f'{name:25s} {models[name]["status"]:>10s} {models[name]["description"]}'
if models[name]['status'] == 'active': if models[name]['status'] == 'active':
print(f'\033[1m{line}\033[0m') line = f'\033[1m{line}\033[0m')
else: print(line)
print(line)
def del_model(self, model_name:str) ->bool: def del_model(self, model_name:str):
''' '''
Delete the named model. Delete the named model.
''' '''
@ -162,9 +166,8 @@ class ModelCache(object):
del omega[model_name] del omega[model_name]
if model_name in self.stack: if model_name in self.stack:
self.stack.remove(model_name) self.stack.remove(model_name)
return True
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True: def add_model(self, model_name:str, model_attributes:dict, clobber=False):
''' '''
Update the named model with a dictionary of attributes. Will fail with an Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite. assertion error if the name already exists. Pass clobber=True to overwrite.
@ -172,12 +175,11 @@ class ModelCache(object):
method will return True. Will fail with an assertion error if provided method will return True. Will fail with an assertion error if provided
attributes are incorrect or the model name is missing. attributes are incorrect or the model name is missing.
''' '''
omega = self.config
# check that all the required fields are present
for field in ('description','weights','height','width','config'): for field in ('description','weights','height','width','config'):
assert field in model_attributes, f'required field {field} is missing' assert field in model_attributes, f'required field {field} is missing'
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"' assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
omega = self.config
config = omega[model_name] if model_name in omega else {} config = omega[model_name] if model_name in omega else {}
for field in model_attributes: for field in model_attributes:
config[field] = model_attributes[field] config[field] = model_attributes[field]
@ -185,18 +187,16 @@ class ModelCache(object):
omega[model_name] = config omega[model_name] = config
if clobber: if clobber:
self._invalidate_cached_model(model_name) self._invalidate_cached_model(model_name)
return True
def _load_model(self, model_name:str): def _load_model(self, model_name:str):
"""Load and initialize the model from configuration variables passed at object creation time""" """Load and initialize the model from configuration variables passed at object creation time"""
if model_name not in self.config: if model_name not in self.config:
print(f'"{model_name}" is not a known model name. Please check your models.yaml file') print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
return None
mconfig = self.config[model_name] mconfig = self.config[model_name]
config = mconfig.config config = mconfig.config
weights = mconfig.weights weights = mconfig.weights
vae = mconfig.get('vae',None) vae = mconfig.get('vae')
width = mconfig.width width = mconfig.width
height = mconfig.height height = mconfig.height
@ -217,15 +217,15 @@ class ModelCache(object):
# this does the work # this does the work
if not os.path.isabs(config): if not os.path.isabs(config):
config = os.path.join(Globals.root,config) config = os.path.join(Globals.root,config)
c = OmegaConf.load(config) omega_config = OmegaConf.load(config)
with open(weights,'rb') as f: with open(weights,'rb') as f:
weight_bytes = f.read() weight_bytes = f.read()
model_hash = self._cached_sha256(weights,weight_bytes) model_hash = self._cached_sha256(weights,weight_bytes)
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu') sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
del weight_bytes del weight_bytes
sd = pl_sd['state_dict'] sd = sd['state_dict']
model = instantiate_from_config(c.model) model = instantiate_from_config(omega_config.model)
m, u = model.load_state_dict(sd, strict=False) model.load_state_dict(sd, strict=False)
if self.precision == 'float16': if self.precision == 'float16':
print(' | Using faster float16 precision') print(' | Using faster float16 precision')
@ -251,13 +251,14 @@ class ModelCache(object):
model.eval() model.eval()
for m in model.modules(): for module in model.modules():
if isinstance(m, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
m._orig_padding_mode = m.padding_mode module._orig_padding_mode = module.padding_mode
# usage statistics # usage statistics
toc = time.time() toc = time.time()
print(f'>> Model loaded in', '%4.2fs' % (toc - tic)) print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
if self._has_cuda(): if self._has_cuda():
print( print(
'>> Max VRAM used to load the model:', '>> Max VRAM used to load the model:',
@ -265,6 +266,7 @@ class ModelCache(object):
'\n>> Current VRAM usage:' '\n>> Current VRAM usage:'
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9), '%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
) )
return model, width, height, model_hash return model, width, height, model_hash
def offload_model(self, model_name:str): def offload_model(self, model_name:str):
@ -272,12 +274,10 @@ class ModelCache(object):
Offload the indicated model to CPU. Will call Offload the indicated model to CPU. Will call
_make_cache_room() to free space if needed. _make_cache_room() to free space if needed.
''' '''
if model_name not in self.models: if model_name not in self.models:
return return
message = f'>> Offloading {model_name} to CPU' print(f'>> Offloading {model_name} to CPU')
print(message)
model = self.models[model_name]['model'] model = self.models[model_name]['model']
self.models[model_name]['model'] = self._model_to_cpu(model) self.models[model_name]['model'] = self._model_to_cpu(model)
@ -299,11 +299,8 @@ class ModelCache(object):
sys.exit() sys.exit()
else: else:
print('\n### WARNING: InvokeAI was unable to scan the model you are using.') print('\n### WARNING: InvokeAI was unable to scan the model you are using.')
from ldm.util import ask_user
model_safe_check_fail = ask_user('Do you want to to continue loading the model?', ['y', 'n']) model_safe_check_fail = ask_user('Do you want to to continue loading the model?', ['y', 'n'])
if model_safe_check_fail.lower() == 'y': if model_safe_check_fail.lower() != 'y':
pass
else:
print("### Exiting InvokeAI") print("### Exiting InvokeAI")
sys.exit() sys.exit()
else: else:
@ -320,7 +317,7 @@ class ModelCache(object):
def print_vram_usage(self): def print_vram_usage(self):
if self._has_cuda: if self._has_cuda:
print ('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9)) print('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
def commit(self,config_file_path:str): def commit(self,config_file_path:str):
''' '''
@ -337,14 +334,15 @@ class ModelCache(object):
''' '''
Returns the preamble for the config file. Returns the preamble for the config file.
''' '''
return '''# This file describes the alternative machine learning models return textwrap.dedent('''
# available to InvokeAI script. # This file describes the alternative machine learning models
# # available to InvokeAI script.
# To add a new model, follow the examples below. Each #
# model requires a model config file, a weights file, # To add a new model, follow the examples below. Each
# and the width and height of the images it # model requires a model config file, a weights file,
# was trained on. # and the width and height of the images it
''' # was trained on.
''')
def _invalidate_cached_model(self,model_name:str): def _invalidate_cached_model(self,model_name:str):
self.offload_model(model_name) self.offload_model(model_name)
@ -383,10 +381,8 @@ class ModelCache(object):
Maintain a simple FIFO. First element is always the Maintain a simple FIFO. First element is always the
least recent, and last element is always the most recent. least recent, and last element is always the most recent.
''' '''
try: with contextlib.suppress(ValueError):
self.stack.remove(model_name) self.stack.remove(model_name)
except ValueError:
pass
self.stack.append(model_name) self.stack.append(model_name)
def _has_cuda(self): def _has_cuda(self):
@ -397,10 +393,12 @@ class ModelCache(object):
basename = os.path.basename(path) basename = os.path.basename(path)
base, _ = os.path.splitext(basename) base, _ = os.path.splitext(basename)
hashpath = os.path.join(dirname,base+'.sha256') hashpath = os.path.join(dirname,base+'.sha256')
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath): if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
with open(hashpath) as f: with open(hashpath) as f:
hash = f.read() hash = f.read()
return hash return hash
print(f'>> Calculating sha256 hash of weights file') print(f'>> Calculating sha256 hash of weights file')
tic = time.time() tic = time.time()
sha = hashlib.sha256() sha = hashlib.sha256()
@ -408,6 +406,7 @@ class ModelCache(object):
hash = sha.hexdigest() hash = sha.hexdigest()
toc = time.time() toc = time.time()
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic)) print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
with open(hashpath,'w') as f: with open(hashpath,'w') as f:
f.write(hash) f.write(hash)
return hash return hash