mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add ability to import and edit alternative models online
- !import_model <path/to/model/weights> will import a new model, prompt the user for its name and description, write it to the models.yaml file, and load it. - !edit_model <model_name> will bring up a previously-defined model and prompt the user to edit its descriptive fields. Example of !import_model <pre> invoke> <b>!import_model models/ldm/stable-diffusion-v1/model-epoch08-float16.ckpt</b> >> Model import in process. Please enter the values needed to configure this model: Name for this model: <b>waifu-diffusion</b> Description of this model: <b>Waifu Diffusion v1.3</b> Configuration file for this model: <b>configs/stable-diffusion/v1-inference.yaml</b> Default image width: <b>512</b> Default image height: <b>512</b> >> New configuration: waifu-diffusion: config: configs/stable-diffusion/v1-inference.yaml description: Waifu Diffusion v1.3 height: 512 weights: models/ldm/stable-diffusion-v1/model-epoch08-float16.ckpt width: 512 OK to import [n]? <b>y</b> >> Caching model stable-diffusion-1.4 in system RAM >> Loading waifu-diffusion from models/ldm/stable-diffusion-v1/model-epoch08-float16.ckpt | LatentDiffusion: Running in eps-prediction mode | DiffusionWrapper has 859.52 M params. | Making attention of type 'vanilla' with 512 in_channels | Working with z of shape (1, 4, 32, 32) = 4096 dimensions. | Making attention of type 'vanilla' with 512 in_channels | Using faster float16 precision </pre> Example of !edit_model <pre> invoke> <b>!edit_model waifu-diffusion</b> >> Editing model waifu-diffusion from configuration file ./configs/models.yaml description: <b>Waifu diffusion v1.4beta</b> weights: models/ldm/stable-diffusion-v1/<b>model-epoch10-float16.ckpt</b> config: configs/stable-diffusion/v1-inference.yaml width: 512 height: 512 >> New configuration: waifu-diffusion: config: configs/stable-diffusion/v1-inference.yaml description: Waifu diffusion v1.4beta weights: models/ldm/stable-diffusion-v1/model-epoch10-float16.ckpt height: 512 width: 512 OK to import [n]? y >> Caching model stable-diffusion-1.4 in system RAM >> Loading waifu-diffusion from models/ldm/stable-diffusion-v1/model-epoch10-float16.ckpt ... </pre>
This commit is contained in:
@ -9,6 +9,7 @@ import copy
|
||||
import warnings
|
||||
import time
|
||||
import traceback
|
||||
import yaml
|
||||
sys.path.append('.') # corrects a weird problem on Macs
|
||||
from ldm.invoke.readline import get_completer
|
||||
from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png
|
||||
@ -108,6 +109,7 @@ def main_loop(gen, opt, infile):
|
||||
# output directory specified at the time of script launch. We do not currently support
|
||||
# changing the history file midstream when the output directory is changed.
|
||||
completer = get_completer(opt, models=list(model_config.keys()))
|
||||
completer.set_default_dir(opt.outdir)
|
||||
output_cntr = completer.get_current_history_length()+1
|
||||
|
||||
# os.pathconf is not available on Windows
|
||||
@ -119,11 +121,9 @@ def main_loop(gen, opt, infile):
|
||||
name_max = 255
|
||||
|
||||
while not done:
|
||||
operation = 'generate' # default operation, alternative is 'postprocess'
|
||||
|
||||
if completer:
|
||||
completer.set_default_dir(opt.outdir)
|
||||
|
||||
operation = 'generate'
|
||||
|
||||
try:
|
||||
command = get_next_command(infile)
|
||||
except EOFError:
|
||||
@ -142,52 +142,10 @@ def main_loop(gen, opt, infile):
|
||||
break
|
||||
|
||||
if command.startswith('!'):
|
||||
subcommand = command[1:]
|
||||
command, operation = do_command(command, gen, opt, completer)
|
||||
|
||||
if subcommand.startswith('dream'): # in case a stored prompt still contains the !dream command
|
||||
command = command.replace('!dream ','',1)
|
||||
|
||||
elif subcommand.startswith('fix'):
|
||||
command = command.replace('!fix ','',1)
|
||||
operation = 'postprocess'
|
||||
|
||||
elif subcommand.startswith('switch'):
|
||||
model_name = command.replace('!switch ','',1)
|
||||
gen.set_model(model_name)
|
||||
completer.add_history(command)
|
||||
continue
|
||||
|
||||
elif subcommand.startswith('models'):
|
||||
model_name = command.replace('!models ','',1)
|
||||
gen.model_cache.print_models()
|
||||
continue
|
||||
|
||||
elif subcommand.startswith('fetch'):
|
||||
file_path = command.replace('!fetch ','',1)
|
||||
retrieve_dream_command(opt,file_path,completer)
|
||||
continue
|
||||
|
||||
elif subcommand.startswith('history'):
|
||||
completer.show_history()
|
||||
continue
|
||||
|
||||
elif subcommand.startswith('search'):
|
||||
search_str = command.replace('!search ','',1)
|
||||
completer.show_history(search_str)
|
||||
continue
|
||||
|
||||
elif subcommand.startswith('clear'):
|
||||
completer.clear_history()
|
||||
continue
|
||||
|
||||
elif re.match('^(\d+)',subcommand):
|
||||
command_no = re.match('^(\d+)',subcommand).groups()[0]
|
||||
command = completer.get_line(int(command_no))
|
||||
completer.set_line(command)
|
||||
continue
|
||||
|
||||
else: # not a recognized subcommand, so give the --help text
|
||||
command = '-h'
|
||||
if operation is None:
|
||||
continue
|
||||
|
||||
if opt.parse_cmd(command) is None:
|
||||
continue
|
||||
@ -381,6 +339,155 @@ def main_loop(gen, opt, infile):
|
||||
|
||||
print('goodbye!')
|
||||
|
||||
def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
||||
operation = 'generate' # default operation, alternative is 'postprocess'
|
||||
|
||||
if command.startswith('!dream'): # in case a stored prompt still contains the !dream command
|
||||
command = command.replace('!dream ','',1)
|
||||
|
||||
elif command.startswith('!fix'):
|
||||
command = command.replace('!fix ','',1)
|
||||
operation = 'postprocess'
|
||||
|
||||
elif command.startswith('!switch'):
|
||||
model_name = command.replace('!switch ','',1)
|
||||
gen.set_model(model_name)
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
elif command.startswith('!models'):
|
||||
gen.model_cache.print_models()
|
||||
operation = None
|
||||
|
||||
elif command.startswith('!import'):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print('** please provide a path to a .ckpt or .vae model file')
|
||||
elif not os.path.exists(path[1]):
|
||||
print(f'** {path[1]}: file not found')
|
||||
else:
|
||||
add_weights_to_config(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
elif command.startswith('!edit'):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print('** please provide the name of a model')
|
||||
else:
|
||||
edit_config(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
elif command.startswith('!fetch'):
|
||||
file_path = command.replace('!fetch ','',1)
|
||||
retrieve_dream_command(opt,file_path,completer)
|
||||
operation = None
|
||||
|
||||
elif command.startswith('!history'):
|
||||
completer.show_history()
|
||||
operation = None
|
||||
|
||||
elif command.startswith('!search'):
|
||||
search_str = command.replace('!search ','',1)
|
||||
completer.show_history(search_str)
|
||||
operation = None
|
||||
|
||||
elif command.startswith('!clear'):
|
||||
completer.clear_history()
|
||||
operation = None
|
||||
|
||||
elif re.match('^!(\d+)',command):
|
||||
command_no = re.match('^!(\d+)',command).groups()[0]
|
||||
command = completer.get_line(int(command_no))
|
||||
completer.set_line(command)
|
||||
operation = None
|
||||
|
||||
else: # not a recognized command, so give the --help text
|
||||
command = '-h'
|
||||
return command, operation
|
||||
|
||||
def add_weights_to_config(model_path:str, gen, opt, completer):
|
||||
print(f'>> Model import in process. Please enter the values needed to configure this model:')
|
||||
print()
|
||||
|
||||
new_config = {}
|
||||
new_config['weights'] = model_path
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
model_name = input('Name for this model: ')
|
||||
if not re.match('^[\w._-]+$',model_name):
|
||||
print('** model name must contain only words, digits and the characters [._-] **')
|
||||
else:
|
||||
done = True
|
||||
new_config['description'] = input('Description of this model: ')
|
||||
|
||||
completer.complete_extensions(('.yaml','.yml'))
|
||||
completer.linebuffer = 'configs/stable-diffusion/v1-inference.yaml'
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
new_config['config'] = input('Configuration file for this model: ')
|
||||
done = os.path.exists(new_config['config'])
|
||||
|
||||
completer.complete_extensions(None)
|
||||
|
||||
for field in ('width','height'):
|
||||
done = False
|
||||
while not done:
|
||||
try:
|
||||
completer.linebuffer = '512'
|
||||
value = int(input(f'Default image {field}: '))
|
||||
assert value >= 64 and value <= 2048
|
||||
new_config[field] = value
|
||||
done = True
|
||||
except:
|
||||
print('** Please enter a valid integer between 64 and 2048')
|
||||
|
||||
if write_config_file(opt.conf, gen, model_name, new_config):
|
||||
gen.set_model(model_name)
|
||||
|
||||
def edit_config(model_name:str, gen, opt, completer):
|
||||
config = gen.model_cache.config
|
||||
|
||||
if model_name not in config:
|
||||
print(f'** Unknown model {model_name}')
|
||||
return
|
||||
|
||||
print(f'\n>> Editing model {model_name} from configuration file {opt.conf}')
|
||||
|
||||
conf = config[model_name]
|
||||
new_config = {}
|
||||
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae'))
|
||||
for field in ('description', 'weights', 'config', 'width','height'):
|
||||
completer.linebuffer = str(conf[field]) if field in conf else ''
|
||||
new_value = input(f'{field}: ')
|
||||
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
||||
completer.complete_extensions(None)
|
||||
|
||||
if write_config_file(opt.conf, gen, model_name, new_config, clobber=True):
|
||||
gen.set_model(model_name)
|
||||
|
||||
def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
|
||||
op = 'modify' if clobber else 'import'
|
||||
print('\n>> New configuration:')
|
||||
print(yaml.dump({model_name:new_config}))
|
||||
if input(f'OK to {op} [n]? ') not in ('y','Y'):
|
||||
return False
|
||||
|
||||
try:
|
||||
yaml_str = gen.model_cache.add_model(model_name, new_config, clobber)
|
||||
except AssertionError as e:
|
||||
print(f'** configuration failed: {str(e)}')
|
||||
return False
|
||||
|
||||
tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp')
|
||||
with open(tmpfile, 'w') as outfile:
|
||||
outfile.write(yaml_str)
|
||||
os.rename(tmpfile,conf_path)
|
||||
return True
|
||||
|
||||
def do_postprocess (gen, opt, callback):
|
||||
file_path = opt.prompt # treat the prompt as the file pathname
|
||||
if os.path.dirname(file_path) == '': #basename given
|
||||
|
Reference in New Issue
Block a user