mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
further improvements to model loading
- code for committing config changes to models.yaml now in module rather than in invoke script - model marked "default" is now loaded if model not specified on command line - uncache changed models when edited, so that they reload properly - removed liaon from models.yaml and added stable-diffusion-1.5
This commit is contained in:
parent
a705a5a0aa
commit
83e6ab08aa
@ -6,15 +6,15 @@
|
|||||||
# and the width and height of the images it
|
# and the width and height of the images it
|
||||||
# was trained on.
|
# was trained on.
|
||||||
|
|
||||||
laion400m:
|
|
||||||
config: configs/latent-diffusion/txt2img-1p4B-eval.yaml
|
|
||||||
weights: models/ldm/text2img-large/model.ckpt
|
|
||||||
description: Latent Diffusion LAION400M model
|
|
||||||
width: 256
|
|
||||||
height: 256
|
|
||||||
stable-diffusion-1.4:
|
stable-diffusion-1.4:
|
||||||
config: configs/stable-diffusion/v1-inference.yaml
|
config: configs/stable-diffusion/v1-inference.yaml
|
||||||
weights: models/ldm/stable-diffusion-v1/model.ckpt
|
weights: models/ldm/stable-diffusion-v1/model.ckpt
|
||||||
description: Stable Diffusion inference model version 1.4
|
description: Stable Diffusion inference model version 1.4
|
||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
|
stable-diffusion-1.5:
|
||||||
|
config: configs/stable-diffusion/v1-inference.yaml
|
||||||
|
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
||||||
|
description: Stable Diffusion inference model version 1.5
|
||||||
|
width: 512
|
||||||
|
height: 512
|
||||||
|
@ -35,6 +35,9 @@ from ldm.invoke.devices import choose_torch_device, choose_precision
|
|||||||
from ldm.invoke.conditioning import get_uc_and_c
|
from ldm.invoke.conditioning import get_uc_and_c
|
||||||
from ldm.invoke.model_cache import ModelCache
|
from ldm.invoke.model_cache import ModelCache
|
||||||
|
|
||||||
|
# this is fallback model in case no default is defined
|
||||||
|
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
|
||||||
|
|
||||||
def fix_func(orig):
|
def fix_func(orig):
|
||||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||||
def new_func(*args, **kw):
|
def new_func(*args, **kw):
|
||||||
@ -127,7 +130,7 @@ class Generate:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model = 'stable-diffusion-1.4',
|
model = None,
|
||||||
conf = 'configs/models.yaml',
|
conf = 'configs/models.yaml',
|
||||||
embedding_path = None,
|
embedding_path = None,
|
||||||
sampler_name = 'k_lms',
|
sampler_name = 'k_lms',
|
||||||
@ -143,7 +146,6 @@ class Generate:
|
|||||||
free_gpu_mem=False,
|
free_gpu_mem=False,
|
||||||
):
|
):
|
||||||
mconfig = OmegaConf.load(conf)
|
mconfig = OmegaConf.load(conf)
|
||||||
self.model_name = model
|
|
||||||
self.height = None
|
self.height = None
|
||||||
self.width = None
|
self.width = None
|
||||||
self.model_cache = None
|
self.model_cache = None
|
||||||
@ -188,6 +190,8 @@ class Generate:
|
|||||||
|
|
||||||
# model caching system for fast switching
|
# model caching system for fast switching
|
||||||
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
||||||
|
print(f'DEBUG: model={model}, default_model={self.model_cache.default_model()}')
|
||||||
|
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
|
||||||
|
|
||||||
# for VRAM usage statistics
|
# for VRAM usage statistics
|
||||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||||
|
@ -364,17 +364,16 @@ class Args(object):
|
|||||||
deprecated_group.add_argument('--laion400m')
|
deprecated_group.add_argument('--laion400m')
|
||||||
deprecated_group.add_argument('--weights') # deprecated
|
deprecated_group.add_argument('--weights') # deprecated
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--conf',
|
'--config',
|
||||||
'-c',
|
'-c',
|
||||||
'-conf',
|
'-config',
|
||||||
dest='conf',
|
dest='conf',
|
||||||
default='./configs/models.yaml',
|
default='./configs/models.yaml',
|
||||||
help='Path to configuration file for alternate models.',
|
help='Path to configuration file for alternate models.',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--model',
|
'--model',
|
||||||
default='stable-diffusion-1.4',
|
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
|
||||||
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
|
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--sampler',
|
'--sampler',
|
||||||
|
@ -85,6 +85,26 @@ class ModelCache(object):
|
|||||||
'hash': hash
|
'hash': hash
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def default_model(self) -> str:
|
||||||
|
'''
|
||||||
|
Returns the name of the default model, or None
|
||||||
|
if none is defined.
|
||||||
|
'''
|
||||||
|
for model_name in self.config:
|
||||||
|
if self.config[model_name].get('default',False):
|
||||||
|
return model_name
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set_default_model(self,model_name:str):
|
||||||
|
'''
|
||||||
|
Set the default model. The change will not take
|
||||||
|
effect until you call model_cache.commit()
|
||||||
|
'''
|
||||||
|
assert model_name in self.models,f"unknown model '{model_name}'"
|
||||||
|
for model in self.models:
|
||||||
|
self.models[model].pop('default',None)
|
||||||
|
self.models[model_name]['default'] = True
|
||||||
|
|
||||||
def list_models(self) -> dict:
|
def list_models(self) -> dict:
|
||||||
'''
|
'''
|
||||||
Return a dict of models in the format:
|
Return a dict of models in the format:
|
||||||
@ -122,22 +142,23 @@ class ModelCache(object):
|
|||||||
else:
|
else:
|
||||||
print(line)
|
print(line)
|
||||||
|
|
||||||
def del_model(self, model_name:str) ->str:
|
def del_model(self, model_name:str) ->bool:
|
||||||
'''
|
'''
|
||||||
Delete the named model and return the YAML
|
Delete the named model.
|
||||||
'''
|
'''
|
||||||
omega = self.config
|
omega = self.config
|
||||||
del omega[model_name]
|
del omega[model_name]
|
||||||
if model_name in self.stack:
|
if model_name in self.stack:
|
||||||
self.stack.remove(model_name)
|
self.stack.remove(model_name)
|
||||||
return OmegaConf.to_yaml(omega)
|
return True
|
||||||
|
|
||||||
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str:
|
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True:
|
||||||
'''
|
'''
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
On a successful update, the config will be changed in memory and a YAML
|
On a successful update, the config will be changed in memory and the
|
||||||
string will be returned.
|
method will return True. Will fail with an assertion error if provided
|
||||||
|
attributes are incorrect or the model name is missing.
|
||||||
'''
|
'''
|
||||||
omega = self.config
|
omega = self.config
|
||||||
# check that all the required fields are present
|
# check that all the required fields are present
|
||||||
@ -150,7 +171,9 @@ class ModelCache(object):
|
|||||||
config[field] = model_attributes[field]
|
config[field] = model_attributes[field]
|
||||||
|
|
||||||
omega[model_name] = config
|
omega[model_name] = config
|
||||||
return OmegaConf.to_yaml(omega)
|
if clobber:
|
||||||
|
self._invalidate_cached_model(model_name)
|
||||||
|
return True
|
||||||
|
|
||||||
def _check_memory(self):
|
def _check_memory(self):
|
||||||
avail_memory = psutil.virtual_memory()[1]
|
avail_memory = psutil.virtual_memory()[1]
|
||||||
@ -230,6 +253,36 @@ class ModelCache(object):
|
|||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def commit(self,config_file_path:str):
|
||||||
|
'''
|
||||||
|
Write current configuration out to the indicated file.
|
||||||
|
'''
|
||||||
|
yaml_str = OmegaConf.to_yaml(self.config)
|
||||||
|
tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp')
|
||||||
|
with open(tmpfile, 'w') as outfile:
|
||||||
|
outfile.write(self.preamble())
|
||||||
|
outfile.write(yaml_str)
|
||||||
|
os.rename(tmpfile,config_file_path)
|
||||||
|
|
||||||
|
def preamble(self):
|
||||||
|
'''
|
||||||
|
Returns the preamble for the config file.
|
||||||
|
'''
|
||||||
|
return '''# This file describes the alternative machine learning models
|
||||||
|
# available to the dream script.
|
||||||
|
#
|
||||||
|
# To add a new model, follow the examples below. Each
|
||||||
|
# model requires a model config file, a weights file,
|
||||||
|
# and the width and height of the images it
|
||||||
|
# was trained on.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def _invalidate_cached_model(self,model_name:str):
|
||||||
|
self.unload_model(model_name)
|
||||||
|
if model_name in self.stack:
|
||||||
|
self.stack.remove(model_name)
|
||||||
|
self.models.pop(model_name,None)
|
||||||
|
|
||||||
def _model_to_cpu(self,model):
|
def _model_to_cpu(self,model):
|
||||||
if self.device != 'cpu':
|
if self.device != 'cpu':
|
||||||
model.cond_stage_model.device = 'cpu'
|
model.cond_stage_model.device = 'cpu'
|
||||||
|
@ -341,6 +341,7 @@ def main_loop(gen, opt, infile):
|
|||||||
|
|
||||||
print('goodbye!')
|
print('goodbye!')
|
||||||
|
|
||||||
|
# to do: this is ugly, fix
|
||||||
def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
||||||
operation = 'generate' # default operation, alternative is 'postprocess'
|
operation = 'generate' # default operation, alternative is 'postprocess'
|
||||||
|
|
||||||
@ -455,7 +456,10 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
|
|||||||
done = True
|
done = True
|
||||||
except:
|
except:
|
||||||
print('** Please enter a valid integer between 64 and 2048')
|
print('** Please enter a valid integer between 64 and 2048')
|
||||||
if write_config_file(opt.conf, gen, model_name, new_config):
|
|
||||||
|
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
||||||
|
|
||||||
|
if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default):
|
||||||
completer.add_model(model_name)
|
completer.add_model(model_name)
|
||||||
|
|
||||||
def del_config(model_name:str, gen, opt, completer):
|
def del_config(model_name:str, gen, opt, completer):
|
||||||
@ -488,14 +492,17 @@ def edit_config(model_name:str, gen, opt, completer):
|
|||||||
completer.linebuffer = str(conf[field]) if field in conf else ''
|
completer.linebuffer = str(conf[field]) if field in conf else ''
|
||||||
new_value = input(f'{field}: ')
|
new_value = input(f'{field}: ')
|
||||||
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
||||||
|
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
||||||
completer.complete_extensions(None)
|
completer.complete_extensions(None)
|
||||||
write_config_file(opt.conf, gen, model_name, new_config, clobber=True)
|
write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default)
|
||||||
|
|
||||||
def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
|
def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False):
|
||||||
current_model = gen.model_name
|
current_model = gen.model_name
|
||||||
|
|
||||||
op = 'modify' if clobber else 'import'
|
op = 'modify' if clobber else 'import'
|
||||||
print('\n>> New configuration:')
|
print('\n>> New configuration:')
|
||||||
|
if make_default:
|
||||||
|
new_config['default'] = True
|
||||||
print(yaml.dump({model_name:new_config}))
|
print(yaml.dump({model_name:new_config}))
|
||||||
if input(f'OK to {op} [n]? ') not in ('y','Y'):
|
if input(f'OK to {op} [n]? ') not in ('y','Y'):
|
||||||
return False
|
return False
|
||||||
@ -509,10 +516,11 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
|
|||||||
gen.model_cache.del_model(model_name)
|
gen.model_cache.del_model(model_name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp')
|
if make_default:
|
||||||
with open(tmpfile, 'w') as outfile:
|
print('making this default')
|
||||||
outfile.write(yaml_str)
|
gen.model_cache.set_default_model(model_name)
|
||||||
os.rename(tmpfile,conf_path)
|
|
||||||
|
gen.model_cache.commit(conf_path)
|
||||||
|
|
||||||
do_switch = input(f'Keep model loaded? [y]')
|
do_switch = input(f'Keep model loaded? [y]')
|
||||||
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
|
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user