enhance support for model switching and editing

- Error checks for invalid model
- Add !del_model command to invoke.py
- Add del_model() method to model_cache
- Autocompleter kept in sync with model addition/subtraction.
This commit is contained in:
Lincoln Stein 2022-10-15 15:46:29 -04:00
parent fe2a2cfc8b
commit a705a5a0aa
5 changed files with 75 additions and 11 deletions

View File

@ -683,8 +683,7 @@ class Generate:
model_data = self.model_cache.get_model(model_name) model_data = self.model_cache.get_model(model_name)
if model_data is None or len(model_data) == 0: if model_data is None or len(model_data) == 0:
print(f'** Model switch failed **') return None
return self.model
self.model = model_data['model'] self.model = model_data['model']
self.width = model_data['width'] self.width = model_data['width']

View File

@ -519,7 +519,7 @@ class Args(object):
formatter_class=ArgFormatter, formatter_class=ArgFormatter,
description= description=
""" """
*Image generation:* *Image generation*
invoke> a fantastic alien landscape -W576 -H512 -s60 -n4 invoke> a fantastic alien landscape -W576 -H512 -s60 -n4
*postprocessing* *postprocessing*
@ -534,6 +534,13 @@ class Args(object):
!history lists all the commands issued during the current session. !history lists all the commands issued during the current session.
!NN retrieves the NNth command from the history !NN retrieves the NNth command from the history
*Model manipulation*
!models -- list models in configs/models.yaml
!switch <model_name> -- switch to model named <model_name>
!import_model path/to/weights/file.ckpt -- adds a model to your config
!edit_model <model_name> -- edit a model's description
!del_model <model_name> -- delete a model
""" """
) )
render_group = parser.add_argument_group('General rendering') render_group = parser.add_argument_group('General rendering')

View File

@ -73,7 +73,8 @@ class ModelCache(object):
except Exception as e: except Exception as e:
print(f'** model {model_name} could not be loaded: {str(e)}') print(f'** model {model_name} could not be loaded: {str(e)}')
print(f'** restoring {self.current_model}') print(f'** restoring {self.current_model}')
return self.get_model(self.current_model) self.get_model(self.current_model)
return None
self.current_model = model_name self.current_model = model_name
self._push_newest_model(model_name) self._push_newest_model(model_name)
@ -121,6 +122,16 @@ class ModelCache(object):
else: else:
print(line) print(line)
def del_model(self, model_name:str) ->str:
'''
Delete the named model and return the YAML
'''
omega = self.config
del omega[model_name]
if model_name in self.stack:
self.stack.remove(model_name)
return OmegaConf.to_yaml(omega)
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str: def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str:
''' '''
Update the named model with a dictionary of attributes. Will fail with an Update the named model with a dictionary of attributes. Will fail with an

View File

@ -53,11 +53,12 @@ COMMANDS = (
'--log_tokenization','-t', '--log_tokenization','-t',
'--hires_fix', '--hires_fix',
'!fix','!fetch','!history','!search','!clear', '!fix','!fetch','!history','!search','!clear',
'!models','!switch','!import_model','!edit_model' '!models','!switch','!import_model','!edit_model','!del_model',
) )
MODEL_COMMANDS = ( MODEL_COMMANDS = (
'!switch', '!switch',
'!edit_model', '!edit_model',
'!del_model',
) )
WEIGHT_COMMANDS = ( WEIGHT_COMMANDS = (
'!import_model', '!import_model',
@ -205,9 +206,24 @@ class Completer(object):
pydoc.pager('\n'.join(lines)) pydoc.pager('\n'.join(lines))
def set_line(self,line)->None: def set_line(self,line)->None:
'''
Set the default string displayed in the next line of input.
'''
self.linebuffer = line self.linebuffer = line
readline.redisplay() readline.redisplay()
def add_model(self,model_name:str)->None:
'''
add a model name to the completion list
'''
self.models.append(model_name)
def del_model(self,model_name:str)->None:
'''
removes a model name from the completion list
'''
self.models.remove(model_name)
def _seed_completions(self, text, state): def _seed_completions(self, text, state):
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text) m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
if m: if m:

View File

@ -381,6 +381,15 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
completer.add_history(command) completer.add_history(command)
operation = None operation = None
elif command.startswith('!del'):
path = shlex.split(command)
if len(path) < 2:
print('** please provide the name of a model')
else:
del_config(path[1], gen, opt, completer)
completer.add_history(command)
operation = None
elif command.startswith('!fetch'): elif command.startswith('!fetch'):
file_path = command.replace('!fetch ','',1) file_path = command.replace('!fetch ','',1)
retrieve_dream_command(opt,file_path,completer) retrieve_dream_command(opt,file_path,completer)
@ -446,10 +455,23 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
done = True done = True
except: except:
print('** Please enter a valid integer between 64 and 2048') print('** Please enter a valid integer between 64 and 2048')
if write_config_file(opt.conf, gen, model_name, new_config): if write_config_file(opt.conf, gen, model_name, new_config):
gen.set_model(model_name) completer.add_model(model_name)
def del_config(model_name:str, gen, opt, completer):
current_model = gen.model_name
if model_name == current_model:
print("** Can't delete active model. !switch to another model first. **")
return
yaml_str = gen.model_cache.del_model(model_name)
tmpfile = os.path.join(os.path.dirname(opt.conf),'new_config.tmp')
with open(tmpfile, 'w') as outfile:
outfile.write(yaml_str)
os.rename(tmpfile,opt.conf)
print(f'** {model_name} deleted')
completer.del_model(model_name)
def edit_config(model_name:str, gen, opt, completer): def edit_config(model_name:str, gen, opt, completer):
config = gen.model_cache.config config = gen.model_cache.config
@ -467,11 +489,11 @@ def edit_config(model_name:str, gen, opt, completer):
new_value = input(f'{field}: ') new_value = input(f'{field}: ')
new_config[field] = int(new_value) if field in ('width','height') else new_value new_config[field] = int(new_value) if field in ('width','height') else new_value
completer.complete_extensions(None) completer.complete_extensions(None)
write_config_file(opt.conf, gen, model_name, new_config, clobber=True)
if write_config_file(opt.conf, gen, model_name, new_config, clobber=True):
gen.set_model(model_name)
def write_config_file(conf_path, gen, model_name, new_config, clobber=False): def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
current_model = gen.model_name
op = 'modify' if clobber else 'import' op = 'modify' if clobber else 'import'
print('\n>> New configuration:') print('\n>> New configuration:')
print(yaml.dump({model_name:new_config})) print(yaml.dump({model_name:new_config}))
@ -479,15 +501,24 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
return False return False
try: try:
print('>> Verifying that new model loads...')
yaml_str = gen.model_cache.add_model(model_name, new_config, clobber) yaml_str = gen.model_cache.add_model(model_name, new_config, clobber)
assert gen.set_model(model_name) is not None, 'model failed to load'
except AssertionError as e: except AssertionError as e:
print(f'** configuration failed: {str(e)}') print(f'** aborting **')
gen.model_cache.del_model(model_name)
return False return False
tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp') tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp')
with open(tmpfile, 'w') as outfile: with open(tmpfile, 'w') as outfile:
outfile.write(yaml_str) outfile.write(yaml_str)
os.rename(tmpfile,conf_path) os.rename(tmpfile,conf_path)
do_switch = input(f'Keep model loaded? [y]')
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
pass
else:
gen.set_model(current_model)
return True return True
def do_postprocess (gen, opt, callback): def do_postprocess (gen, opt, callback):