mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cleanup ldm/invoke/model_cache.py
remove duplicate import: os ldm.util.ask_user is imported only once now introduce textwrap and contextlib packages to clean up the code return, returns None implicitly so it is omitted a function returns None by default so it is omitted dict.get returns None by default if the value is not found so it is omitted type of True is a bool and if the module only returns True then it should not return anything in the first place added some indentations and line breaks to further improve readability Signed-off-by: devops117 <55235206+devops117@users.noreply.github.com>
This commit is contained in:
parent
8e81425e89
commit
a095214e52
@ -15,10 +15,11 @@ import psutil
|
|||||||
import sys
|
import sys
|
||||||
import transformers
|
import transformers
|
||||||
import traceback
|
import traceback
|
||||||
import os
|
import textwrap
|
||||||
|
import contextlib
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.errors import ConfigAttributeError
|
from omegaconf.errors import ConfigAttributeError
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config, ask_user
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
@ -72,6 +73,7 @@ class ModelCache(object):
|
|||||||
width = self.models[model_name]['width']
|
width = self.models[model_name]['width']
|
||||||
height = self.models[model_name]['height']
|
height = self.models[model_name]['height']
|
||||||
hash = self.models[model_name]['hash']
|
hash = self.models[model_name]['hash']
|
||||||
|
|
||||||
else: # we're about to load a new model, so potentially offload the least recently used one
|
else: # we're about to load a new model, so potentially offload the least recently used one
|
||||||
try:
|
try:
|
||||||
requested_model, width, height, hash = self._load_model(model_name)
|
requested_model, width, height, hash = self._load_model(model_name)
|
||||||
@ -80,12 +82,13 @@ class ModelCache(object):
|
|||||||
self.models[model_name]['width'] = width
|
self.models[model_name]['width'] = width
|
||||||
self.models[model_name]['height'] = height
|
self.models[model_name]['height'] = height
|
||||||
self.models[model_name]['hash'] = hash
|
self.models[model_name]['hash'] = hash
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'** model {model_name} could not be loaded: {str(e)}')
|
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
print(f'** restoring {self.current_model}')
|
print(f'** restoring {self.current_model}')
|
||||||
self.get_model(self.current_model)
|
self.get_model(self.current_model)
|
||||||
return None
|
return
|
||||||
|
|
||||||
self.current_model = model_name
|
self.current_model = model_name
|
||||||
self._push_newest_model(model_name)
|
self._push_newest_model(model_name)
|
||||||
@ -102,9 +105,8 @@ class ModelCache(object):
|
|||||||
if none is defined.
|
if none is defined.
|
||||||
'''
|
'''
|
||||||
for model_name in self.config:
|
for model_name in self.config:
|
||||||
if self.config[model_name].get('default',False):
|
if self.config[model_name].get('default'):
|
||||||
return model_name
|
return model_name
|
||||||
return None
|
|
||||||
|
|
||||||
def set_default_model(self,model_name:str):
|
def set_default_model(self,model_name:str):
|
||||||
'''
|
'''
|
||||||
@ -112,6 +114,7 @@ class ModelCache(object):
|
|||||||
effect until you call model_cache.commit()
|
effect until you call model_cache.commit()
|
||||||
'''
|
'''
|
||||||
assert model_name in self.models,f"unknown model '{model_name}'"
|
assert model_name in self.models,f"unknown model '{model_name}'"
|
||||||
|
|
||||||
config = self.config
|
config = self.config
|
||||||
for model in config:
|
for model in config:
|
||||||
config[model].pop('default',None)
|
config[model].pop('default',None)
|
||||||
@ -125,22 +128,24 @@ class ModelCache(object):
|
|||||||
},
|
},
|
||||||
model_name2: { etc }
|
model_name2: { etc }
|
||||||
'''
|
'''
|
||||||
result = {}
|
|
||||||
for name in self.config:
|
for name in self.config:
|
||||||
try:
|
try:
|
||||||
description = self.config[name].description
|
description = self.config[name].description
|
||||||
except ConfigAttributeError:
|
except ConfigAttributeError:
|
||||||
description = '<no description>'
|
description = '<no description>'
|
||||||
|
|
||||||
if self.current_model == name:
|
if self.current_model == name:
|
||||||
status = 'active'
|
status = 'active'
|
||||||
elif name in self.models:
|
elif name in self.models:
|
||||||
status = 'cached'
|
status = 'cached'
|
||||||
else:
|
else:
|
||||||
status = 'not loaded'
|
status = 'not loaded'
|
||||||
result[name]={}
|
|
||||||
result[name]['status']=status
|
return {
|
||||||
result[name]['description']=description
|
name: {
|
||||||
return result
|
'status': status,
|
||||||
|
'description': description,
|
||||||
|
}}
|
||||||
|
|
||||||
def print_models(self):
|
def print_models(self):
|
||||||
'''
|
'''
|
||||||
@ -150,11 +155,10 @@ class ModelCache(object):
|
|||||||
for name in models:
|
for name in models:
|
||||||
line = f'{name:25s} {models[name]["status"]:>10s} {models[name]["description"]}'
|
line = f'{name:25s} {models[name]["status"]:>10s} {models[name]["description"]}'
|
||||||
if models[name]['status'] == 'active':
|
if models[name]['status'] == 'active':
|
||||||
print(f'\033[1m{line}\033[0m')
|
line = f'\033[1m{line}\033[0m')
|
||||||
else:
|
print(line)
|
||||||
print(line)
|
|
||||||
|
|
||||||
def del_model(self, model_name:str) ->bool:
|
def del_model(self, model_name:str):
|
||||||
'''
|
'''
|
||||||
Delete the named model.
|
Delete the named model.
|
||||||
'''
|
'''
|
||||||
@ -162,9 +166,8 @@ class ModelCache(object):
|
|||||||
del omega[model_name]
|
del omega[model_name]
|
||||||
if model_name in self.stack:
|
if model_name in self.stack:
|
||||||
self.stack.remove(model_name)
|
self.stack.remove(model_name)
|
||||||
return True
|
|
||||||
|
|
||||||
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True:
|
def add_model(self, model_name:str, model_attributes:dict, clobber=False):
|
||||||
'''
|
'''
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
@ -172,12 +175,11 @@ class ModelCache(object):
|
|||||||
method will return True. Will fail with an assertion error if provided
|
method will return True. Will fail with an assertion error if provided
|
||||||
attributes are incorrect or the model name is missing.
|
attributes are incorrect or the model name is missing.
|
||||||
'''
|
'''
|
||||||
omega = self.config
|
|
||||||
# check that all the required fields are present
|
|
||||||
for field in ('description','weights','height','width','config'):
|
for field in ('description','weights','height','width','config'):
|
||||||
assert field in model_attributes, f'required field {field} is missing'
|
assert field in model_attributes, f'required field {field} is missing'
|
||||||
|
|
||||||
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
|
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
|
||||||
|
|
||||||
|
omega = self.config
|
||||||
config = omega[model_name] if model_name in omega else {}
|
config = omega[model_name] if model_name in omega else {}
|
||||||
for field in model_attributes:
|
for field in model_attributes:
|
||||||
config[field] = model_attributes[field]
|
config[field] = model_attributes[field]
|
||||||
@ -185,18 +187,16 @@ class ModelCache(object):
|
|||||||
omega[model_name] = config
|
omega[model_name] = config
|
||||||
if clobber:
|
if clobber:
|
||||||
self._invalidate_cached_model(model_name)
|
self._invalidate_cached_model(model_name)
|
||||||
return True
|
|
||||||
|
|
||||||
def _load_model(self, model_name:str):
|
def _load_model(self, model_name:str):
|
||||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||||
if model_name not in self.config:
|
if model_name not in self.config:
|
||||||
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
||||||
return None
|
|
||||||
|
|
||||||
mconfig = self.config[model_name]
|
mconfig = self.config[model_name]
|
||||||
config = mconfig.config
|
config = mconfig.config
|
||||||
weights = mconfig.weights
|
weights = mconfig.weights
|
||||||
vae = mconfig.get('vae',None)
|
vae = mconfig.get('vae')
|
||||||
width = mconfig.width
|
width = mconfig.width
|
||||||
height = mconfig.height
|
height = mconfig.height
|
||||||
|
|
||||||
@ -217,15 +217,15 @@ class ModelCache(object):
|
|||||||
# this does the work
|
# this does the work
|
||||||
if not os.path.isabs(config):
|
if not os.path.isabs(config):
|
||||||
config = os.path.join(Globals.root,config)
|
config = os.path.join(Globals.root,config)
|
||||||
c = OmegaConf.load(config)
|
omega_config = OmegaConf.load(config)
|
||||||
with open(weights,'rb') as f:
|
with open(weights,'rb') as f:
|
||||||
weight_bytes = f.read()
|
weight_bytes = f.read()
|
||||||
model_hash = self._cached_sha256(weights,weight_bytes)
|
model_hash = self._cached_sha256(weights,weight_bytes)
|
||||||
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
||||||
del weight_bytes
|
del weight_bytes
|
||||||
sd = pl_sd['state_dict']
|
sd = sd['state_dict']
|
||||||
model = instantiate_from_config(c.model)
|
model = instantiate_from_config(omega_config.model)
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
if self.precision == 'float16':
|
if self.precision == 'float16':
|
||||||
print(' | Using faster float16 precision')
|
print(' | Using faster float16 precision')
|
||||||
@ -251,13 +251,14 @@ class ModelCache(object):
|
|||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
for m in model.modules():
|
for module in model.modules():
|
||||||
if isinstance(m, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
||||||
m._orig_padding_mode = m.padding_mode
|
module._orig_padding_mode = module.padding_mode
|
||||||
|
|
||||||
# usage statistics
|
# usage statistics
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
|
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
|
||||||
|
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
print(
|
print(
|
||||||
'>> Max VRAM used to load the model:',
|
'>> Max VRAM used to load the model:',
|
||||||
@ -265,6 +266,7 @@ class ModelCache(object):
|
|||||||
'\n>> Current VRAM usage:'
|
'\n>> Current VRAM usage:'
|
||||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||||
)
|
)
|
||||||
|
|
||||||
return model, width, height, model_hash
|
return model, width, height, model_hash
|
||||||
|
|
||||||
def offload_model(self, model_name:str):
|
def offload_model(self, model_name:str):
|
||||||
@ -272,12 +274,10 @@ class ModelCache(object):
|
|||||||
Offload the indicated model to CPU. Will call
|
Offload the indicated model to CPU. Will call
|
||||||
_make_cache_room() to free space if needed.
|
_make_cache_room() to free space if needed.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
if model_name not in self.models:
|
if model_name not in self.models:
|
||||||
return
|
return
|
||||||
|
|
||||||
message = f'>> Offloading {model_name} to CPU'
|
print(f'>> Offloading {model_name} to CPU')
|
||||||
print(message)
|
|
||||||
model = self.models[model_name]['model']
|
model = self.models[model_name]['model']
|
||||||
self.models[model_name]['model'] = self._model_to_cpu(model)
|
self.models[model_name]['model'] = self._model_to_cpu(model)
|
||||||
|
|
||||||
@ -299,11 +299,8 @@ class ModelCache(object):
|
|||||||
sys.exit()
|
sys.exit()
|
||||||
else:
|
else:
|
||||||
print('\n### WARNING: InvokeAI was unable to scan the model you are using.')
|
print('\n### WARNING: InvokeAI was unable to scan the model you are using.')
|
||||||
from ldm.util import ask_user
|
|
||||||
model_safe_check_fail = ask_user('Do you want to to continue loading the model?', ['y', 'n'])
|
model_safe_check_fail = ask_user('Do you want to to continue loading the model?', ['y', 'n'])
|
||||||
if model_safe_check_fail.lower() == 'y':
|
if model_safe_check_fail.lower() != 'y':
|
||||||
pass
|
|
||||||
else:
|
|
||||||
print("### Exiting InvokeAI")
|
print("### Exiting InvokeAI")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
else:
|
else:
|
||||||
@ -320,7 +317,7 @@ class ModelCache(object):
|
|||||||
|
|
||||||
def print_vram_usage(self):
|
def print_vram_usage(self):
|
||||||
if self._has_cuda:
|
if self._has_cuda:
|
||||||
print ('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
|
print('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
|
||||||
|
|
||||||
def commit(self,config_file_path:str):
|
def commit(self,config_file_path:str):
|
||||||
'''
|
'''
|
||||||
@ -337,14 +334,15 @@ class ModelCache(object):
|
|||||||
'''
|
'''
|
||||||
Returns the preamble for the config file.
|
Returns the preamble for the config file.
|
||||||
'''
|
'''
|
||||||
return '''# This file describes the alternative machine learning models
|
return textwrap.dedent('''
|
||||||
# available to InvokeAI script.
|
# This file describes the alternative machine learning models
|
||||||
#
|
# available to InvokeAI script.
|
||||||
# To add a new model, follow the examples below. Each
|
#
|
||||||
# model requires a model config file, a weights file,
|
# To add a new model, follow the examples below. Each
|
||||||
# and the width and height of the images it
|
# model requires a model config file, a weights file,
|
||||||
# was trained on.
|
# and the width and height of the images it
|
||||||
'''
|
# was trained on.
|
||||||
|
''')
|
||||||
|
|
||||||
def _invalidate_cached_model(self,model_name:str):
|
def _invalidate_cached_model(self,model_name:str):
|
||||||
self.offload_model(model_name)
|
self.offload_model(model_name)
|
||||||
@ -383,10 +381,8 @@ class ModelCache(object):
|
|||||||
Maintain a simple FIFO. First element is always the
|
Maintain a simple FIFO. First element is always the
|
||||||
least recent, and last element is always the most recent.
|
least recent, and last element is always the most recent.
|
||||||
'''
|
'''
|
||||||
try:
|
with contextlib.suppress(ValueError):
|
||||||
self.stack.remove(model_name)
|
self.stack.remove(model_name)
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
self.stack.append(model_name)
|
self.stack.append(model_name)
|
||||||
|
|
||||||
def _has_cuda(self):
|
def _has_cuda(self):
|
||||||
@ -397,10 +393,12 @@ class ModelCache(object):
|
|||||||
basename = os.path.basename(path)
|
basename = os.path.basename(path)
|
||||||
base, _ = os.path.splitext(basename)
|
base, _ = os.path.splitext(basename)
|
||||||
hashpath = os.path.join(dirname,base+'.sha256')
|
hashpath = os.path.join(dirname,base+'.sha256')
|
||||||
|
|
||||||
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
|
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
|
||||||
with open(hashpath) as f:
|
with open(hashpath) as f:
|
||||||
hash = f.read()
|
hash = f.read()
|
||||||
return hash
|
return hash
|
||||||
|
|
||||||
print(f'>> Calculating sha256 hash of weights file')
|
print(f'>> Calculating sha256 hash of weights file')
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
@ -408,6 +406,7 @@ class ModelCache(object):
|
|||||||
hash = sha.hexdigest()
|
hash = sha.hexdigest()
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
|
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
|
||||||
|
|
||||||
with open(hashpath,'w') as f:
|
with open(hashpath,'w') as f:
|
||||||
f.write(hash)
|
f.write(hash)
|
||||||
return hash
|
return hash
|
||||||
|
Loading…
Reference in New Issue
Block a user