enhance console gui for invokeai-merge

- Added modest adaptive behavior; if the screen is wide enough the three
  checklists of models will be arranged in a horizontal row.
- Added color support
This commit is contained in:
Lincoln Stein 2023-02-02 20:26:45 -05:00
parent 0642728484
commit b9aef33ae8

View File

@ -5,6 +5,7 @@ used to merge 2-3 models together and create a new InvokeAI-registered diffusion
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
""" """
import argparse import argparse
import curses
import os import os
import sys import sys
from argparse import Namespace from argparse import Namespace
@ -12,6 +13,7 @@ from pathlib import Path
from typing import List, Union from typing import List, Union
import npyscreen import npyscreen
import warnings
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -26,7 +28,6 @@ from ldm.invoke.model_manager import ModelManager
DEST_MERGED_MODEL_DIR = "merged_models" DEST_MERGED_MODEL_DIR = "merged_models"
def merge_diffusion_models( def merge_diffusion_models(
model_ids_or_paths: List[Union[str, Path]], model_ids_or_paths: List[Union[str, Path]],
alpha: float = 0.5, alpha: float = 0.5,
@ -185,6 +186,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
def __init__(self, parentApp, name): def __init__(self, parentApp, name):
self.parentApp = parentApp self.parentApp = parentApp
self.ALLOW_RESIZE=True
self.FIX_MINIMUM_SIZE_WHEN_CREATED=False
super().__init__(parentApp, name) super().__init__(parentApp, name)
@property @property
@ -195,29 +198,94 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
self.parentApp.setNextForm(None) self.parentApp.setNextForm(None)
def create(self): def create(self):
window_height,window_width=curses.initscr().getmaxyx()
self.model_names = self.get_model_names() self.model_names = self.get_model_names()
max_width = max([len(x) for x in self.model_names])
max_width += 6
horizontal_layout = max_width*3 < window_width
self.add_widget_intelligent( self.add_widget_intelligent(
npyscreen.FixedText, name="Select up to three models to merge", value="" npyscreen.FixedText,
color='CONTROL',
value=f"Select two models to merge and optionally a third.",
editable=False,
) )
self.models = self.add_widget_intelligent( self.add_widget_intelligent(
npyscreen.TitleMultiSelect, npyscreen.FixedText,
name="Select two to three models to merge:", color='CONTROL',
value=f"Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
editable=False,
)
self.add_widget_intelligent(
npyscreen.FixedText,
value='MODEL 1',
color='GOOD',
editable=False,
rely=4 if horizontal_layout else None,
)
self.model1 = self.add_widget_intelligent(
npyscreen.SelectOne,
values=self.model_names, values=self.model_names,
value=None, value=0,
max_height=len(self.model_names) + 1, max_height=len(self.model_names),
max_width=max_width,
scroll_exit=True,
rely=5,
)
self.add_widget_intelligent(
npyscreen.FixedText,
value='MODEL 2',
color='GOOD',
editable=False,
relx=max_width+3 if horizontal_layout else None,
rely=4 if horizontal_layout else None,
)
self.model2 = self.add_widget_intelligent(
npyscreen.SelectOne,
name='(2)',
values=self.model_names,
value=1,
max_height=len(self.model_names),
max_width=max_width,
relx=max_width+3 if horizontal_layout else None,
rely=5 if horizontal_layout else None,
scroll_exit=True, scroll_exit=True,
) )
self.models.when_value_edited = self.models_changed self.add_widget_intelligent(
npyscreen.FixedText,
value='MODEL 3',
color='GOOD',
editable=False,
relx=max_width*2+3 if horizontal_layout else None,
rely=4 if horizontal_layout else None,
)
models_plus_none = self.model_names.copy()
models_plus_none.insert(0,'None')
self.model3 = self.add_widget_intelligent(
npyscreen.SelectOne,
name='(3)',
values=models_plus_none,
value=0,
max_height=len(self.model_names)+1,
max_width=max_width,
scroll_exit=True,
relx=max_width*2+3 if horizontal_layout else None,
rely=5 if horizontal_layout else None,
)
for m in [self.model1,self.model2,self.model3]:
m.when_value_edited = self.models_changed
self.merged_model_name = self.add_widget_intelligent( self.merged_model_name = self.add_widget_intelligent(
npyscreen.TitleText, npyscreen.TitleText,
name="Name for merged model:", name="Name for merged model:",
labelColor='CONTROL',
value="", value="",
scroll_exit=True, scroll_exit=True,
) )
self.force = self.add_widget_intelligent( self.force = self.add_widget_intelligent(
npyscreen.Checkbox, npyscreen.Checkbox,
name="Force merge of incompatible models", name="Force merge of incompatible models",
labelColor='CONTROL',
value=False, value=False,
scroll_exit=True, scroll_exit=True,
) )
@ -226,6 +294,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
name="Merge Method:", name="Merge Method:",
values=self.interpolations, values=self.interpolations,
value=0, value=0,
labelColor='CONTROL',
max_height=len(self.interpolations) + 1, max_height=len(self.interpolations) + 1,
scroll_exit=True, scroll_exit=True,
) )
@ -236,47 +305,53 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
step=0.05, step=0.05,
lowest=0, lowest=0,
value=0.5, value=0.5,
labelColor='CONTROL',
scroll_exit=True, scroll_exit=True,
) )
self.models.editing = True self.model1.editing = True
def models_changed(self): def models_changed(self):
model_names = self.models.values models = self.model1.values
selected_models = self.models.value selected_model1 = self.model1.value[0]
if len(selected_models) > 3: selected_model2 = self.model2.value[0]
npyscreen.notify_confirm( selected_model3 = self.model3.value[0]
"Too many models selected for merging. Select two to three." merged_model_name = f'{models[selected_model1]}+{models[selected_model2]}'
) self.merged_model_name.value = merged_model_name
return
elif len(selected_models) > 2: if selected_model3 > 0:
self.merge_method.values = ["add_difference"] self.merge_method.values=['add_difference'],
self.merge_method.value = 0 self.merged_model_name.value += f'+{models[selected_model3]}'
else: else:
self.merge_method.values = self.interpolations self.merge_method.values=self.interpolations
self.merged_model_name.value = "+".join( self.merge_method.value=0
[model_names[x] for x in selected_models]
)
def on_ok(self): def on_ok(self):
if self.validate_field_values() and self.check_for_overwrite(): if self.validate_field_values() and self.check_for_overwrite():
self.parentApp.setNextForm(None) self.parentApp.setNextForm(None)
self.editing = False self.editing = False
self.parentApp.merge_arguments = self.marshall_arguments() self.parentApp.merge_arguments = self.marshall_arguments()
npyscreen.notify("Starting the merge...") npyscreen.notify('Starting the merge...')
else: else:
self.editing = True self.editing = True
def on_cancel(self): def on_cancel(self):
sys.exit(0) sys.exit(0)
def marshall_arguments(self) -> dict: def marshall_arguments(self)->dict:
models = [self.models.values[x] for x in self.models.value] model_names = self.model_names
models = [
model_names[self.model1.value[0]],
model_names[self.model2.value[0]],
]
if self.model3.value[0] > 0:
models.append(model_names[self.model3.value[0]-1])
args = dict( args = dict(
models=models, models=models,
alpha=self.alpha.value, alpha = self.alpha.value,
interp=self.interpolations[self.merge_method.value[0]], interp = self.interpolations[self.merge_method.value[0]],
force=self.force.value, force = self.force.value,
merged_model_name=self.merged_model_name.value, merged_model_name = self.merged_model_name.value,
) )
return args return args
@ -289,15 +364,18 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
f"The chosen merged model destination, {model_out}, is already in use. Overwrite?" f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
) )
def validate_field_values(self) -> bool: def validate_field_values(self)->bool:
bad_fields = [] bad_fields = []
selected_models = self.models.value model_names = self.model_names
if len(selected_models) < 2 or len(selected_models) > 3: selected_models = set((model_names[self.model1.value[0]],model_names[self.model2.value[0]]))
bad_fields.append("Please select two or three models to merge.") if self.model3.value[0] > 0:
selected_models.add(model_names[self.model3.value[0]-1])
if len(selected_models) < 2:
bad_fields.append(f'Please select two or three DIFFERENT models to compare. You selected {selected_models}')
if len(bad_fields) > 0: if len(bad_fields) > 0:
message = "The following problems were detected and must be corrected:" message = 'The following problems were detected and must be corrected:'
for problem in bad_fields: for problem in bad_fields:
message += f"\n* {problem}" message += f'\n* {problem}'
npyscreen.notify_confirm(message) npyscreen.notify_confirm(message)
return False return False
else: else:
@ -322,10 +400,9 @@ class Mergeapp(npyscreen.NPSAppManaged):
) # precision doesn't really matter here ) # precision doesn't really matter here
def onStart(self): def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme) npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings") self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
def run_gui(args: Namespace): def run_gui(args: Namespace):
mergeapp = Mergeapp() mergeapp = Mergeapp()
mergeapp.run() mergeapp.run()
@ -338,8 +415,8 @@ def run_gui(args: Namespace):
def run_cli(args: Namespace): def run_cli(args: Namespace):
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1" assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
assert ( assert (
len(args.models) >= 1 and len(args.models) <= 3 args.models and len(args.models) >= 1 and len(args.models) <= 3
), "provide 2 or 3 models to merge" ), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
if not args.merged_model_name: if not args.merged_model_name:
args.merged_model_name = "+".join(args.models) args.merged_model_name = "+".join(args.models)
@ -353,6 +430,7 @@ def run_cli(args: Namespace):
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
merge_diffusion_models_and_commit(**vars(args)) merge_diffusion_models_and_commit(**vars(args))
print(f'>> Models merged into new model: "{args.merged_model_name}".')
def main(): def main():
@ -365,17 +443,22 @@ def main():
] = cache_dir # because not clear the merge pipeline is honoring cache_dir ] = cache_dir # because not clear the merge pipeline is honoring cache_dir
args.cache_dir = cache_dir args.cache_dir = cache_dir
try: with warnings.catch_warnings():
if args.front_end: warnings.simplefilter('ignore')
run_gui(args) try:
else: if args.front_end:
run_cli(args) run_gui(args)
print(f">> Conversion successful. New model is named {args.merged_model_name}") else:
except Exception as e: run_cli(args)
print(f"** An error occurred while merging the pipelines: {str(e)}") print(f'>> Conversion successful.')
sys.exit(-1) except Exception as e:
except KeyboardInterrupt: if str(e).startswith('Not enough space'):
sys.exit(-1) print('** Not enough horizontal space! Try making the window wider, or relaunch with a smaller starting size.')
else:
print(f"** An error occurred while merging the pipelines: {str(e)}")
sys.exit(-1)
except KeyboardInterrupt:
sys.exit(-1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()