further improvements to model loading

- code for committing config changes to models.yaml now in module
  rather than in invoke script
- model marked "default" is now loaded if model not specified on
  command line
- uncache changed models when edited, so that they reload properly
- removed liaon from models.yaml and added stable-diffusion-1.5
This commit is contained in:
Lincoln Stein 2022-10-21 00:28:54 -04:00
parent a705a5a0aa
commit 83e6ab08aa
5 changed files with 91 additions and 27 deletions

View File

@ -6,15 +6,15 @@
# and the width and height of the images it # and the width and height of the images it
# was trained on. # was trained on.
laion400m:
config: configs/latent-diffusion/txt2img-1p4B-eval.yaml
weights: models/ldm/text2img-large/model.ckpt
description: Latent Diffusion LAION400M model
width: 256
height: 256
stable-diffusion-1.4: stable-diffusion-1.4:
config: configs/stable-diffusion/v1-inference.yaml config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/model.ckpt weights: models/ldm/stable-diffusion-v1/model.ckpt
description: Stable Diffusion inference model version 1.4 description: Stable Diffusion inference model version 1.4
width: 512 width: 512
height: 512 height: 512
stable-diffusion-1.5:
config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
description: Stable Diffusion inference model version 1.5
width: 512
height: 512

View File

@ -35,6 +35,9 @@ from ldm.invoke.devices import choose_torch_device, choose_precision
from ldm.invoke.conditioning import get_uc_and_c from ldm.invoke.conditioning import get_uc_and_c
from ldm.invoke.model_cache import ModelCache from ldm.invoke.model_cache import ModelCache
# this is fallback model in case no default is defined
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
def fix_func(orig): def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw): def new_func(*args, **kw):
@ -127,7 +130,7 @@ class Generate:
def __init__( def __init__(
self, self,
model = 'stable-diffusion-1.4', model = None,
conf = 'configs/models.yaml', conf = 'configs/models.yaml',
embedding_path = None, embedding_path = None,
sampler_name = 'k_lms', sampler_name = 'k_lms',
@ -143,7 +146,6 @@ class Generate:
free_gpu_mem=False, free_gpu_mem=False,
): ):
mconfig = OmegaConf.load(conf) mconfig = OmegaConf.load(conf)
self.model_name = model
self.height = None self.height = None
self.width = None self.width = None
self.model_cache = None self.model_cache = None
@ -188,6 +190,8 @@ class Generate:
# model caching system for fast switching # model caching system for fast switching
self.model_cache = ModelCache(mconfig,self.device,self.precision) self.model_cache = ModelCache(mconfig,self.device,self.precision)
print(f'DEBUG: model={model}, default_model={self.model_cache.default_model()}')
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
# for VRAM usage statistics # for VRAM usage statistics
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None

View File

