enhance console gui for invokeai-merge

- Added modest adaptive behavior; if the screen is wide enough the three
  checklists of models will be arranged in a horizontal row.
- Added color support
This commit is contained in:
Lincoln Stein 2023-02-02 20:26:45 -05:00
parent 0642728484
commit b9aef33ae8

View File

@ -5,6 +5,7 @@ used to merge 2-3 models together and create a new InvokeAI-registered diffusion
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import argparse
import curses
import os
import sys
from argparse import Namespace
@ -12,6 +13,7 @@ from pathlib import Path
from typing import List, Union
import npyscreen
import warnings
from diffusers import DiffusionPipeline
from omegaconf import OmegaConf
@ -26,7 +28,6 @@ from ldm.invoke.model_manager import ModelManager
DEST_MERGED_MODEL_DIR = "merged_models"
def merge_diffusion_models(
model_ids_or_paths: List[Union[str, Path]],
alpha: float = 0.5,
@ -185,6 +186,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
def __init__(self, parentApp, name):
self.parentApp = parentApp
self.ALLOW_RESIZE=True
self.FIX_MINIMUM_SIZE_WHEN_CREATED=False
super().__init__(parentApp, name)
@property
@ -195,29 +198,94 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
self.parentApp.setNextForm(None)
def create(self):
window_height,window_width=curses.initscr().getmaxyx()
self.model_names = self.get_model_names()
max_width = max([len(x) for x in self.model_names])
max_width += 6
horizontal_layout = max_width*3 < window_width
self.add_widget_intelligent(
npyscreen.FixedText, name="Select up to three models to merge", value=""
npyscreen.FixedText,
color='CONTROL',
value=f"Select two models to merge and optionally a third.",
editable=False,
)
self.models = self.add_widget_intelligent(
npyscreen.TitleMultiSelect,
name="Select two to three models to merge:",
self.add_widget_intelligent(
npyscreen.FixedText,
color='CONTROL',
value=f"Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
editable=False,
)
self.add_widget_intelligent(
npyscreen.FixedText,
value='MODEL 1',
color='GOOD',
editable=False,
rely=4 if horizontal_layout else None,
)
self.model1 = self.add_widget_intelligent(
npyscreen.SelectOne,
values=self.model_names,
value=None,
max_height=len(self.model_names) + 1,
value=0,
max_height=len(self.model_names),
max_width=max_width,
scroll_exit=True,
rely=5,
)
self.add_widget_intelligent(
npyscreen.FixedText,
value='MODEL 2',
color='GOOD',
editable=False,
relx=max_width+3 if horizontal_layout else None,
rely=4 if horizontal_layout else None,
)
self.model2 = self.add_widget_intelligent(
npyscreen.SelectOne,
name='(2)',
values=self.model_names,
value=1,
max_height=len(self.model_names),
max_width=max_width,
relx=max_width+3 if horizontal_layout else None,
rely=5 if horizontal_layout else None,
scroll_exit=True,
)
self.models.when_value_edited = self.models_changed
self.add_widget_intelligent(
npyscreen.FixedText,
value='MODEL 3',
color='GOOD',
editable=False,
relx=max_width*2+3 if horizontal_layout else None,
rely=4 if horizontal_layout else None,
)
models_plus_none = self.model_names.copy()
models_plus_none.insert(0,'None')
self.model3 = self.add_widget_intelligent(
npyscreen.SelectOne,
name='(3)',
values=models_plus_none,
value=0,
max_height=len(self.model_names)+1,
max_width=max_width,
scroll_exit=True,
relx=max_width*2+3 if horizontal_layout else None,
rely=5 if horizontal_layout else None,
)
for m in [self.model1,self.model2,self.model3]:
m.when_value_edited = self.models_changed
self.merged_model_name = self.add_widget_intelligent(
npyscreen.TitleText,
name="Name for merged model:",
labelColor='CONTROL',
value="",
scroll_exit=True,
)
self.force = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Force merge of incompatible models",
labelColor='CONTROL',
value=False,
scroll_exit=True,
)
@ -226,6 +294,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
name="Merge Method:",
values=self.interpolations,
value=0,
labelColor='CONTROL',
max_height=len(self.interpolations) + 1,
scroll_exit=True,
)
@ -236,47 +305,53 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
step=0.05,
lowest=0,
value=0.5,
labelColor='CONTROL',
scroll_exit=True,
)
self.models.editing = True
self.model1.editing = True
def models_changed(self):
model_names = self.models.values
selected_models = self.models.value
if len(selected_models) > 3:
npyscreen.notify_confirm(
"Too many models selected for merging. Select two to three."
)
return
elif len(selected_models) > 2:
self.merge_method.values = ["add_difference"]
self.merge_method.value = 0
models = self.model1.values
selected_model1 = self.model1.value[0]
selected_model2 = self.model2.value[0]
selected_model3 = self.model3.value[0]
merged_model_name = f'{models[selected_model1]}+{models[selected_model2]}'
self.merged_model_name.value = merged_model_name
if selected_model3 > 0:
self.merge_method.values=['add_difference'],
self.merged_model_name.value += f'+{models[selected_model3]}'
else:
self.merge_method.values = self.interpolations
self.merged_model_name.value = "+".join(
[model_names[x] for x in selected_models]
)
self.merge_method.values=self.interpolations
self.merge_method.value=0
def on_ok(self):
if self.validate_field_values() and self.check_for_overwrite():
self.parentApp.setNextForm(None)
self.editing = False
self.parentApp.merge_arguments = self.marshall_arguments()
npyscreen.notify("Starting the merge...")
npyscreen.notify('Starting the merge...')
else:
self.editing = True
def on_cancel(self):
sys.exit(0)
def marshall_arguments(self) -> dict:
models = [self.models.values[x] for x in self.models.value]
def marshall_arguments(self)->dict:
model_names = self.model_names
models = [
model_names[self.model1.value[0]],
model_names[self.model2.value[0]],
]
if self.model3.value[0] > 0:
models.append(model_names[self.model3.value[0]-1])
args = dict(
models=models,
alpha=self.alpha.value,
interp=self.interpolations[self.merge_method.value[0]],
force=self.force.value,
merged_model_name=self.merged_model_name.value,
alpha = self.alpha.value,
interp = self.interpolations[self.merge_method.value[0]],
force = self.force.value,
merged_model_name = self.merged_model_name.value,
)
return args
@ -289,15 +364,18 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
)
def validate_field_values(self) -> bool:
def validate_field_values(self)->bool:
bad_fields = []
selected_models = self.models.value
if len(selected_models) < 2 or len(selected_models) > 3:
bad_fields.append("Please select two or three models to merge.")
model_names = self.model_names
selected_models = set((model_names[self.model1.value[0]],model_names[self.model2.value[0]]))
if self.model3.value[0] > 0:
selected_models.add(model_names[self.model3.value[0]-1])
if len(selected_models) < 2:
bad_fields.append(f'Please select two or three DIFFERENT models to compare. You selected {selected_models}')
if len(bad_fields) > 0:
message = "The following problems were detected and must be corrected:"
message = 'The following problems were detected and must be corrected:'
for problem in bad_fields:
message += f"\n* {problem}"
message += f'\n* {problem}'
npyscreen.notify_confirm(message)
return False
else:
@ -322,10 +400,9 @@ class Mergeapp(npyscreen.NPSAppManaged):
) # precision doesn't really matter here
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
def run_gui(args: Namespace):
mergeapp = Mergeapp()
mergeapp.run()
@ -338,8 +415,8 @@ def run_gui(args: Namespace):
def run_cli(args: Namespace):
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
assert (
len(args.models) >= 1 and len(args.models) <= 3
), "provide 2 or 3 models to merge"
args.models and len(args.models) >= 1 and len(args.models) <= 3
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
if not args.merged_model_name:
args.merged_model_name = "+".join(args.models)
@ -353,6 +430,7 @@ def run_cli(args: Namespace):
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
merge_diffusion_models_and_commit(**vars(args))
print(f'>> Models merged into new model: "{args.merged_model_name}".')
def main():
@ -365,17 +443,22 @@ def main():
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
args.cache_dir = cache_dir
try:
if args.front_end:
run_gui(args)
else:
run_cli(args)
print(f">> Conversion successful. New model is named {args.merged_model_name}")
except Exception as e:
print(f"** An error occurred while merging the pipelines: {str(e)}")
sys.exit(-1)
except KeyboardInterrupt:
sys.exit(-1)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
try:
if args.front_end:
run_gui(args)
else:
run_cli(args)
print(f'>> Conversion successful.')
except Exception as e:
if str(e).startswith('Not enough space'):
print('** Not enough horizontal space! Try making the window wider, or relaunch with a smaller starting size.')
else:
print(f"** An error occurred while merging the pipelines: {str(e)}")
sys.exit(-1)
except KeyboardInterrupt:
sys.exit(-1)
if __name__ == "__main__":
main()