mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
enhance support for model switching and editing
- Error checks for invalid model - Add !del_model command to invoke.py - Add del_model() method to model_cache - Autocompleter kept in sync with model addition/subtraction.
This commit is contained in:
parent
fe2a2cfc8b
commit
a705a5a0aa
@ -683,8 +683,7 @@ class Generate:
|
|||||||
|
|
||||||
model_data = self.model_cache.get_model(model_name)
|
model_data = self.model_cache.get_model(model_name)
|
||||||
if model_data is None or len(model_data) == 0:
|
if model_data is None or len(model_data) == 0:
|
||||||
print(f'** Model switch failed **')
|
return None
|
||||||
return self.model
|
|
||||||
|
|
||||||
self.model = model_data['model']
|
self.model = model_data['model']
|
||||||
self.width = model_data['width']
|
self.width = model_data['width']
|
||||||
|
@ -519,7 +519,7 @@ class Args(object):
|
|||||||
formatter_class=ArgFormatter,
|
formatter_class=ArgFormatter,
|
||||||
description=
|
description=
|
||||||
"""
|
"""
|
||||||
*Image generation:*
|
*Image generation*
|
||||||
invoke> a fantastic alien landscape -W576 -H512 -s60 -n4
|
invoke> a fantastic alien landscape -W576 -H512 -s60 -n4
|
||||||
|
|
||||||
*postprocessing*
|
*postprocessing*
|
||||||
@ -534,6 +534,13 @@ class Args(object):
|
|||||||
!history lists all the commands issued during the current session.
|
!history lists all the commands issued during the current session.
|
||||||
|
|
||||||
!NN retrieves the NNth command from the history
|
!NN retrieves the NNth command from the history
|
||||||
|
|
||||||
|
*Model manipulation*
|
||||||
|
!models -- list models in configs/models.yaml
|
||||||
|
!switch <model_name> -- switch to model named <model_name>
|
||||||
|
!import_model path/to/weights/file.ckpt -- adds a model to your config
|
||||||
|
!edit_model <model_name> -- edit a model's description
|
||||||
|
!del_model <model_name> -- delete a model
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
render_group = parser.add_argument_group('General rendering')
|
render_group = parser.add_argument_group('General rendering')
|
||||||
|
@ -73,7 +73,8 @@ class ModelCache(object):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'** model {model_name} could not be loaded: {str(e)}')
|
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||||
print(f'** restoring {self.current_model}')
|
print(f'** restoring {self.current_model}')
|
||||||
return self.get_model(self.current_model)
|
self.get_model(self.current_model)
|
||||||
|
return None
|
||||||
|
|
||||||
self.current_model = model_name
|
self.current_model = model_name
|
||||||
self._push_newest_model(model_name)
|
self._push_newest_model(model_name)
|
||||||
@ -121,6 +122,16 @@ class ModelCache(object):
|
|||||||
else:
|
else:
|
||||||
print(line)
|
print(line)
|
||||||
|
|
||||||
|
def del_model(self, model_name:str) ->str:
|
||||||
|
'''
|
||||||
|
Delete the named model and return the YAML
|
||||||
|
'''
|
||||||
|
omega = self.config
|
||||||
|
del omega[model_name]
|
||||||
|
if model_name in self.stack:
|
||||||
|
self.stack.remove(model_name)
|
||||||
|
return OmegaConf.to_yaml(omega)
|
||||||
|
|
||||||
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str:
|
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str:
|
||||||
'''
|
'''
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
|
@ -53,11 +53,12 @@ COMMANDS = (
|
|||||||
'--log_tokenization','-t',
|
'--log_tokenization','-t',
|
||||||
'--hires_fix',
|
'--hires_fix',
|
||||||
'!fix','!fetch','!history','!search','!clear',
|
'!fix','!fetch','!history','!search','!clear',
|
||||||
'!models','!switch','!import_model','!edit_model'
|
'!models','!switch','!import_model','!edit_model','!del_model',
|
||||||
)
|
)
|
||||||
MODEL_COMMANDS = (
|
MODEL_COMMANDS = (
|
||||||
'!switch',
|
'!switch',
|
||||||
'!edit_model',
|
'!edit_model',
|
||||||
|
'!del_model',
|
||||||
)
|
)
|
||||||
WEIGHT_COMMANDS = (
|
WEIGHT_COMMANDS = (
|
||||||
'!import_model',
|
'!import_model',
|
||||||
@ -205,9 +206,24 @@ class Completer(object):
|
|||||||
pydoc.pager('\n'.join(lines))
|
pydoc.pager('\n'.join(lines))
|
||||||
|
|
||||||
def set_line(self,line)->None:
|
def set_line(self,line)->None:
|
||||||
|
'''
|
||||||
|
Set the default string displayed in the next line of input.
|
||||||
|
'''
|
||||||
self.linebuffer = line
|
self.linebuffer = line
|
||||||
readline.redisplay()
|
readline.redisplay()
|
||||||
|
|
||||||
|
def add_model(self,model_name:str)->None:
|
||||||
|
'''
|
||||||
|
add a model name to the completion list
|
||||||
|
'''
|
||||||
|
self.models.append(model_name)
|
||||||
|
|
||||||
|
def del_model(self,model_name:str)->None:
|
||||||
|
'''
|
||||||
|
removes a model name from the completion list
|
||||||
|
'''
|
||||||
|
self.models.remove(model_name)
|
||||||
|
|
||||||
def _seed_completions(self, text, state):
|
def _seed_completions(self, text, state):
|
||||||
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
||||||
if m:
|
if m:
|
||||||
|
@ -381,6 +381,15 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
|||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
operation = None
|
operation = None
|
||||||
|
|
||||||
|
elif command.startswith('!del'):
|
||||||
|
path = shlex.split(command)
|
||||||
|
if len(path) < 2:
|
||||||
|
print('** please provide the name of a model')
|
||||||
|
else:
|
||||||
|
del_config(path[1], gen, opt, completer)
|
||||||
|
completer.add_history(command)
|
||||||
|
operation = None
|
||||||
|
|
||||||
elif command.startswith('!fetch'):
|
elif command.startswith('!fetch'):
|
||||||
file_path = command.replace('!fetch ','',1)
|
file_path = command.replace('!fetch ','',1)
|
||||||
retrieve_dream_command(opt,file_path,completer)
|
retrieve_dream_command(opt,file_path,completer)
|
||||||
@ -446,10 +455,23 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
|
|||||||
done = True
|
done = True
|
||||||
except:
|
except:
|
||||||
print('** Please enter a valid integer between 64 and 2048')
|
print('** Please enter a valid integer between 64 and 2048')
|
||||||
|
|
||||||
if write_config_file(opt.conf, gen, model_name, new_config):
|
if write_config_file(opt.conf, gen, model_name, new_config):
|
||||||
gen.set_model(model_name)
|
completer.add_model(model_name)
|
||||||
|
|
||||||
|
def del_config(model_name:str, gen, opt, completer):
|
||||||
|
current_model = gen.model_name
|
||||||
|
if model_name == current_model:
|
||||||
|
print("** Can't delete active model. !switch to another model first. **")
|
||||||
|
return
|
||||||
|
yaml_str = gen.model_cache.del_model(model_name)
|
||||||
|
|
||||||
|
tmpfile = os.path.join(os.path.dirname(opt.conf),'new_config.tmp')
|
||||||
|
with open(tmpfile, 'w') as outfile:
|
||||||
|
outfile.write(yaml_str)
|
||||||
|
os.rename(tmpfile,opt.conf)
|
||||||
|
print(f'** {model_name} deleted')
|
||||||
|
completer.del_model(model_name)
|
||||||
|
|
||||||
def edit_config(model_name:str, gen, opt, completer):
|
def edit_config(model_name:str, gen, opt, completer):
|
||||||
config = gen.model_cache.config
|
config = gen.model_cache.config
|
||||||
|
|
||||||
@ -467,11 +489,11 @@ def edit_config(model_name:str, gen, opt, completer):
|
|||||||
new_value = input(f'{field}: ')
|
new_value = input(f'{field}: ')
|
||||||
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
||||||
completer.complete_extensions(None)
|
completer.complete_extensions(None)
|
||||||
|
write_config_file(opt.conf, gen, model_name, new_config, clobber=True)
|
||||||
if write_config_file(opt.conf, gen, model_name, new_config, clobber=True):
|
|
||||||
gen.set_model(model_name)
|
|
||||||
|
|
||||||
def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
|
def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
|
||||||
|
current_model = gen.model_name
|
||||||
|
|
||||||
op = 'modify' if clobber else 'import'
|
op = 'modify' if clobber else 'import'
|
||||||
print('\n>> New configuration:')
|
print('\n>> New configuration:')
|
||||||
print(yaml.dump({model_name:new_config}))
|
print(yaml.dump({model_name:new_config}))
|
||||||
@ -479,15 +501,24 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
print('>> Verifying that new model loads...')
|
||||||
yaml_str = gen.model_cache.add_model(model_name, new_config, clobber)
|
yaml_str = gen.model_cache.add_model(model_name, new_config, clobber)
|
||||||
|
assert gen.set_model(model_name) is not None, 'model failed to load'
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
print(f'** configuration failed: {str(e)}')
|
print(f'** aborting **')
|
||||||
|
gen.model_cache.del_model(model_name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp')
|
tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp')
|
||||||
with open(tmpfile, 'w') as outfile:
|
with open(tmpfile, 'w') as outfile:
|
||||||
outfile.write(yaml_str)
|
outfile.write(yaml_str)
|
||||||
os.rename(tmpfile,conf_path)
|
os.rename(tmpfile,conf_path)
|
||||||
|
|
||||||
|
do_switch = input(f'Keep model loaded? [y]')
|
||||||
|
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
gen.set_model(current_model)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def do_postprocess (gen, opt, callback):
|
def do_postprocess (gen, opt, callback):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user