""" invokeai.frontend.merge exports a single function call merge_diffusion_models() used to merge 2-3 models together and create a new InvokeAI-registered diffusion model. Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team """ import argparse import curses import re import sys from argparse import Namespace from pathlib import Path from typing import List, Optional, Tuple import npyscreen from npyscreen import widget from invokeai.app.services.config.config_default import get_config from invokeai.app.services.download import DownloadQueueService from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage from invokeai.app.services.model_install import ModelInstallService from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager import ( BaseModelType, ModelFormat, ModelType, ModelVariantType, ) from invokeai.backend.model_manager.merge import ModelMerger from invokeai.backend.util.logging import InvokeAILogger from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox config = get_config() logger = InvokeAILogger.get_logger() BASE_TYPES = [ (BaseModelType.StableDiffusion1, "Models Built on SD-1.x"), (BaseModelType.StableDiffusion2, "Models Built on SD-2.x"), (BaseModelType.StableDiffusionXL, "Models Built on SDXL"), ] def _parse_args() -> Namespace: parser = argparse.ArgumentParser(description="InvokeAI model merging") parser.add_argument( "--root_dir", type=Path, default=config.root_path, help="Path to the invokeai runtime directory", ) parser.add_argument( "--front_end", "--gui", dest="front_end", action="store_true", default=False, help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.", ) parser.add_argument( "--models", dest="model_names", type=str, nargs="+", help="Two to three model names to be merged", ) parser.add_argument( "--base_model", type=str, choices=[x[0].value for x in BASE_TYPES], help="The base model shared by the models to be merged", ) parser.add_argument( "--merged_model_name", "--destination", dest="merged_model_name", type=str, help="Name of the output model. If not specified, will be the concatenation of the input model names.", ) parser.add_argument( "--alpha", type=float, default=0.5, help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models", ) parser.add_argument( "--interpolation", dest="interp", type=str, choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"], default="weighted_sum", help='Interpolation method to use. If three models are present, only "add_difference" will work.', ) parser.add_argument( "--force", action="store_true", help="Try to merge models even if they are incompatible with each other", ) parser.add_argument( "--clobber", "--overwrite", dest="clobber", action="store_true", help="Overwrite the merged model if --merged_model_name already exists", ) return parser.parse_args() # ------------------------- GUI HERE ------------------------- class mergeModelsForm(npyscreen.FormMultiPageAction): interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid"] def __init__(self, parentApp, name): self.parentApp = parentApp self.ALLOW_RESIZE = True self.FIX_MINIMUM_SIZE_WHEN_CREATED = False super().__init__(parentApp, name) @property def record_store(self): return self.parentApp.record_store def afterEditing(self): self.parentApp.setNextForm(None) def create(self): window_height, window_width = curses.initscr().getmaxyx() self.current_base = 0 self.models = self.get_models(BASE_TYPES[self.current_base][0]) self.model_names = [x[1] for x in self.models] max_width = max([len(x) for x in self.model_names]) max_width += 6 horizontal_layout = max_width * 3 < window_width self.add_widget_intelligent( npyscreen.FixedText, color="CONTROL", value="Select two models to merge and optionally a third.", editable=False, ) self.add_widget_intelligent( npyscreen.FixedText, color="CONTROL", value="Use up and down arrows to move, to select an item, and to move from one field to the next.", editable=False, ) self.nextrely += 1 self.base_select = self.add_widget_intelligent( SingleSelectColumns, values=[x[1] for x in BASE_TYPES], value=[self.current_base], columns=4, max_height=2, relx=8, scroll_exit=True, ) self.base_select.on_changed = self._populate_models self.add_widget_intelligent( npyscreen.FixedText, value="MODEL 1", color="GOOD", editable=False, rely=6 if horizontal_layout else None, ) self.model1 = self.add_widget_intelligent( npyscreen.SelectOne, values=self.model_names, value=0, max_height=len(self.model_names), max_width=max_width, scroll_exit=True, rely=7, ) self.add_widget_intelligent( npyscreen.FixedText, value="MODEL 2", color="GOOD", editable=False, relx=max_width + 3 if horizontal_layout else None, rely=6 if horizontal_layout else None, ) self.model2 = self.add_widget_intelligent( npyscreen.SelectOne, name="(2)", values=self.model_names, value=1, max_height=len(self.model_names), max_width=max_width, relx=max_width + 3 if horizontal_layout else None, rely=7 if horizontal_layout else None, scroll_exit=True, ) self.add_widget_intelligent( npyscreen.FixedText, value="MODEL 3", color="GOOD", editable=False, relx=max_width * 2 + 3 if horizontal_layout else None, rely=6 if horizontal_layout else None, ) models_plus_none = self.model_names.copy() models_plus_none.insert(0, "None") self.model3 = self.add_widget_intelligent( npyscreen.SelectOne, name="(3)", values=models_plus_none, value=0, max_height=len(self.model_names) + 1, max_width=max_width, scroll_exit=True, relx=max_width * 2 + 3 if horizontal_layout else None, rely=7 if horizontal_layout else None, ) for m in [self.model1, self.model2, self.model3]: m.when_value_edited = self.models_changed self.merged_model_name = self.add_widget_intelligent( TextBox, name="Name for merged model:", labelColor="CONTROL", max_height=3, value="", scroll_exit=True, ) self.force = self.add_widget_intelligent( npyscreen.Checkbox, name="Force merge of models created by different diffusers library versions", labelColor="CONTROL", value=True, scroll_exit=True, ) self.nextrely += 1 self.merge_method = self.add_widget_intelligent( npyscreen.TitleSelectOne, name="Merge Method:", values=self.interpolations, value=0, labelColor="CONTROL", max_height=len(self.interpolations) + 1, scroll_exit=True, ) self.alpha = self.add_widget_intelligent( FloatTitleSlider, name="Weight (alpha) to assign to second and third models:", out_of=1.0, step=0.01, lowest=0, value=0.5, labelColor="CONTROL", scroll_exit=True, ) self.model1.editing = True def models_changed(self): models = self.model1.values selected_model1 = self.model1.value[0] selected_model2 = self.model2.value[0] selected_model3 = self.model3.value[0] merged_model_name = f"{models[selected_model1]}+{models[selected_model2]}" self.merged_model_name.value = merged_model_name if selected_model3 > 0: self.merge_method.values = ["add_difference ( A+(B-C) )"] self.merged_model_name.value += f"+{models[selected_model3 -1]}" # In model3 there is one more element in the list (None). So we have to subtract one. else: self.merge_method.values = self.interpolations self.merge_method.value = 0 def on_ok(self): if self.validate_field_values() and self.check_for_overwrite(): self.parentApp.setNextForm(None) self.editing = False self.parentApp.merge_arguments = self.marshall_arguments() npyscreen.notify("Starting the merge...") else: self.editing = True def on_cancel(self): sys.exit(0) def marshall_arguments(self) -> dict: model_keys = [x[0] for x in self.models] models = [ model_keys[self.model1.value[0]], model_keys[self.model2.value[0]], ] if self.model3.value[0] > 0: models.append(model_keys[self.model3.value[0] - 1]) interp = "add_difference" else: interp = self.interpolations[self.merge_method.value[0]] args = { "model_keys": models, "base_model": tuple(BaseModelType)[self.base_select.value[0]], "alpha": self.alpha.value, "interp": interp, "force": self.force.value, "merged_model_name": self.merged_model_name.value, } return args def check_for_overwrite(self) -> bool: model_out = self.merged_model_name.value if model_out not in self.model_names: return True else: return npyscreen.notify_yes_no( f"The chosen merged model destination, {model_out}, is already in use. Overwrite?" ) def validate_field_values(self) -> bool: bad_fields = [] model_names = self.model_names selected_models = {model_names[self.model1.value[0]], model_names[self.model2.value[0]]} if self.model3.value[0] > 0: selected_models.add(model_names[self.model3.value[0] - 1]) if len(selected_models) < 2: bad_fields.append(f"Please select two or three DIFFERENT models to compare. You selected {selected_models}") if len(bad_fields) > 0: message = "The following problems were detected and must be corrected:" for problem in bad_fields: message += f"\n* {problem}" npyscreen.notify_confirm(message) return False else: return True def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name models = [ (x.key, x.name) for x in self.record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model) if x.format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal") ] return sorted(models, key=lambda x: x[1]) def _populate_models(self, value: List[int]): base_model = BASE_TYPES[value[0]][0] self.models = self.get_models(base_model) self.model_names = [x[1] for x in self.models] models_plus_none = self.model_names.copy() models_plus_none.insert(0, "None") self.model1.values = self.model_names self.model2.values = self.model_names self.model3.values = models_plus_none self.display() class Mergeapp(npyscreen.NPSAppManaged): def __init__(self, record_store: ModelRecordServiceBase): super().__init__() self.record_store = record_store def onStart(self): npyscreen.setTheme(npyscreen.Themes.ElegantTheme) self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings") def run_gui(args: Namespace) -> None: record_store: ModelRecordServiceBase = get_config_store() mergeapp = Mergeapp(record_store) mergeapp.run() args = mergeapp.merge_arguments merger = get_model_merger(record_store) merger.merge_diffusion_models_and_save(**args) merged_model_name = args["merged_model_name"] logger.info(f'Models merged into new model: "{merged_model_name}".') def run_cli(args: Namespace): assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1" assert ( args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3 ), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage." if not args.merged_model_name: args.merged_model_name = "+".join(args.model_names) logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"') record_store: ModelRecordServiceBase = get_config_store() assert ( len(record_store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' merger = get_model_merger(record_store) model_keys = [] for name in args.model_names: if len(name) == 32 and re.match(r"^[0-9a-f]$", name): model_keys.append(name) else: models = record_store.search_by_attr( model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model) ) assert len(models) > 0, f"{name}: Unknown model" assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead." model_keys.append(models[0].key) merger.merge_diffusion_models_and_save( alpha=args.alpha, model_keys=model_keys, merged_model_name=args.merged_model_name, interp=args.interp, force=args.force, ) logger.info(f'Models merged into new model: "{args.merged_model_name}".') def get_config_store() -> ModelRecordServiceSQL: output_path = config.outputs_path assert output_path is not None image_files = DiskImageFileStorage(output_path / "images") db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files) return ModelRecordServiceSQL(db) def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger: installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=DownloadQueueService()) installer.start() return ModelMerger(installer) def main(): args = _parse_args() if args.root_dir: config.parse_args(["--root", str(args.root_dir)]) else: config.parse_args([]) try: if args.front_end: run_gui(args) else: run_cli(args) except widget.NotEnoughSpaceForWidget as e: if str(e).startswith("Height of 1 allocated"): logger.error("You need to have at least two diffusers models in order to merge") else: logger.error("Not enough room for the user interface. Try making this window larger.") sys.exit(-1) except Exception as e: logger.error(e) sys.exit(-1) except KeyboardInterrupt: sys.exit(-1) if __name__ == "__main__": main()