mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
enhance console gui for invokeai-merge
- Added modest adaptive behavior; if the screen is wide enough the three checklists of models will be arranged in a horizontal row. - Added color support
This commit is contained in:
parent
0642728484
commit
b9aef33ae8
@ -5,6 +5,7 @@ used to merge 2-3 models together and create a new InvokeAI-registered diffusion
|
|||||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
|
import curses
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
@ -12,6 +13,7 @@ from pathlib import Path
|
|||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import npyscreen
|
import npyscreen
|
||||||
|
import warnings
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
@ -26,7 +28,6 @@ from ldm.invoke.model_manager import ModelManager
|
|||||||
|
|
||||||
DEST_MERGED_MODEL_DIR = "merged_models"
|
DEST_MERGED_MODEL_DIR = "merged_models"
|
||||||
|
|
||||||
|
|
||||||
def merge_diffusion_models(
|
def merge_diffusion_models(
|
||||||
model_ids_or_paths: List[Union[str, Path]],
|
model_ids_or_paths: List[Union[str, Path]],
|
||||||
alpha: float = 0.5,
|
alpha: float = 0.5,
|
||||||
@ -185,6 +186,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
|
|
||||||
def __init__(self, parentApp, name):
|
def __init__(self, parentApp, name):
|
||||||
self.parentApp = parentApp
|
self.parentApp = parentApp
|
||||||
|
self.ALLOW_RESIZE=True
|
||||||
|
self.FIX_MINIMUM_SIZE_WHEN_CREATED=False
|
||||||
super().__init__(parentApp, name)
|
super().__init__(parentApp, name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -195,29 +198,94 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
self.parentApp.setNextForm(None)
|
self.parentApp.setNextForm(None)
|
||||||
|
|
||||||
def create(self):
|
def create(self):
|
||||||
|
window_height,window_width=curses.initscr().getmaxyx()
|
||||||
|
|
||||||
self.model_names = self.get_model_names()
|
self.model_names = self.get_model_names()
|
||||||
|
max_width = max([len(x) for x in self.model_names])
|
||||||
|
max_width += 6
|
||||||
|
horizontal_layout = max_width*3 < window_width
|
||||||
|
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.FixedText, name="Select up to three models to merge", value=""
|
npyscreen.FixedText,
|
||||||
|
color='CONTROL',
|
||||||
|
value=f"Select two models to merge and optionally a third.",
|
||||||
|
editable=False,
|
||||||
)
|
)
|
||||||
self.models = self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.TitleMultiSelect,
|
npyscreen.FixedText,
|
||||||
name="Select two to three models to merge:",
|
color='CONTROL',
|
||||||
|
value=f"Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
|
||||||
|
editable=False,
|
||||||
|
)
|
||||||
|
self.add_widget_intelligent(
|
||||||
|
npyscreen.FixedText,
|
||||||
|
value='MODEL 1',
|
||||||
|
color='GOOD',
|
||||||
|
editable=False,
|
||||||
|
rely=4 if horizontal_layout else None,
|
||||||
|
)
|
||||||
|
self.model1 = self.add_widget_intelligent(
|
||||||
|
npyscreen.SelectOne,
|
||||||
values=self.model_names,
|
values=self.model_names,
|
||||||
value=None,
|
value=0,
|
||||||
max_height=len(self.model_names) + 1,
|
max_height=len(self.model_names),
|
||||||
|
max_width=max_width,
|
||||||
|
scroll_exit=True,
|
||||||
|
rely=5,
|
||||||
|
)
|
||||||
|
self.add_widget_intelligent(
|
||||||
|
npyscreen.FixedText,
|
||||||
|
value='MODEL 2',
|
||||||
|
color='GOOD',
|
||||||
|
editable=False,
|
||||||
|
relx=max_width+3 if horizontal_layout else None,
|
||||||
|
rely=4 if horizontal_layout else None,
|
||||||
|
)
|
||||||
|
self.model2 = self.add_widget_intelligent(
|
||||||
|
npyscreen.SelectOne,
|
||||||
|
name='(2)',
|
||||||
|
values=self.model_names,
|
||||||
|
value=1,
|
||||||
|
max_height=len(self.model_names),
|
||||||
|
max_width=max_width,
|
||||||
|
relx=max_width+3 if horizontal_layout else None,
|
||||||
|
rely=5 if horizontal_layout else None,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.models.when_value_edited = self.models_changed
|
self.add_widget_intelligent(
|
||||||
|
npyscreen.FixedText,
|
||||||
|
value='MODEL 3',
|
||||||
|
color='GOOD',
|
||||||
|
editable=False,
|
||||||
|
relx=max_width*2+3 if horizontal_layout else None,
|
||||||
|
rely=4 if horizontal_layout else None,
|
||||||
|
)
|
||||||
|
models_plus_none = self.model_names.copy()
|
||||||
|
models_plus_none.insert(0,'None')
|
||||||
|
self.model3 = self.add_widget_intelligent(
|
||||||
|
npyscreen.SelectOne,
|
||||||
|
name='(3)',
|
||||||
|
values=models_plus_none,
|
||||||
|
value=0,
|
||||||
|
max_height=len(self.model_names)+1,
|
||||||
|
max_width=max_width,
|
||||||
|
scroll_exit=True,
|
||||||
|
relx=max_width*2+3 if horizontal_layout else None,
|
||||||
|
rely=5 if horizontal_layout else None,
|
||||||
|
)
|
||||||
|
for m in [self.model1,self.model2,self.model3]:
|
||||||
|
m.when_value_edited = self.models_changed
|
||||||
self.merged_model_name = self.add_widget_intelligent(
|
self.merged_model_name = self.add_widget_intelligent(
|
||||||
npyscreen.TitleText,
|
npyscreen.TitleText,
|
||||||
name="Name for merged model:",
|
name="Name for merged model:",
|
||||||
|
labelColor='CONTROL',
|
||||||
value="",
|
value="",
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.force = self.add_widget_intelligent(
|
self.force = self.add_widget_intelligent(
|
||||||
npyscreen.Checkbox,
|
npyscreen.Checkbox,
|
||||||
name="Force merge of incompatible models",
|
name="Force merge of incompatible models",
|
||||||
|
labelColor='CONTROL',
|
||||||
value=False,
|
value=False,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
@ -226,6 +294,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
name="Merge Method:",
|
name="Merge Method:",
|
||||||
values=self.interpolations,
|
values=self.interpolations,
|
||||||
value=0,
|
value=0,
|
||||||
|
labelColor='CONTROL',
|
||||||
max_height=len(self.interpolations) + 1,
|
max_height=len(self.interpolations) + 1,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
@ -236,47 +305,53 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
step=0.05,
|
step=0.05,
|
||||||
lowest=0,
|
lowest=0,
|
||||||
value=0.5,
|
value=0.5,
|
||||||
|
labelColor='CONTROL',
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.models.editing = True
|
self.model1.editing = True
|
||||||
|
|
||||||
def models_changed(self):
|
def models_changed(self):
|
||||||
model_names = self.models.values
|
models = self.model1.values
|
||||||
selected_models = self.models.value
|
selected_model1 = self.model1.value[0]
|
||||||
if len(selected_models) > 3:
|
selected_model2 = self.model2.value[0]
|
||||||
npyscreen.notify_confirm(
|
selected_model3 = self.model3.value[0]
|
||||||
"Too many models selected for merging. Select two to three."
|
merged_model_name = f'{models[selected_model1]}+{models[selected_model2]}'
|
||||||
)
|
self.merged_model_name.value = merged_model_name
|
||||||
return
|
|
||||||
elif len(selected_models) > 2:
|
if selected_model3 > 0:
|
||||||
self.merge_method.values = ["add_difference"]
|
self.merge_method.values=['add_difference'],
|
||||||
self.merge_method.value = 0
|
self.merged_model_name.value += f'+{models[selected_model3]}'
|
||||||
else:
|
else:
|
||||||
self.merge_method.values = self.interpolations
|
self.merge_method.values=self.interpolations
|
||||||
self.merged_model_name.value = "+".join(
|
self.merge_method.value=0
|
||||||
[model_names[x] for x in selected_models]
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_ok(self):
|
def on_ok(self):
|
||||||
if self.validate_field_values() and self.check_for_overwrite():
|
if self.validate_field_values() and self.check_for_overwrite():
|
||||||
self.parentApp.setNextForm(None)
|
self.parentApp.setNextForm(None)
|
||||||
self.editing = False
|
self.editing = False
|
||||||
self.parentApp.merge_arguments = self.marshall_arguments()
|
self.parentApp.merge_arguments = self.marshall_arguments()
|
||||||
npyscreen.notify("Starting the merge...")
|
npyscreen.notify('Starting the merge...')
|
||||||
else:
|
else:
|
||||||
self.editing = True
|
self.editing = True
|
||||||
|
|
||||||
def on_cancel(self):
|
def on_cancel(self):
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
def marshall_arguments(self) -> dict:
|
def marshall_arguments(self)->dict:
|
||||||
models = [self.models.values[x] for x in self.models.value]
|
model_names = self.model_names
|
||||||
|
models = [
|
||||||
|
model_names[self.model1.value[0]],
|
||||||
|
model_names[self.model2.value[0]],
|
||||||
|
]
|
||||||
|
if self.model3.value[0] > 0:
|
||||||
|
models.append(model_names[self.model3.value[0]-1])
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
models=models,
|
models=models,
|
||||||
alpha=self.alpha.value,
|
alpha = self.alpha.value,
|
||||||
interp=self.interpolations[self.merge_method.value[0]],
|
interp = self.interpolations[self.merge_method.value[0]],
|
||||||
force=self.force.value,
|
force = self.force.value,
|
||||||
merged_model_name=self.merged_model_name.value,
|
merged_model_name = self.merged_model_name.value,
|
||||||
)
|
)
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@ -289,15 +364,18 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
|
f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_field_values(self) -> bool:
|
def validate_field_values(self)->bool:
|
||||||
bad_fields = []
|
bad_fields = []
|
||||||
selected_models = self.models.value
|
model_names = self.model_names
|
||||||
if len(selected_models) < 2 or len(selected_models) > 3:
|
selected_models = set((model_names[self.model1.value[0]],model_names[self.model2.value[0]]))
|
||||||
bad_fields.append("Please select two or three models to merge.")
|
if self.model3.value[0] > 0:
|
||||||
|
selected_models.add(model_names[self.model3.value[0]-1])
|
||||||
|
if len(selected_models) < 2:
|
||||||
|
bad_fields.append(f'Please select two or three DIFFERENT models to compare. You selected {selected_models}')
|
||||||
if len(bad_fields) > 0:
|
if len(bad_fields) > 0:
|
||||||
message = "The following problems were detected and must be corrected:"
|
message = 'The following problems were detected and must be corrected:'
|
||||||
for problem in bad_fields:
|
for problem in bad_fields:
|
||||||
message += f"\n* {problem}"
|
message += f'\n* {problem}'
|
||||||
npyscreen.notify_confirm(message)
|
npyscreen.notify_confirm(message)
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
@ -322,10 +400,9 @@ class Mergeapp(npyscreen.NPSAppManaged):
|
|||||||
) # precision doesn't really matter here
|
) # precision doesn't really matter here
|
||||||
|
|
||||||
def onStart(self):
|
def onStart(self):
|
||||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
|
||||||
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
|
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
|
||||||
|
|
||||||
|
|
||||||
def run_gui(args: Namespace):
|
def run_gui(args: Namespace):
|
||||||
mergeapp = Mergeapp()
|
mergeapp = Mergeapp()
|
||||||
mergeapp.run()
|
mergeapp.run()
|
||||||
@ -338,8 +415,8 @@ def run_gui(args: Namespace):
|
|||||||
def run_cli(args: Namespace):
|
def run_cli(args: Namespace):
|
||||||
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
||||||
assert (
|
assert (
|
||||||
len(args.models) >= 1 and len(args.models) <= 3
|
args.models and len(args.models) >= 1 and len(args.models) <= 3
|
||||||
), "provide 2 or 3 models to merge"
|
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
|
||||||
|
|
||||||
if not args.merged_model_name:
|
if not args.merged_model_name:
|
||||||
args.merged_model_name = "+".join(args.models)
|
args.merged_model_name = "+".join(args.models)
|
||||||
@ -353,6 +430,7 @@ def run_cli(args: Namespace):
|
|||||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||||
|
|
||||||
merge_diffusion_models_and_commit(**vars(args))
|
merge_diffusion_models_and_commit(**vars(args))
|
||||||
|
print(f'>> Models merged into new model: "{args.merged_model_name}".')
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -365,17 +443,22 @@ def main():
|
|||||||
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
|
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
|
||||||
args.cache_dir = cache_dir
|
args.cache_dir = cache_dir
|
||||||
|
|
||||||
try:
|
with warnings.catch_warnings():
|
||||||
if args.front_end:
|
warnings.simplefilter('ignore')
|
||||||
run_gui(args)
|
try:
|
||||||
else:
|
if args.front_end:
|
||||||
run_cli(args)
|
run_gui(args)
|
||||||
print(f">> Conversion successful. New model is named {args.merged_model_name}")
|
else:
|
||||||
except Exception as e:
|
run_cli(args)
|
||||||
print(f"** An error occurred while merging the pipelines: {str(e)}")
|
print(f'>> Conversion successful.')
|
||||||
sys.exit(-1)
|
except Exception as e:
|
||||||
except KeyboardInterrupt:
|
if str(e).startswith('Not enough space'):
|
||||||
sys.exit(-1)
|
print('** Not enough horizontal space! Try making the window wider, or relaunch with a smaller starting size.')
|
||||||
|
else:
|
||||||
|
print(f"** An error occurred while merging the pipelines: {str(e)}")
|
||||||
|
sys.exit(-1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
Loading…
Reference in New Issue
Block a user