mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'model-switching' into development
This commit is contained in:
commit
b2bf2b08ff
@ -6,15 +6,15 @@
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
|
||||
laion400m:
|
||||
config: configs/latent-diffusion/txt2img-1p4B-eval.yaml
|
||||
weights: models/ldm/text2img-large/model.ckpt
|
||||
description: Latent Diffusion LAION400M model
|
||||
width: 256
|
||||
height: 256
|
||||
stable-diffusion-1.4:
|
||||
config: configs/stable-diffusion/v1-inference.yaml
|
||||
weights: models/ldm/stable-diffusion-v1/model.ckpt
|
||||
description: Stable Diffusion inference model version 1.4
|
||||
width: 512
|
||||
height: 512
|
||||
stable-diffusion-1.5:
|
||||
config: configs/stable-diffusion/v1-inference.yaml
|
||||
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
||||
description: Stable Diffusion inference model version 1.5
|
||||
width: 512
|
||||
height: 512
|
||||
|
@ -55,6 +55,9 @@ torch.randint_like = fix_func(torch.randint_like)
|
||||
torch.bernoulli = fix_func(torch.bernoulli)
|
||||
torch.multinomial = fix_func(torch.multinomial)
|
||||
|
||||
# this is fallback model in case no default is defined
|
||||
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
|
||||
|
||||
def fix_func(orig):
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
def new_func(*args, **kw):
|
||||
@ -147,7 +150,7 @@ class Generate:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model = 'stable-diffusion-1.4',
|
||||
model = None,
|
||||
conf = 'configs/models.yaml',
|
||||
embedding_path = None,
|
||||
sampler_name = 'k_lms',
|
||||
@ -163,7 +166,6 @@ class Generate:
|
||||
free_gpu_mem=False,
|
||||
):
|
||||
mconfig = OmegaConf.load(conf)
|
||||
self.model_name = model
|
||||
self.height = None
|
||||
self.width = None
|
||||
self.model_cache = None
|
||||
@ -210,6 +212,8 @@ class Generate:
|
||||
|
||||
# model caching system for fast switching
|
||||
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
||||
print(f'DEBUG: model={model}, default_model={self.model_cache.default_model()}')
|
||||
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
|
||||
|
||||
# for VRAM usage statistics
|
||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||
@ -715,8 +719,7 @@ class Generate:
|
||||
|
||||
model_data = self.model_cache.get_model(model_name)
|
||||
if model_data is None or len(model_data) == 0:
|
||||
print(f'** Model switch failed **')
|
||||
return self.model
|
||||
return None
|
||||
|
||||
self.model = model_data['model']
|
||||
self.width = model_data['width']
|
||||
|
@ -366,17 +366,16 @@ class Args(object):
|
||||
deprecated_group.add_argument('--laion400m')
|
||||
deprecated_group.add_argument('--weights') # deprecated
|
||||
model_group.add_argument(
|
||||
'--conf',
|
||||
'--config',
|
||||
'-c',
|
||||
'-conf',
|
||||
'-config',
|
||||
dest='conf',
|
||||
default='./configs/models.yaml',
|
||||
help='Path to configuration file for alternate models.',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--model',
|
||||
default='stable-diffusion-1.4',
|
||||
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
|
||||
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--png_compression','-z',
|
||||
@ -529,7 +528,7 @@ class Args(object):
|
||||
formatter_class=ArgFormatter,
|
||||
description=
|
||||
"""
|
||||
*Image generation:*
|
||||
*Image generation*
|
||||
invoke> a fantastic alien landscape -W576 -H512 -s60 -n4
|
||||
|
||||
*postprocessing*
|
||||
@ -544,6 +543,13 @@ class Args(object):
|
||||
!history lists all the commands issued during the current session.
|
||||
|
||||
!NN retrieves the NNth command from the history
|
||||
|
||||
*Model manipulation*
|
||||
!models -- list models in configs/models.yaml
|
||||
!switch <model_name> -- switch to model named <model_name>
|
||||
!import_model path/to/weights/file.ckpt -- adds a model to your config
|
||||
!edit_model <model_name> -- edit a model's description
|
||||
!del_model <model_name> -- delete a model
|
||||
"""
|
||||
)
|
||||
render_group = parser.add_argument_group('General rendering')
|
||||
|
@ -73,7 +73,8 @@ class ModelCache(object):
|
||||
except Exception as e:
|
||||
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||
print(f'** restoring {self.current_model}')
|
||||
return self.get_model(self.current_model)
|
||||
self.get_model(self.current_model)
|
||||
return None
|
||||
|
||||
self.current_model = model_name
|
||||
self._push_newest_model(model_name)
|
||||
@ -84,6 +85,26 @@ class ModelCache(object):
|
||||
'hash': hash
|
||||
}
|
||||
|
||||
def default_model(self) -> str:
|
||||
'''
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
'''
|
||||
for model_name in self.config:
|
||||
if self.config[model_name].get('default',False):
|
||||
return model_name
|
||||
return None
|
||||
|
||||
def set_default_model(self,model_name:str):
|
||||
'''
|
||||
Set the default model. The change will not take
|
||||
effect until you call model_cache.commit()
|
||||
'''
|
||||
assert model_name in self.models,f"unknown model '{model_name}'"
|
||||
for model in self.models:
|
||||
self.models[model].pop('default',None)
|
||||
self.models[model_name]['default'] = True
|
||||
|
||||
def list_models(self) -> dict:
|
||||
'''
|
||||
Return a dict of models in the format:
|
||||
@ -121,12 +142,23 @@ class ModelCache(object):
|
||||
else:
|
||||
print(line)
|
||||
|
||||
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str:
|
||||
def del_model(self, model_name:str) ->bool:
|
||||
'''
|
||||
Delete the named model.
|
||||
'''
|
||||
omega = self.config
|
||||
del omega[model_name]
|
||||
if model_name in self.stack:
|
||||
self.stack.remove(model_name)
|
||||
return True
|
||||
|
||||
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True:
|
||||
'''
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory and a YAML
|
||||
string will be returned.
|
||||
On a successful update, the config will be changed in memory and the
|
||||
method will return True. Will fail with an assertion error if provided
|
||||
attributes are incorrect or the model name is missing.
|
||||
'''
|
||||
omega = self.config
|
||||
# check that all the required fields are present
|
||||
@ -139,7 +171,9 @@ class ModelCache(object):
|
||||
config[field] = model_attributes[field]
|
||||
|
||||
omega[model_name] = config
|
||||
return OmegaConf.to_yaml(omega)
|
||||
if clobber:
|
||||
self._invalidate_cached_model(model_name)
|
||||
return True
|
||||
|
||||
def _check_memory(self):
|
||||
avail_memory = psutil.virtual_memory()[1]
|
||||
@ -219,6 +253,36 @@ class ModelCache(object):
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def commit(self,config_file_path:str):
|
||||
'''
|
||||
Write current configuration out to the indicated file.
|
||||
'''
|
||||
yaml_str = OmegaConf.to_yaml(self.config)
|
||||
tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp')
|
||||
with open(tmpfile, 'w') as outfile:
|
||||
outfile.write(self.preamble())
|
||||
outfile.write(yaml_str)
|
||||
os.rename(tmpfile,config_file_path)
|
||||
|
||||
def preamble(self):
|
||||
'''
|
||||
Returns the preamble for the config file.
|
||||
'''
|
||||
return '''# This file describes the alternative machine learning models
|
||||
# available to the dream script.
|
||||
#
|
||||
# To add a new model, follow the examples below. Each
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
'''
|
||||
|
||||
def _invalidate_cached_model(self,model_name:str):
|
||||
self.unload_model(model_name)
|
||||
if model_name in self.stack:
|
||||
self.stack.remove(model_name)
|
||||
self.models.pop(model_name,None)
|
||||
|
||||
def _model_to_cpu(self,model):
|
||||
if self.device != 'cpu':
|
||||
model.cond_stage_model.device = 'cpu'
|
||||
|
@ -57,12 +57,13 @@ COMMANDS = (
|
||||
'--png_compression','-z',
|
||||
'--text_mask','-tm',
|
||||
'!fix','!fetch','!replay','!history','!search','!clear',
|
||||
'!models','!switch','!import_model','!edit_model','!del_model',
|
||||
'!mask',
|
||||
'!models','!switch','!import_model','!edit_model'
|
||||
)
|
||||
MODEL_COMMANDS = (
|
||||
'!switch',
|
||||
'!edit_model',
|
||||
'!del_model',
|
||||
)
|
||||
WEIGHT_COMMANDS = (
|
||||
'!import_model',
|
||||
@ -218,9 +219,24 @@ class Completer(object):
|
||||
pydoc.pager('\n'.join(lines))
|
||||
|
||||
def set_line(self,line)->None:
|
||||
'''
|
||||
Set the default string displayed in the next line of input.
|
||||
'''
|
||||
self.linebuffer = line
|
||||
readline.redisplay()
|
||||
|
||||
def add_model(self,model_name:str)->None:
|
||||
'''
|
||||
add a model name to the completion list
|
||||
'''
|
||||
self.models.append(model_name)
|
||||
|
||||
def del_model(self,model_name:str)->None:
|
||||
'''
|
||||
removes a model name from the completion list
|
||||
'''
|
||||
self.models.remove(model_name)
|
||||
|
||||
def _seed_completions(self, text, state):
|
||||
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
||||
if m:
|
||||
|
@ -424,6 +424,15 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
elif command.startswith('!del'):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print('** please provide the name of a model')
|
||||
else:
|
||||
del_config(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
elif command.startswith('!fetch'):
|
||||
file_path = command.replace('!fetch','',1).strip()
|
||||
retrieve_dream_command(opt,file_path,completer)
|
||||
@ -498,9 +507,25 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
|
||||
except:
|
||||
print('** Please enter a valid integer between 64 and 2048')
|
||||
|
||||
if write_config_file(opt.conf, gen, model_name, new_config):
|
||||
gen.set_model(model_name)
|
||||
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
||||
|
||||
if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default):
|
||||
completer.add_model(model_name)
|
||||
|
||||
def del_config(model_name:str, gen, opt, completer):
|
||||
current_model = gen.model_name
|
||||
if model_name == current_model:
|
||||
print("** Can't delete active model. !switch to another model first. **")
|
||||
return
|
||||
yaml_str = gen.model_cache.del_model(model_name)
|
||||
|
||||
tmpfile = os.path.join(os.path.dirname(opt.conf),'new_config.tmp')
|
||||
with open(tmpfile, 'w') as outfile:
|
||||
outfile.write(yaml_str)
|
||||
os.rename(tmpfile,opt.conf)
|
||||
print(f'** {model_name} deleted')
|
||||
completer.del_model(model_name)
|
||||
|
||||
def edit_config(model_name:str, gen, opt, completer):
|
||||
config = gen.model_cache.config
|
||||
|
||||
@ -517,28 +542,41 @@ def edit_config(model_name:str, gen, opt, completer):
|
||||
completer.linebuffer = str(conf[field]) if field in conf else ''
|
||||
new_value = input(f'{field}: ')
|
||||
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
||||
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
||||
completer.complete_extensions(None)
|
||||
|
||||
if write_config_file(opt.conf, gen, model_name, new_config, clobber=True):
|
||||
gen.set_model(model_name)
|
||||
write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default)
|
||||
|
||||
def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False):
|
||||
current_model = gen.model_name
|
||||
|
||||
def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
|
||||
op = 'modify' if clobber else 'import'
|
||||
print('\n>> New configuration:')
|
||||
if make_default:
|
||||
new_config['default'] = True
|
||||
print(yaml.dump({model_name:new_config}))
|
||||
if input(f'OK to {op} [n]? ') not in ('y','Y'):
|
||||
return False
|
||||
|
||||
try:
|
||||
print('>> Verifying that new model loads...')
|
||||
yaml_str = gen.model_cache.add_model(model_name, new_config, clobber)
|
||||
assert gen.set_model(model_name) is not None, 'model failed to load'
|
||||
except AssertionError as e:
|
||||
print(f'** configuration failed: {str(e)}')
|
||||
print(f'** aborting **')
|
||||
gen.model_cache.del_model(model_name)
|
||||
return False
|
||||
|
||||
if make_default:
|
||||
print('making this default')
|
||||
gen.model_cache.set_default_model(model_name)
|
||||
|
||||
gen.model_cache.commit(conf_path)
|
||||
|
||||
tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp')
|
||||
with open(tmpfile, 'w') as outfile:
|
||||
outfile.write(yaml_str)
|
||||
os.rename(tmpfile,conf_path)
|
||||
do_switch = input(f'Keep model loaded? [y]')
|
||||
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
|
||||
pass
|
||||
else:
|
||||
gen.set_model(current_model)
|
||||
return True
|
||||
|
||||
def do_textmask(gen, opt, callback):
|
||||
|
Loading…
Reference in New Issue
Block a user