add ability to import and edit alternative models online

- !import_model <path/to/model/weights> will import a new model,
  prompt the user for its name and description, write it to the
  models.yaml file, and load it.

- !edit_model <model_name> will bring up a previously-defined model
  and prompt the user to edit its descriptive fields.

Example of !import_model

<pre>
invoke> <b>!import_model models/ldm/stable-diffusion-v1/model-epoch08-float16.ckpt</b>
>> Model import in process. Please enter the values needed to configure this model:

Name for this model: <b>waifu-diffusion</b>
Description of this model: <b>Waifu Diffusion v1.3</b>
Configuration file for this model: <b>configs/stable-diffusion/v1-inference.yaml</b>
Default image width: <b>512</b>
Default image height: <b>512</b>
>> New configuration:
waifu-diffusion:
  config: configs/stable-diffusion/v1-inference.yaml
  description: Waifu Diffusion v1.3
  height: 512
  weights: models/ldm/stable-diffusion-v1/model-epoch08-float16.ckpt
  width: 512
OK to import [n]? <b>y</b>
>> Caching model stable-diffusion-1.4 in system RAM
>> Loading waifu-diffusion from models/ldm/stable-diffusion-v1/model-epoch08-float16.ckpt
   | LatentDiffusion: Running in eps-prediction mode
   | DiffusionWrapper has 859.52 M params.
   | Making attention of type 'vanilla' with 512 in_channels
   | Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
   | Making attention of type 'vanilla' with 512 in_channels
   | Using faster float16 precision
</pre>

Example of !edit_model

<pre>
invoke> <b>!edit_model waifu-diffusion</b>
>> Editing model waifu-diffusion from configuration file ./configs/models.yaml
description: <b>Waifu diffusion v1.4beta</b>
weights: models/ldm/stable-diffusion-v1/<b>model-epoch10-float16.ckpt</b>
config: configs/stable-diffusion/v1-inference.yaml
width: 512
height: 512

>> New configuration:
waifu-diffusion:
  config: configs/stable-diffusion/v1-inference.yaml
  description: Waifu diffusion v1.4beta
  weights: models/ldm/stable-diffusion-v1/model-epoch10-float16.ckpt
  height: 512
  width: 512

OK to import [n]? y
>> Caching model stable-diffusion-1.4 in system RAM
>> Loading waifu-diffusion from models/ldm/stable-diffusion-v1/model-epoch10-float16.ckpt
...
</pre>
This commit is contained in:
Lincoln Stein
2022-10-13 23:48:07 -04:00
parent 916f5bfbb2
commit 6afc0f9b38
7 changed files with 380 additions and 75 deletions

View File

