fix crash in both textual_inversion and merge front ends when not enough models defined (#2540)

- Issue is that if insufficient diffusers models are defined in
models.yaml the frontend would ungraciously crash.

- Now it emits appropriate error messages telling user what the problem
is.
This commit is contained in:
Lincoln Stein 2023-02-05 19:34:07 -05:00 committed by GitHub
commit 6ab03c4d08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 93 additions and 59 deletions

View File

@ -15,20 +15,18 @@ from pathlib import Path
from typing import List, Union
import npyscreen
from diffusers import DiffusionPipeline, logging as dlogging
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
from npyscreen import widget
from omegaconf import OmegaConf
from ldm.invoke.globals import (
Globals,
global_cache_dir,
global_config_file,
global_models_dir,
global_set_root,
)
from ldm.invoke.globals import (Globals, global_cache_dir, global_config_file,
global_models_dir, global_set_root)
from ldm.invoke.model_manager import ModelManager
DEST_MERGED_MODEL_DIR = "merged_models"
def merge_diffusion_models(
model_ids_or_paths: List[Union[str, Path]],
alpha: float = 0.5,
@ -48,7 +46,7 @@ def merge_diffusion_models(
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
with warnings.catch_warnings():
warnings.simplefilter('ignore')
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
@ -188,13 +186,12 @@ class FloatTitleSlider(npyscreen.TitleText):
class mergeModelsForm(npyscreen.FormMultiPageAction):
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"]
def __init__(self, parentApp, name):
self.parentApp = parentApp
self.ALLOW_RESIZE=True
self.FIX_MINIMUM_SIZE_WHEN_CREATED=False
self.ALLOW_RESIZE = True
self.FIX_MINIMUM_SIZE_WHEN_CREATED = False
super().__init__(parentApp, name)
@property
@ -205,29 +202,29 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
self.parentApp.setNextForm(None)
def create(self):
window_height,window_width=curses.initscr().getmaxyx()
window_height, window_width = curses.initscr().getmaxyx()
self.model_names = self.get_model_names()
max_width = max([len(x) for x in self.model_names])
max_width += 6
horizontal_layout = max_width*3 < window_width
horizontal_layout = max_width * 3 < window_width
self.add_widget_intelligent(
npyscreen.FixedText,
color='CONTROL',
color="CONTROL",
value=f"Select two models to merge and optionally a third.",
editable=False,
)
self.add_widget_intelligent(
npyscreen.FixedText,
color='CONTROL',
color="CONTROL",
value=f"Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
editable=False,
)
self.add_widget_intelligent(
npyscreen.FixedText,
value='MODEL 1',
color='GOOD',
value="MODEL 1",
color="GOOD",
editable=False,
rely=4 if horizontal_layout else None,
)
@ -242,57 +239,57 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
)
self.add_widget_intelligent(
npyscreen.FixedText,
value='MODEL 2',
color='GOOD',
value="MODEL 2",
color="GOOD",
editable=False,
relx=max_width+3 if horizontal_layout else None,
relx=max_width + 3 if horizontal_layout else None,
rely=4 if horizontal_layout else None,
)
self.model2 = self.add_widget_intelligent(
npyscreen.SelectOne,
name='(2)',
name="(2)",
values=self.model_names,
value=1,
max_height=len(self.model_names),
max_width=max_width,
relx=max_width+3 if horizontal_layout else None,
relx=max_width + 3 if horizontal_layout else None,
rely=5 if horizontal_layout else None,
scroll_exit=True,
)
self.add_widget_intelligent(
npyscreen.FixedText,
value='MODEL 3',
color='GOOD',
value="MODEL 3",
color="GOOD",
editable=False,
relx=max_width*2+3 if horizontal_layout else None,
relx=max_width * 2 + 3 if horizontal_layout else None,
rely=4 if horizontal_layout else None,
)
models_plus_none = self.model_names.copy()
models_plus_none.insert(0,'None')
models_plus_none.insert(0, "None")
self.model3 = self.add_widget_intelligent(
npyscreen.SelectOne,
name='(3)',
name="(3)",
values=models_plus_none,
value=0,
max_height=len(self.model_names)+1,
max_height=len(self.model_names) + 1,
max_width=max_width,
scroll_exit=True,
relx=max_width*2+3 if horizontal_layout else None,
relx=max_width * 2 + 3 if horizontal_layout else None,
rely=5 if horizontal_layout else None,
)
for m in [self.model1,self.model2,self.model3]:
for m in [self.model1, self.model2, self.model3]:
m.when_value_edited = self.models_changed
self.merged_model_name = self.add_widget_intelligent(
npyscreen.TitleText,
name="Name for merged model:",
labelColor='CONTROL',
labelColor="CONTROL",
value="",
scroll_exit=True,
)
self.force = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Force merge of incompatible models",
labelColor='CONTROL',
labelColor="CONTROL",
value=False,
scroll_exit=True,
)
@ -301,7 +298,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
name="Merge Method:",
values=self.interpolations,
value=0,
labelColor='CONTROL',
labelColor="CONTROL",
max_height=len(self.interpolations) + 1,
scroll_exit=True,
)
@ -312,7 +309,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
step=0.05,
lowest=0,
value=0.5,
labelColor='CONTROL',
labelColor="CONTROL",
scroll_exit=True,
)
self.model1.editing = True
@ -322,43 +319,43 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
selected_model1 = self.model1.value[0]
selected_model2 = self.model2.value[0]
selected_model3 = self.model3.value[0]
merged_model_name = f'{models[selected_model1]}+{models[selected_model2]}'
merged_model_name = f"{models[selected_model1]}+{models[selected_model2]}"
self.merged_model_name.value = merged_model_name
if selected_model3 > 0:
self.merge_method.values=['add_difference'],
self.merged_model_name.value += f'+{models[selected_model3]}'
self.merge_method.values = (["add_difference"],)
self.merged_model_name.value += f"+{models[selected_model3]}"
else:
self.merge_method.values=self.interpolations
self.merge_method.value=0
self.merge_method.values = self.interpolations
self.merge_method.value = 0
def on_ok(self):
if self.validate_field_values() and self.check_for_overwrite():
self.parentApp.setNextForm(None)
self.editing = False
self.parentApp.merge_arguments = self.marshall_arguments()
npyscreen.notify('Starting the merge...')
npyscreen.notify("Starting the merge...")
else:
self.editing = True
def on_cancel(self):
sys.exit(0)
def marshall_arguments(self)->dict:
def marshall_arguments(self) -> dict:
model_names = self.model_names
models = [
model_names[self.model1.value[0]],
model_names[self.model2.value[0]],
]
if self.model3.value[0] > 0:
models.append(model_names[self.model3.value[0]-1])
models.append(model_names[self.model3.value[0] - 1])
args = dict(
models=models,
alpha = self.alpha.value,
interp = self.interpolations[self.merge_method.value[0]],
force = self.force.value,
merged_model_name = self.merged_model_name.value,
alpha=self.alpha.value,
interp=self.interpolations[self.merge_method.value[0]],
force=self.force.value,
merged_model_name=self.merged_model_name.value,
)
return args
@ -371,18 +368,22 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
)
def validate_field_values(self)->bool:
def validate_field_values(self) -> bool:
bad_fields = []
model_names = self.model_names
selected_models = set((model_names[self.model1.value[0]],model_names[self.model2.value[0]]))
selected_models = set(
(model_names[self.model1.value[0]], model_names[self.model2.value[0]])
)
if self.model3.value[0] > 0:
selected_models.add(model_names[self.model3.value[0]-1])
selected_models.add(model_names[self.model3.value[0] - 1])
if len(selected_models) < 2:
bad_fields.append(f'Please select two or three DIFFERENT models to compare. You selected {selected_models}')
bad_fields.append(
f"Please select two or three DIFFERENT models to compare. You selected {selected_models}"
)
if len(bad_fields) > 0:
message = 'The following problems were detected and must be corrected:'
message = "The following problems were detected and must be corrected:"
for problem in bad_fields:
message += f'\n* {problem}'
message += f"\n* {problem}"
npyscreen.notify_confirm(message)
return False
else:
@ -410,6 +411,7 @@ class Mergeapp(npyscreen.NPSAppManaged):
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
def run_gui(args: Namespace):
mergeapp = Mergeapp()
mergeapp.run()
@ -450,5 +452,27 @@ def main():
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
args.cache_dir = cache_dir
try:
if args.front_end:
run_gui(args)
else:
run_cli(args)
print(f">> Conversion successful. New model is named {args.merged_model_name}")
except widget.NotEnoughSpaceForWidget as e:
if str(e).startswith("Height of 1 allocated"):
print(
"** You need to have at least two diffusers models defined in models.yaml in order to merge"
)
else:
print(f"** A layout error has occurred: {str(e)}")
sys.exit(-1)
except Exception as e:
print(">> An error occurred:")
traceback.print_exc()
sys.exit(-1)
except KeyboardInterrupt:
sys.exit(-1)
if __name__ == "__main__":
main()

View File

@ -17,6 +17,7 @@ from pathlib import Path
from typing import List, Tuple
import npyscreen
from npyscreen import widget
from omegaconf import OmegaConf
from ldm.invoke.globals import Globals, global_set_root
@ -295,7 +296,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
for idx in range(len(model_names))
if "default" in conf[model_names[idx]]
]
default = defaults[0] if len(defaults)>0 else 0
default = defaults[0] if len(defaults) > 0 else 0
return (model_names, default)
def marshall_arguments(self) -> dict:
@ -438,11 +439,20 @@ def main():
do_front_end(args)
else:
do_textual_inversion_training(**vars(args))
except widget.NotEnoughSpaceForWidget as e:
if str(e).startswith("Height of 1 allocated"):
print(
"** You need to have at least one diffusers models defined in models.yaml in order to train"
)
else:
print(f"** A layout error has occurred: {str(e)}")
sys.exit(-1)
except AssertionError as e:
print(str(e))
sys.exit(-1)
except KeyboardInterrupt:
pass
if __name__ == "__main__":
main()