@ -364,17 +364,16 @@ class Args(object):
deprecated_group.add_argument('--laion400m') deprecated_group.add_argument('--laion400m')
deprecated_group.add_argument('--weights') # deprecated deprecated_group.add_argument('--weights') # deprecated
model_group.add_argument( model_group.add_argument(
'--conf', '--config',
'-c', '-c',
'-conf', '-config',
dest='conf', dest='conf',
default='./configs/models.yaml', default='./configs/models.yaml',
help='Path to configuration file for alternate models.', help='Path to configuration file for alternate models.',
) )
model_group.add_argument( model_group.add_argument(
'--model', '--model',
default='stable-diffusion-1.4', help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
) )
model_group.add_argument( model_group.add_argument(
'--sampler', '--sampler',

View File

@ -85,6 +85,26 @@ class ModelCache(object):
'hash': hash 'hash': hash
} }
def default_model(self) -> str:
'''
Returns the name of the default model, or None
if none is defined.
'''
for model_name in self.config:
if self.config[model_name].get('default',False):
return model_name
return None
def set_default_model(self,model_name:str):
'''
Set the default model. The change will not take
effect until you call model_cache.commit()
'''
assert model_name in self.models,f"unknown model '{model_name}'"
for model in self.models:
self.models[model].pop('default',None)
self.models[model_name]['default'] = True
def list_models(self) -> dict: def list_models(self) -> dict:
''' '''
Return a dict of models in the format: Return a dict of models in the format:
@ -122,22 +142,23 @@ class ModelCache(object):
else: else:
print(line) print(line)
def del_model(self, model_name:str) ->str: def del_model(self, model_name:str) ->bool:
''' '''
Delete the named model and return the YAML Delete the named model.
''' '''
omega = self.config omega = self.config
del omega[model_name] del omega[model_name]
if model_name in self.stack: if model_name in self.stack:
self.stack.remove(model_name) self.stack.remove(model_name)
return OmegaConf.to_yaml(omega) return True
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str: def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True:
''' '''
Update the named model with a dictionary of attributes. Will fail with an Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite. assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory and a YAML On a successful update, the config will be changed in memory and the
string will be returned. method will return True. Will fail with an assertion error if provided
attributes are incorrect or the model name is missing.
''' '''
omega = self.config omega = self.config
# check that all the required fields are present # check that all the required fields are present
@ -150,7 +171,9 @@ class ModelCache(object):
config[field] = model_attributes[field] config[field] = model_attributes[field]
omega[model_name] = config omega[model_name] = config
return OmegaConf.to_yaml(omega) if clobber:
self._invalidate_cached_model(model_name)
return True
def _check_memory(self): def _check_memory(self):
avail_memory = psutil.virtual_memory()[1] avail_memory = psutil.virtual_memory()[1]
@ -230,6 +253,36 @@ class ModelCache(object):
if self._has_cuda(): if self._has_cuda():
torch.cuda.empty_cache() torch.cuda.empty_cache()
def commit(self,config_file_path:str):
'''
Write current configuration out to the indicated file.
'''
yaml_str = OmegaConf.to_yaml(self.config)
tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp')
with open(tmpfile, 'w') as outfile:
outfile.write(self.preamble())
outfile.write(yaml_str)
os.rename(tmpfile,config_file_path)
def preamble(self):
'''
Returns the preamble for the config file.
'''
return '''# This file describes the alternative machine learning models
# available to the dream script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
'''
def _invalidate_cached_model(self,model_name:str):
self.unload_model(model_name)
if model_name in self.stack:
self.stack.remove(model_name)
self.models.pop(model_name,None)
def _model_to_cpu(self,model): def _model_to_cpu(self,model):
if self.device != 'cpu': if self.device != 'cpu':
model.cond_stage_model.device = 'cpu' model.cond_stage_model.device = 'cpu'

View File

@ -341,6 +341,7 @@ def main_loop(gen, opt, infile):
print('goodbye!') print('goodbye!')
# to do: this is ugly, fix
def do_command(command:str, gen, opt:Args, completer) -> tuple: def do_command(command:str, gen, opt:Args, completer) -> tuple:
operation = 'generate' # default operation, alternative is 'postprocess' operation = 'generate' # default operation, alternative is 'postprocess'
@ -455,7 +456,10 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
done = True done = True
except: except:
print('** Please enter a valid integer between 64 and 2048') print('** Please enter a valid integer between 64 and 2048')
if write_config_file(opt.conf, gen, model_name, new_config):
make_default = input('Make this the default model? [n] ') in ('y','Y')
if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default):
completer.add_model(model_name) completer.add_model(model_name)
def del_config(model_name:str, gen, opt, completer): def del_config(model_name:str, gen, opt, completer):
@ -488,14 +492,17 @@ def edit_config(model_name:str, gen, opt, completer):
completer.linebuffer = str(conf[field]) if field in conf else '' completer.linebuffer = str(conf[field]) if field in conf else ''
new_value = input(f'{field}: ') new_value = input(f'{field}: ')
new_config[field] = int(new_value) if field in ('width','height') else new_value new_config[field] = int(new_value) if field in ('width','height') else new_value
make_default = input('Make this the default model? [n] ') in ('y','Y')
completer.complete_extensions(None) completer.complete_extensions(None)
write_config_file(opt.conf, gen, model_name, new_config, clobber=True) write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default)
def write_config_file(conf_path, gen, model_name, new_config, clobber=False): def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False):
current_model = gen.model_name current_model = gen.model_name
op = 'modify' if clobber else 'import' op = 'modify' if clobber else 'import'
print('\n>> New configuration:') print('\n>> New configuration:')
if make_default:
new_config['default'] = True
print(yaml.dump({model_name:new_config})) print(yaml.dump({model_name:new_config}))
if input(f'OK to {op} [n]? ') not in ('y','Y'): if input(f'OK to {op} [n]? ') not in ('y','Y'):
return False return False
@ -509,10 +516,11 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
gen.model_cache.del_model(model_name) gen.model_cache.del_model(model_name)
return False return False
tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp') if make_default:
with open(tmpfile, 'w') as outfile: print('making this default')
outfile.write(yaml_str) gen.model_cache.set_default_model(model_name)
os.rename(tmpfile,conf_path)
gen.model_cache.commit(conf_path)
do_switch = input(f'Keep model loaded? [y]') do_switch = input(f'Keep model loaded? [y]')
if len(do_switch)==0 or do_switch[0] in ('y','Y'): if len(do_switch)==0 or do_switch[0] in ('y','Y'):