mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cleanup ldm/invoke/model_cache.py
remove duplicate import: os ldm.util.ask_user is imported only once now introduce textwrap and contextlib packages to clean up the code return, returns None implicitly so it is omitted a function returns None by default so it is omitted dict.get returns None by default if the value is not found so it is omitted type of True is a bool and if the module only returns True then it should not return anything in the first place added some indentations and line breaks to further improve readability Signed-off-by: devops117 <55235206+devops117@users.noreply.github.com>
This commit is contained in:
parent
8e81425e89
commit
a095214e52
@ -15,10 +15,11 @@ import psutil
|
||||
import sys
|
||||
import transformers
|
||||
import traceback
|
||||
import os
|
||||
import textwrap
|
||||
import contextlib
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.errors import ConfigAttributeError
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.util import instantiate_from_config, ask_user
|
||||
from ldm.invoke.globals import Globals
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
@ -72,6 +73,7 @@ class ModelCache(object):
|
||||
width = self.models[model_name]['width']
|
||||
height = self.models[model_name]['height']
|
||||
hash = self.models[model_name]['hash']
|
||||
|
||||
else: # we're about to load a new model, so potentially offload the least recently used one
|
||||
try:
|
||||
requested_model, width, height, hash = self._load_model(model_name)
|
||||
@ -80,12 +82,13 @@ class ModelCache(object):
|
||||
self.models[model_name]['width'] = width
|
||||
self.models[model_name]['height'] = height
|
||||
self.models[model_name]['hash'] = hash
|
||||
|
||||
except Exception as e:
|
||||
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||
print(traceback.format_exc())
|
||||
print(f'** restoring {self.current_model}')
|
||||
self.get_model(self.current_model)
|
||||
return None
|
||||
return
|
||||
|
||||
self.current_model = model_name
|
||||
self._push_newest_model(model_name)
|
||||
@ -102,9 +105,8 @@ class ModelCache(object):
|
||||
if none is defined.
|
||||
'''
|
||||
for model_name in self.config:
|
||||
if self.config[model_name].get('default',False):
|
||||
if self.config[model_name].get('default'):
|
||||
return model_name
|
||||
return None
|
||||
|
||||
def set_default_model(self,model_name:str):
|
||||
'''
|
||||
@ -112,6 +114,7 @@ class ModelCache(object):
|
||||
effect until you call model_cache.commit()
|
||||
'''
|
||||
assert model_name in self.models,f"unknown model '{model_name}'"
|
||||
|
||||
config = self.config
|
||||
for model in config:
|
||||
config[model].pop('default',None)
|
||||
@ -125,22 +128,24 @@ class ModelCache(object):
|
||||
},
|
||||
model_name2: { etc }
|
||||
'''
|
||||
result = {}
|
||||
for name in self.config:
|
||||
try:
|
||||
description = self.config[name].description
|
||||
except ConfigAttributeError:
|
||||
description = '<no description>'
|
||||
|
||||
if self.current_model == name:
|
||||
status = 'active'
|
||||
elif name in self.models:
|
||||
status = 'cached'
|
||||
else:
|
||||
status = 'not loaded'
|
||||
result[name]={}
|
||||
result[name]['status']=status
|
||||
result[name]['description']=description
|
||||
return result
|
||||
|
||||
return {
|
||||
name: {
|
||||
'status': status,
|
||||
'description': description,
|
||||
}}
|
||||
|
||||
def print_models(self):
|
||||
'''
|
||||
@ -150,11 +155,10 @@ class ModelCache(object):
|
||||
for name in models:
|
||||
line = f'{name:25s} {models[name]["status"]:>10s} {models[name]["description"]}'
|
||||
if models[name]['status'] == 'active':
|
||||
print(f'\033[1m{line}\033[0m')
|
||||
else:
|
||||
print(line)
|
||||
line = f'\033[1m{line}\033[0m')
|
||||
print(line)
|
||||
|
||||
def del_model(self, model_name:str) ->bool:
|
||||
def del_model(self, model_name:str):
|
||||
'''
|
||||
Delete the named model.
|
||||
'''
|
||||
@ -162,9 +166,8 @@ class ModelCache(object):
|
||||
del omega[model_name]
|
||||
if model_name in self.stack:
|
||||
self.stack.remove(model_name)
|
||||
return True
|
||||
|
||||
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True:
|
||||
def add_model(self, model_name:str, model_attributes:dict, clobber=False):
|
||||
'''
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
@ -172,12 +175,11 @@ class ModelCache(object):
|
||||
method will return True. Will fail with an assertion error if provided
|
||||
attributes are incorrect or the model name is missing.
|
||||
'''
|
||||
omega = self.config
|
||||
# check that all the required fields are present
|
||||
for field in ('description','weights','height','width','config'):
|
||||
assert field in model_attributes, f'required field {field} is missing'
|
||||
|
||||
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
|
||||
|
||||
omega = self.config
|
||||
config = omega[model_name] if model_name in omega else {}
|
||||
for field in model_attributes:
|
||||
config[field] = model_attributes[field]
|
||||
@ -185,18 +187,16 @@ class ModelCache(object):
|
||||
omega[model_name] = config
|
||||
if clobber:
|
||||
self._invalidate_cached_model(model_name)
|
||||
return True
|
||||
|
||||
def _load_model(self, model_name:str):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if model_name not in self.config:
|
||||
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
return None
|
||||
|
||||
mconfig = self.config[model_name]
|
||||
config = mconfig.config
|
||||
weights = mconfig.weights
|
||||
vae = mconfig.get('vae',None)
|
||||
vae = mconfig.get('vae')
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
@ -217,15 +217,15 @@ class ModelCache(object):
|
||||
# this does the work
|
||||
if not os.path.isabs(config):
|
||||
config = os.path.join(Globals.root,config)
|
||||
c = OmegaConf.load(config)
|
||||
omega_config = OmegaConf.load(config)
|
||||
with open(weights,'rb') as f:
|
||||
weight_bytes = f.read()
|
||||
model_hash = self._cached_sha256(weights,weight_bytes)
|
||||
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
||||
sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
||||
del weight_bytes
|
||||
sd = pl_sd['state_dict']
|
||||
model = instantiate_from_config(c.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
sd = sd['state_dict']
|
||||
model = instantiate_from_config(omega_config.model)
|
||||
model.load_state_dict(sd, strict=False)
|
||||
|
||||
if self.precision == 'float16':
|
||||
print(' | Using faster float16 precision')
|
||||
@ -251,13 +251,14 @@ class ModelCache(object):
|
||||
|
||||
model.eval()
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
||||
m._orig_padding_mode = m.padding_mode
|
||||
for module in model.modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
||||
module._orig_padding_mode = module.padding_mode
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
|
||||
|
||||
if self._has_cuda():
|
||||
print(
|
||||
'>> Max VRAM used to load the model:',
|
||||
@ -265,6 +266,7 @@ class ModelCache(object):
|
||||
'\n>> Current VRAM usage:'
|
||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
|
||||
return model, width, height, model_hash
|
||||
|
||||
def offload_model(self, model_name:str):
|
||||
@ -272,12 +274,10 @@ class ModelCache(object):
|
||||
Offload the indicated model to CPU. Will call
|
||||
_make_cache_room() to free space if needed.
|
||||
'''
|
||||
|
||||
if model_name not in self.models:
|
||||
return
|
||||
|
||||
message = f'>> Offloading {model_name} to CPU'
|
||||
print(message)
|
||||
print(f'>> Offloading {model_name} to CPU')
|
||||
model = self.models[model_name]['model']
|
||||
self.models[model_name]['model'] = self._model_to_cpu(model)
|
||||
|
||||
@ -299,11 +299,8 @@ class ModelCache(object):
|
||||
sys.exit()
|
||||
else:
|
||||
print('\n### WARNING: InvokeAI was unable to scan the model you are using.')
|
||||
from ldm.util import ask_user
|
||||
model_safe_check_fail = ask_user('Do you want to to continue loading the model?', ['y', 'n'])
|
||||
if model_safe_check_fail.lower() == 'y':
|
||||
pass
|
||||
else:
|
||||
if model_safe_check_fail.lower() != 'y':
|
||||
print("### Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
@ -320,7 +317,7 @@ class ModelCache(object):
|
||||
|
||||
def print_vram_usage(self):
|
||||
if self._has_cuda:
|
||||
print ('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
|
||||
print('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
|
||||
|
||||
def commit(self,config_file_path:str):
|
||||
'''
|
||||
@ -337,14 +334,15 @@ class ModelCache(object):
|
||||
'''
|
||||
Returns the preamble for the config file.
|
||||
'''
|
||||
return '''# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
#
|
||||
# To add a new model, follow the examples below. Each
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
'''
|
||||
return textwrap.dedent('''
|
||||
# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
#
|
||||
# To add a new model, follow the examples below. Each
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
''')
|
||||
|
||||
def _invalidate_cached_model(self,model_name:str):
|
||||
self.offload_model(model_name)
|
||||
@ -383,10 +381,8 @@ class ModelCache(object):
|
||||
Maintain a simple FIFO. First element is always the
|
||||
least recent, and last element is always the most recent.
|
||||
'''
|
||||
try:
|
||||
with contextlib.suppress(ValueError):
|
||||
self.stack.remove(model_name)
|
||||
except ValueError:
|
||||
pass
|
||||
self.stack.append(model_name)
|
||||
|
||||
def _has_cuda(self):
|
||||
@ -397,10 +393,12 @@ class ModelCache(object):
|
||||
basename = os.path.basename(path)
|
||||
base, _ = os.path.splitext(basename)
|
||||
hashpath = os.path.join(dirname,base+'.sha256')
|
||||
|
||||
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
|
||||
print(f'>> Calculating sha256 hash of weights file')
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
@ -408,6 +406,7 @@ class ModelCache(object):
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
|
||||
|
||||
with open(hashpath,'w') as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
|
Loading…
Reference in New Issue
Block a user