@ -9,6 +9,7 @@ import copy
import warnings
import time
import traceback
import yaml
sys.path.append('.') # corrects a weird problem on Macs
from ldm.invoke.readline import get_completer
from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png
@ -108,6 +109,7 @@ def main_loop(gen, opt, infile):
# output directory specified at the time of script launch. We do not currently support
# changing the history file midstream when the output directory is changed.
completer = get_completer(opt, models=list(model_config.keys()))
completer.set_default_dir(opt.outdir)
output_cntr = completer.get_current_history_length()+1
# os.pathconf is not available on Windows
@ -119,11 +121,9 @@ def main_loop(gen, opt, infile):
name_max = 255
while not done:
operation = 'generate' # default operation, alternative is 'postprocess'
if completer:
completer.set_default_dir(opt.outdir)
operation = 'generate'
try:
command = get_next_command(infile)
except EOFError:
@ -142,52 +142,10 @@ def main_loop(gen, opt, infile):
break
if command.startswith('!'):
subcommand = command[1:]
command, operation = do_command(command, gen, opt, completer)
if subcommand.startswith('dream'): # in case a stored prompt still contains the !dream command
command = command.replace('!dream ','',1)
elif subcommand.startswith('fix'):
command = command.replace('!fix ','',1)
operation = 'postprocess'
elif subcommand.startswith('switch'):
model_name = command.replace('!switch ','',1)
gen.set_model(model_name)
completer.add_history(command)
continue
elif subcommand.startswith('models'):
model_name = command.replace('!models ','',1)
gen.model_cache.print_models()
continue
elif subcommand.startswith('fetch'):
file_path = command.replace('!fetch ','',1)
retrieve_dream_command(opt,file_path,completer)
continue
elif subcommand.startswith('history'):
completer.show_history()
continue
elif subcommand.startswith('search'):
search_str = command.replace('!search ','',1)
completer.show_history(search_str)
continue
elif subcommand.startswith('clear'):
completer.clear_history()
continue
elif re.match('^(\d+)',subcommand):
command_no = re.match('^(\d+)',subcommand).groups()[0]
command = completer.get_line(int(command_no))
completer.set_line(command)
continue
else: # not a recognized subcommand, so give the --help text
command = '-h'
if operation is None:
continue
if opt.parse_cmd(command) is None:
continue
@ -381,6 +339,155 @@ def main_loop(gen, opt, infile):
print('goodbye!')
def do_command(command:str, gen, opt:Args, completer) -> tuple:
operation = 'generate' # default operation, alternative is 'postprocess'
if command.startswith('!dream'): # in case a stored prompt still contains the !dream command
command = command.replace('!dream ','',1)
elif command.startswith('!fix'):
command = command.replace('!fix ','',1)
operation = 'postprocess'
elif command.startswith('!switch'):
model_name = command.replace('!switch ','',1)
gen.set_model(model_name)
completer.add_history(command)
operation = None
elif command.startswith('!models'):
gen.model_cache.print_models()
operation = None
elif command.startswith('!import'):
path = shlex.split(command)
if len(path) < 2:
print('** please provide a path to a .ckpt or .vae model file')
elif not os.path.exists(path[1]):
print(f'** {path[1]}: file not found')
else:
add_weights_to_config(path[1], gen, opt, completer)
completer.add_history(command)
operation = None
elif command.startswith('!edit'):
path = shlex.split(command)
if len(path) < 2:
print('** please provide the name of a model')
else:
edit_config(path[1], gen, opt, completer)
completer.add_history(command)
operation = None
elif command.startswith('!fetch'):
file_path = command.replace('!fetch ','',1)
retrieve_dream_command(opt,file_path,completer)
operation = None
elif command.startswith('!history'):
completer.show_history()
operation = None
elif command.startswith('!search'):
search_str = command.replace('!search ','',1)
completer.show_history(search_str)
operation = None
elif command.startswith('!clear'):
completer.clear_history()
operation = None
elif re.match('^!(\d+)',command):
command_no = re.match('^!(\d+)',command).groups()[0]
command = completer.get_line(int(command_no))
completer.set_line(command)
operation = None
else: # not a recognized command, so give the --help text
command = '-h'
return command, operation
def add_weights_to_config(model_path:str, gen, opt, completer):
print(f'>> Model import in process. Please enter the values needed to configure this model:')
print()
new_config = {}
new_config['weights'] = model_path
done = False
while not done:
model_name = input('Name for this model: ')
if not re.match('^[\w._-]+$',model_name):
print('** model name must contain only words, digits and the characters [._-] **')
else:
done = True
new_config['description'] = input('Description of this model: ')
completer.complete_extensions(('.yaml','.yml'))
completer.linebuffer = 'configs/stable-diffusion/v1-inference.yaml'
done = False
while not done:
new_config['config'] = input('Configuration file for this model: ')
done = os.path.exists(new_config['config'])
completer.complete_extensions(None)
for field in ('width','height'):
done = False
while not done:
try:
completer.linebuffer = '512'
value = int(input(f'Default image {field}: '))
assert value >= 64 and value <= 2048
new_config[field] = value
done = True
except:
print('** Please enter a valid integer between 64 and 2048')
if write_config_file(opt.conf, gen, model_name, new_config):
gen.set_model(model_name)
def edit_config(model_name:str, gen, opt, completer):
config = gen.model_cache.config
if model_name not in config:
print(f'** Unknown model {model_name}')
return
print(f'\n>> Editing model {model_name} from configuration file {opt.conf}')
conf = config[model_name]
new_config = {}
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae'))
for field in ('description', 'weights', 'config', 'width','height'):
completer.linebuffer = str(conf[field]) if field in conf else ''
new_value = input(f'{field}: ')
new_config[field] = int(new_value) if field in ('width','height') else new_value
completer.complete_extensions(None)
if write_config_file(opt.conf, gen, model_name, new_config, clobber=True):
gen.set_model(model_name)
def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
op = 'modify' if clobber else 'import'
print('\n>> New configuration:')
print(yaml.dump({model_name:new_config}))
if input(f'OK to {op} [n]? ') not in ('y','Y'):
return False
try:
yaml_str = gen.model_cache.add_model(model_name, new_config, clobber)
except AssertionError as e:
print(f'** configuration failed: {str(e)}')
return False
tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp')
with open(tmpfile, 'w') as outfile:
outfile.write(yaml_str)
os.rename(tmpfile,conf_path)
return True
def do_postprocess (gen, opt, callback):
file_path = opt.prompt # treat the prompt as the file pathname
if os.path.dirname(file_path) == '': #basename given