mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
enhance support for model switching and editing
- Error checks for invalid model - Add !del_model command to invoke.py - Add del_model() method to model_cache - Autocompleter kept in sync with model addition/subtraction.
This commit is contained in:
@ -381,6 +381,15 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
elif command.startswith('!del'):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print('** please provide the name of a model')
|
||||
else:
|
||||
del_config(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
elif command.startswith('!fetch'):
|
||||
file_path = command.replace('!fetch ','',1)
|
||||
retrieve_dream_command(opt,file_path,completer)
|
||||
@ -446,10 +455,23 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
|
||||
done = True
|
||||
except:
|
||||
print('** Please enter a valid integer between 64 and 2048')
|
||||
|
||||
if write_config_file(opt.conf, gen, model_name, new_config):
|
||||
gen.set_model(model_name)
|
||||
completer.add_model(model_name)
|
||||
|
||||
def del_config(model_name:str, gen, opt, completer):
|
||||
current_model = gen.model_name
|
||||
if model_name == current_model:
|
||||
print("** Can't delete active model. !switch to another model first. **")
|
||||
return
|
||||
yaml_str = gen.model_cache.del_model(model_name)
|
||||
|
||||
tmpfile = os.path.join(os.path.dirname(opt.conf),'new_config.tmp')
|
||||
with open(tmpfile, 'w') as outfile:
|
||||
outfile.write(yaml_str)
|
||||
os.rename(tmpfile,opt.conf)
|
||||
print(f'** {model_name} deleted')
|
||||
completer.del_model(model_name)
|
||||
|
||||
def edit_config(model_name:str, gen, opt, completer):
|
||||
config = gen.model_cache.config
|
||||
|
||||
@ -467,11 +489,11 @@ def edit_config(model_name:str, gen, opt, completer):
|
||||
new_value = input(f'{field}: ')
|
||||
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
||||
completer.complete_extensions(None)
|
||||
|
||||
if write_config_file(opt.conf, gen, model_name, new_config, clobber=True):
|
||||
gen.set_model(model_name)
|
||||
write_config_file(opt.conf, gen, model_name, new_config, clobber=True)
|
||||
|
||||
def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
|
||||
current_model = gen.model_name
|
||||
|
||||
op = 'modify' if clobber else 'import'
|
||||
print('\n>> New configuration:')
|
||||
print(yaml.dump({model_name:new_config}))
|
||||
@ -479,15 +501,24 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
|
||||
return False
|
||||
|
||||
try:
|
||||
print('>> Verifying that new model loads...')
|
||||
yaml_str = gen.model_cache.add_model(model_name, new_config, clobber)
|
||||
assert gen.set_model(model_name) is not None, 'model failed to load'
|
||||
except AssertionError as e:
|
||||
print(f'** configuration failed: {str(e)}')
|
||||
print(f'** aborting **')
|
||||
gen.model_cache.del_model(model_name)
|
||||
return False
|
||||
|
||||
tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp')
|
||||
with open(tmpfile, 'w') as outfile:
|
||||
outfile.write(yaml_str)
|
||||
os.rename(tmpfile,conf_path)
|
||||
|
||||
do_switch = input(f'Keep model loaded? [y]')
|
||||
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
|
||||
pass
|
||||
else:
|
||||
gen.set_model(current_model)
|
||||
return True
|
||||
|
||||
def do_postprocess (gen, opt, callback):
|
||||
|
Reference in New Issue
Block a user