Tidy names and locations of modules

- Rename old "model_management" directory to "model_management_OLD" in order to catch
  dangling references to original model manager.
- Caught and fixed most dangling references (still checking)
- Rename lora, textual_inversion and model_patcher modules
- Introduce a RawModel base class to simplfy the Union returned by the
  model loaders.
- Tidy up the model manager 2-related tests. Add useful fixtures, and
  a finalizer to the queue and installer fixtures that will stop the
  services and release threads.
This commit is contained in:
Lincoln Stein
2024-02-17 11:45:32 -05:00
committed by psychedelicious
parent 996eb96b4e
commit 5d612ec095
89 changed files with 355 additions and 1609 deletions

View File

@ -1,845 +0,0 @@
#!/usr/bin/env python
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
# Before running stable-diffusion on an internet-isolated machine,
# run this script from one with internet connectivity. The
# two machines must share a common .cache directory.
"""
This is the npyscreen frontend to the model installation application.
The work is actually done in backend code in model_install_backend.py.
"""
import argparse
import curses
import logging
import sys
import textwrap
import traceback
from argparse import Namespace
from multiprocessing import Process
from multiprocessing.connection import Connection, Pipe
from pathlib import Path
from shutil import get_terminal_size
from typing import Optional
import npyscreen
import torch
from npyscreen import widget
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType
from invokeai.backend.model_management import ModelManager, ModelType
from invokeai.backend.util import choose_precision, choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.widgets import (
MIN_COLS,
MIN_LINES,
BufferBox,
CenteredTitleText,
CyclingForm,
MultiSelectColumns,
SingleSelectColumns,
TextBox,
WindowTooSmallException,
select_stable_diffusion_config_file,
set_min_terminal_size,
)
config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger()
# build a table mapping all non-printable characters to None
# for stripping control characters
# from https://stackoverflow.com/questions/92438/stripping-non-printable-characters-from-a-string-in-python
NOPRINT_TRANS_TABLE = {i: None for i in range(0, sys.maxunicode + 1) if not chr(i).isprintable()}
# maximum number of installed models we can display before overflowing vertically
MAX_OTHER_MODELS = 72
def make_printable(s: str) -> str:
"""Replace non-printable characters in a string"""
return s.translate(NOPRINT_TRANS_TABLE)
class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
# for responsive resizing set to False, but this seems to cause a crash!
FIX_MINIMUM_SIZE_WHEN_CREATED = True
# for persistence
current_tab = 0
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
self.multipage = multipage
self.subprocess = None
super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad?
def create(self):
self.keypress_timeout = 10
self.counter = 0
self.subprocess_connection = None
if not config.model_conf_path.exists():
with open(config.model_conf_path, "w") as file:
print("# InvokeAI model configuration file", file=file)
self.installer = ModelInstall(config)
self.all_models = self.installer.all_models()
self.starter_models = self.installer.starter_models()
self.model_labels = self._get_model_labels()
window_width, window_height = get_terminal_size()
self.nextrely -= 1
self.add_widget_intelligent(
npyscreen.FixedText,
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields. Cursor keys navigate, and <space> selects.",
editable=False,
color="CAUTION",
)
self.nextrely += 1
self.tabs = self.add_widget_intelligent(
SingleSelectColumns,
values=[
"STARTERS",
"MAINS",
"CONTROLNETS",
"T2I-ADAPTERS",
"IP-ADAPTERS",
"LORAS",
"TI EMBEDDINGS",
],
value=[self.current_tab],
columns=7,
max_height=2,
relx=8,
scroll_exit=True,
)
self.tabs.on_changed = self._toggle_tables
top_of_table = self.nextrely
self.starter_pipelines = self.add_starter_pipelines()
bottom_of_table = self.nextrely
self.nextrely = top_of_table
self.pipeline_models = self.add_pipeline_widgets(
model_type=ModelType.Main, window_width=window_width, exclude=self.starter_models
)
# self.pipeline_models['autoload_pending'] = True
bottom_of_table = max(bottom_of_table, self.nextrely)
self.nextrely = top_of_table
self.controlnet_models = self.add_model_widgets(
model_type=ModelType.ControlNet,
window_width=window_width,
)
bottom_of_table = max(bottom_of_table, self.nextrely)
self.nextrely = top_of_table
self.t2i_models = self.add_model_widgets(
model_type=ModelType.T2IAdapter,
window_width=window_width,
)
bottom_of_table = max(bottom_of_table, self.nextrely)
self.nextrely = top_of_table
self.ipadapter_models = self.add_model_widgets(
model_type=ModelType.IPAdapter,
window_width=window_width,
)
bottom_of_table = max(bottom_of_table, self.nextrely)
self.nextrely = top_of_table
self.lora_models = self.add_model_widgets(
model_type=ModelType.Lora,
window_width=window_width,
)
bottom_of_table = max(bottom_of_table, self.nextrely)
self.nextrely = top_of_table
self.ti_models = self.add_model_widgets(
model_type=ModelType.TextualInversion,
window_width=window_width,
)
bottom_of_table = max(bottom_of_table, self.nextrely)
self.nextrely = bottom_of_table + 1
self.monitor = self.add_widget_intelligent(
BufferBox,
name="Log Messages",
editable=False,
max_height=6,
)
self.nextrely += 1
done_label = "APPLY CHANGES"
back_label = "BACK"
cancel_label = "CANCEL"
current_position = self.nextrely
if self.multipage:
self.back_button = self.add_widget_intelligent(
npyscreen.ButtonPress,
name=back_label,
when_pressed_function=self.on_back,
)
else:
self.nextrely = current_position
self.cancel_button = self.add_widget_intelligent(
npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel
)
self.nextrely = current_position
self.ok_button = self.add_widget_intelligent(
npyscreen.ButtonPress,
name=done_label,
relx=(window_width - len(done_label)) // 2,
when_pressed_function=self.on_execute,
)
label = "APPLY CHANGES & EXIT"
self.nextrely = current_position
self.done = self.add_widget_intelligent(
npyscreen.ButtonPress,
name=label,
relx=window_width - len(label) - 15,
when_pressed_function=self.on_done,
)
# This restores the selected page on return from an installation
for _i in range(1, self.current_tab + 1):
self.tabs.h_cursor_line_down(1)
self._toggle_tables([self.current_tab])
############# diffusers tab ##########
def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
"""Add widgets responsible for selecting diffusers models"""
widgets = {}
models = self.all_models
starters = self.starter_models
starter_model_labels = self.model_labels
self.installed_models = sorted([x for x in starters if models[x].installed])
widgets.update(
label1=self.add_widget_intelligent(
CenteredTitleText,
name="Select from a starter set of Stable Diffusion models from HuggingFace.",
editable=False,
labelColor="CAUTION",
)
)
self.nextrely -= 1
# if user has already installed some initial models, then don't patronize them
# by showing more recommendations
show_recommended = len(self.installed_models) == 0
keys = [x for x in models.keys() if x in starters]
widgets.update(
models_selected=self.add_widget_intelligent(
MultiSelectColumns,
columns=1,
name="Install Starter Models",
values=[starter_model_labels[x] for x in keys],
value=[
keys.index(x)
for x in keys
if (show_recommended and models[x].recommended) or (x in self.installed_models)
],
max_height=len(starters) + 1,
relx=4,
scroll_exit=True,
),
models=keys,
)
self.nextrely += 1
return widgets
############# Add a set of model install widgets ########
def add_model_widgets(
self,
model_type: ModelType,
window_width: int = 120,
install_prompt: str = None,
exclude: set = None,
) -> dict[str, npyscreen.widget]:
"""Generic code to create model selection widgets"""
if exclude is None:
exclude = set()
widgets = {}
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
model_labels = [self.model_labels[x] for x in model_list]
show_recommended = len(self.installed_models) == 0
truncated = False
if len(model_list) > 0:
max_width = max([len(x) for x in model_labels])
columns = window_width // (max_width + 8) # 8 characters for "[x] " and padding
columns = min(len(model_list), columns) or 1
prompt = (
install_prompt
or f"Select the desired {model_type.value.title()} models to install. Unchecked models will be purged from disk."
)
widgets.update(
label1=self.add_widget_intelligent(
CenteredTitleText,
name=prompt,
editable=False,
labelColor="CAUTION",
)
)
if len(model_labels) > MAX_OTHER_MODELS:
model_labels = model_labels[0:MAX_OTHER_MODELS]
truncated = True
widgets.update(
models_selected=self.add_widget_intelligent(
MultiSelectColumns,
columns=columns,
name=f"Install {model_type} Models",
values=model_labels,
value=[
model_list.index(x)
for x in model_list
if (show_recommended and self.all_models[x].recommended) or self.all_models[x].installed
],
max_height=len(model_list) // columns + 1,
relx=4,
scroll_exit=True,
),
models=model_list,
)
if truncated:
widgets.update(
warning_message=self.add_widget_intelligent(
npyscreen.FixedText,
value=f"Too many models to display (max={MAX_OTHER_MODELS}). Some are not displayed.",
editable=False,
color="CAUTION",
)
)
self.nextrely += 1
widgets.update(
download_ids=self.add_widget_intelligent(
TextBox,
name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):",
max_height=4,
scroll_exit=True,
editable=True,
)
)
return widgets
### Tab for arbitrary diffusers widgets ###
def add_pipeline_widgets(
self,
model_type: ModelType = ModelType.Main,
window_width: int = 120,
**kwargs,
) -> dict[str, npyscreen.widget]:
"""Similar to add_model_widgets() but adds some additional widgets at the bottom
to support the autoload directory"""
widgets = self.add_model_widgets(
model_type=model_type,
window_width=window_width,
install_prompt=f"Installed {model_type.value.title()} models. Unchecked models in the InvokeAI root directory will be deleted. Enter URLs, paths or repo_ids to import.",
**kwargs,
)
return widgets
def resize(self):
super().resize()
if s := self.starter_pipelines.get("models_selected"):
keys = [x for x in self.all_models.keys() if x in self.starter_models]
s.values = [self.model_labels[x] for x in keys]
def _toggle_tables(self, value=None):
selected_tab = value[0]
widgets = [
self.starter_pipelines,
self.pipeline_models,
self.controlnet_models,
self.t2i_models,
self.ipadapter_models,
self.lora_models,
self.ti_models,
]
for group in widgets:
for _k, v in group.items():
try:
v.hidden = True
v.editable = False
except Exception:
pass
for _k, v in widgets[selected_tab].items():
try:
v.hidden = False
if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
v.editable = True
except Exception:
pass
self.__class__.current_tab = selected_tab # for persistence
self.display()
def _get_model_labels(self) -> dict[str, str]:
window_width, window_height = get_terminal_size()
checkbox_width = 4
spacing_width = 2
models = self.all_models
label_width = max([len(models[x].name) for x in models])
description_width = window_width - label_width - checkbox_width - spacing_width
result = {}
for x in models.keys():
description = models[x].description
description = (
description[0 : description_width - 3] + "..."
if description and len(description) > description_width
else description
if description
else ""
)
result[x] = f"%-{label_width}s %s" % (models[x].name, description)
return result
def _get_columns(self) -> int:
window_width, window_height = get_terminal_size()
cols = 4 if window_width > 240 else 3 if window_width > 160 else 2 if window_width > 80 else 1
return min(cols, len(self.installed_models))
def confirm_deletions(self, selections: InstallSelections) -> bool:
remove_models = selections.remove_models
if len(remove_models) > 0:
mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models])
return npyscreen.notify_ok_cancel(
f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}"
)
else:
return True
def on_execute(self):
self.marshall_arguments()
app = self.parentApp
if not self.confirm_deletions(app.install_selections):
return
self.monitor.entry_widget.buffer(["Processing..."], scroll_end=True)
self.ok_button.hidden = True
self.display()
# TO DO: Spawn a worker thread, not a subprocess
parent_conn, child_conn = Pipe()
p = Process(
target=process_and_execute,
kwargs={
"opt": app.program_opts,
"selections": app.install_selections,
"conn_out": child_conn,
},
)
p.start()
child_conn.close()
self.subprocess_connection = parent_conn
self.subprocess = p
app.install_selections = InstallSelections()
def on_back(self):
self.parentApp.switchFormPrevious()
self.editing = False
def on_cancel(self):
self.parentApp.setNextForm(None)
self.parentApp.user_cancelled = True
self.editing = False
def on_done(self):
self.marshall_arguments()
if not self.confirm_deletions(self.parentApp.install_selections):
return
self.parentApp.setNextForm(None)
self.parentApp.user_cancelled = False
self.editing = False
########## This routine monitors the child process that is performing model installation and removal #####
def while_waiting(self):
"""Called during idle periods. Main task is to update the Log Messages box with messages
from the child process that does the actual installation/removal"""
c = self.subprocess_connection
if not c:
return
monitor_widget = self.monitor.entry_widget
while c.poll():
try:
data = c.recv_bytes().decode("utf-8")
data.strip("\n")
# processing child is requesting user input to select the
# right configuration file
if data.startswith("*need v2 config"):
_, model_path, *_ = data.split(":", 2)
self._return_v2_config(model_path)
# processing child is done
elif data == "*done*":
self._close_subprocess_and_regenerate_form()
break
# update the log message box
else:
data = make_printable(data)
data = data.replace("[A", "")
monitor_widget.buffer(
textwrap.wrap(
data,
width=monitor_widget.width,
subsequent_indent=" ",
),
scroll_end=True,
)
self.display()
except (EOFError, OSError):
self.subprocess_connection = None
def _return_v2_config(self, model_path: str):
c = self.subprocess_connection
model_name = Path(model_path).name
message = select_stable_diffusion_config_file(model_name=model_name)
c.send_bytes(message.encode("utf-8"))
def _close_subprocess_and_regenerate_form(self):
app = self.parentApp
self.subprocess_connection.close()
self.subprocess_connection = None
self.monitor.entry_widget.buffer(["** Action Complete **"])
self.display()
# rebuild the form, saving and restoring some of the fields that need to be preserved.
saved_messages = self.monitor.entry_widget.values
app.main_form = app.addForm(
"MAIN",
addModelsForm,
name="Install Stable Diffusion Models",
multipage=self.multipage,
)
app.switchForm("MAIN")
app.main_form.monitor.entry_widget.values = saved_messages
app.main_form.monitor.entry_widget.buffer([""], scroll_end=True)
# app.main_form.pipeline_models['autoload_directory'].value = autoload_dir
# app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan
def marshall_arguments(self):
"""
Assemble arguments and store as attributes of the application:
.starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml
True => Install
False => Remove
.scan_directory: Path to a directory of models to scan and import
.autoscan_on_startup: True if invokeai should scan and import at startup time
.import_model_paths: list of URLs, repo_ids and file paths to import
"""
selections = self.parentApp.install_selections
all_models = self.all_models
# Defined models (in INITIAL_CONFIG.yaml or models.yaml) to add/remove
ui_sections = [
self.starter_pipelines,
self.pipeline_models,
self.controlnet_models,
self.t2i_models,
self.ipadapter_models,
self.lora_models,
self.ti_models,
]
for section in ui_sections:
if "models_selected" not in section:
continue
selected = {section["models"][x] for x in section["models_selected"].value}
models_to_install = [x for x in selected if not self.all_models[x].installed]
models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
selections.remove_models.extend(models_to_remove)
selections.install_models.extend(
all_models[x].path or all_models[x].repo_id
for x in models_to_install
if all_models[x].path or all_models[x].repo_id
)
# models located in the 'download_ids" section
for section in ui_sections:
if downloads := section.get("download_ids"):
selections.install_models.extend(downloads.value.split())
# NOT NEEDED - DONE IN BACKEND NOW
# # special case for the ipadapter_models. If any of the adapters are
# # chosen, then we add the corresponding encoder(s) to the install list.
# section = self.ipadapter_models
# if section.get("models_selected"):
# selected_adapters = [
# self.all_models[section["models"][x]].name for x in section.get("models_selected").value
# ]
# encoders = []
# if any(["sdxl" in x for x in selected_adapters]):
# encoders.append("ip_adapter_sdxl_image_encoder")
# if any(["sd15" in x for x in selected_adapters]):
# encoders.append("ip_adapter_sd_image_encoder")
# for encoder in encoders:
# key = f"any/clip_vision/{encoder}"
# repo_id = f"InvokeAI/{encoder}"
# if key not in self.all_models:
# selections.install_models.append(repo_id)
class AddModelApplication(npyscreen.NPSAppManaged):
def __init__(self, opt):
super().__init__()
self.program_opts = opt
self.user_cancelled = False
# self.autoload_pending = True
self.install_selections = InstallSelections()
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
self.main_form = self.addForm(
"MAIN",
addModelsForm,
name="Install Stable Diffusion Models",
cycle_widgets=False,
)
class StderrToMessage:
def __init__(self, connection: Connection):
self.connection = connection
def write(self, data: str):
self.connection.send_bytes(data.encode("utf-8"))
def flush(self):
pass
# --------------------------------------------------------
def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None) -> SchedulerPredictionType:
if tui_conn:
logger.debug("Waiting for user response...")
return _ask_user_for_pt_tui(model_path, tui_conn)
else:
return _ask_user_for_pt_cmdline(model_path)
def _ask_user_for_pt_cmdline(model_path: Path) -> Optional[SchedulerPredictionType]:
choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
print(
f"""
Please select the scheduler prediction type of the checkpoint named {model_path.name}:
[1] "epsilon" - most v1.5 models and v2 models trained on 512 pixel images
[2] "vprediction" - v2 models trained on 768 pixel images and a few v1.5 models
[3] Accept the best guess; you can fix it in the Web UI later
"""
)
choice = None
ok = False
while not ok:
try:
choice = input("select [3]> ").strip()
if not choice:
return None
choice = choices[int(choice) - 1]
ok = True
except (ValueError, IndexError):
print(f"{choice} is not a valid choice")
except EOFError:
return
return choice
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType:
tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8"))
# note that we don't do any status checking here
response = tui_conn.recv_bytes().decode("utf-8")
if response is None:
return None
elif response == "epsilon":
return SchedulerPredictionType.epsilon
elif response == "v":
return SchedulerPredictionType.VPrediction
elif response == "guess":
return None
else:
return None
# --------------------------------------------------------
def process_and_execute(
opt: Namespace,
selections: InstallSelections,
conn_out: Connection = None,
):
# need to reinitialize config in subprocess
config = InvokeAIAppConfig.get_config()
args = ["--root", opt.root] if opt.root else []
config.parse_args(args)
# set up so that stderr is sent to conn_out
if conn_out:
translator = StderrToMessage(conn_out)
sys.stderr = translator
sys.stdout = translator
logger = InvokeAILogger.get_logger()
logger.handlers.clear()
logger.addHandler(logging.StreamHandler(translator))
installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x, conn_out))
installer.install(selections)
if conn_out:
conn_out.send_bytes("*done*".encode("utf-8"))
conn_out.close()
# --------------------------------------------------------
def select_and_download_models(opt: Namespace):
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
config.precision = precision
installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type)
if opt.list_models:
installer.list_models(opt.list_models)
elif opt.add or opt.delete:
selections = InstallSelections(install_models=opt.add or [], remove_models=opt.delete or [])
installer.install(selections)
elif opt.default_only:
selections = InstallSelections(install_models=installer.default_model())
installer.install(selections)
elif opt.yes_to_all:
selections = InstallSelections(install_models=installer.recommended_models())
installer.install(selections)
# this is where the TUI is called
else:
# needed to support the probe() method running under a subprocess
torch.multiprocessing.set_start_method("spawn")
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
raise WindowTooSmallException(
"Could not increase terminal size. Try running again with a larger window or smaller font size."
)
installApp = AddModelApplication(opt)
try:
installApp.run()
except KeyboardInterrupt as e:
if hasattr(installApp, "main_form"):
if installApp.main_form.subprocess and installApp.main_form.subprocess.is_alive():
logger.info("Terminating subprocesses")
installApp.main_form.subprocess.terminate()
installApp.main_form.subprocess = None
raise e
process_and_execute(opt, installApp.install_selections)
# -------------------------------------
def main():
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
parser.add_argument(
"--add",
nargs="*",
help="List of URLs, local paths or repo_ids of models to install",
)
parser.add_argument(
"--delete",
nargs="*",
help="List of names of models to idelete",
)
parser.add_argument(
"--full-precision",
dest="full_precision",
action=argparse.BooleanOptionalAction,
type=bool,
default=False,
help="use 32-bit weights instead of faster 16-bit weights",
)
parser.add_argument(
"--yes",
"-y",
dest="yes_to_all",
action="store_true",
help='answer "yes" to all prompts',
)
parser.add_argument(
"--default_only",
action="store_true",
help="Only install the default model",
)
parser.add_argument(
"--list-models",
choices=[x.value for x in ModelType],
help="list installed models",
)
parser.add_argument(
"--config_file",
"-c",
dest="config_file",
type=str,
default=None,
help="path to configuration file to create",
)
parser.add_argument(
"--root_dir",
dest="root",
type=str,
default=None,
help="path to root of install directory",
)
opt = parser.parse_args()
invoke_args = []
if opt.root:
invoke_args.extend(["--root", opt.root])
if opt.full_precision:
invoke_args.extend(["--precision", "float32"])
config.parse_args(invoke_args)
logger = InvokeAILogger().get_logger(config=config)
if not config.model_conf_path.exists():
logger.info("Your InvokeAI root directory is not set up. Calling invokeai-configure.")
from invokeai.frontend.install.invokeai_configure import invokeai_configure
invokeai_configure()
sys.exit(0)
try:
select_and_download_models(opt)
except AssertionError as e:
logger.error(e)
sys.exit(-1)
except KeyboardInterrupt:
curses.nocbreak()
curses.echo()
curses.endwin()
logger.info("Goodbye! Come back soon.")
except WindowTooSmallException as e:
logger.error(str(e))
except widget.NotEnoughSpaceForWidget as e:
if str(e).startswith("Height of 1 allocated"):
logger.error("Insufficient vertical space for the interface. Please make your window taller and try again")
input("Press any key to continue...")
except Exception as e:
if str(e).startswith("addwstr"):
logger.error(
"Insufficient horizontal space for the interface. Please make your window wider and try again."
)
else:
print(f"An exception has occurred: {str(e)} Details:")
print(traceback.format_exc(), file=sys.stderr)
input("Press any key to continue...")
# -------------------------------------
if __name__ == "__main__":
main()

View File

@ -1,438 +0,0 @@
"""
invokeai.frontend.merge exports a single function called merge_diffusion_models().
It merges 2-3 models together and create a new InvokeAI-registered diffusion model.
Copyright (c) 2023-24 Lincoln Stein and the InvokeAI Development Team
"""
import argparse
import curses
import re
import sys
from argparse import Namespace
from pathlib import Path
from typing import List, Optional, Tuple
import npyscreen
from npyscreen import widget
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.install.install_helper import initialize_installer
from invokeai.backend.model_manager import (
BaseModelType,
ModelFormat,
ModelType,
ModelVariantType,
)
from invokeai.backend.model_manager.merge import ModelMerger
from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox
config = InvokeAIAppConfig.get_config()
BASE_TYPES = [
(BaseModelType.StableDiffusion1, "Models Built on SD-1.x"),
(BaseModelType.StableDiffusion2, "Models Built on SD-2.x"),
(BaseModelType.StableDiffusionXL, "Models Built on SDXL"),
]
def _parse_args() -> Namespace:
parser = argparse.ArgumentParser(description="InvokeAI model merging")
parser.add_argument(
"--root_dir",
type=Path,
default=config.root,
help="Path to the invokeai runtime directory",
)
parser.add_argument(
"--front_end",
"--gui",
dest="front_end",
action="store_true",
default=False,
help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.",
)
parser.add_argument(
"--models",
dest="model_names",
type=str,
nargs="+",
help="Two to three model names to be merged",
)
parser.add_argument(
"--base_model",
type=str,
choices=[x[0].value for x in BASE_TYPES],
help="The base model shared by the models to be merged",
)
parser.add_argument(
"--merged_model_name",
"--destination",
dest="merged_model_name",
type=str,
help="Name of the output model. If not specified, will be the concatenation of the input model names.",
)
parser.add_argument(
"--alpha",
type=float,
default=0.5,
help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models",
)
parser.add_argument(
"--interpolation",
dest="interp",
type=str,
choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"],
default="weighted_sum",
help='Interpolation method to use. If three models are present, only "add_difference" will work.',
)
parser.add_argument(
"--force",
action="store_true",
help="Try to merge models even if they are incompatible with each other",
)
parser.add_argument(
"--clobber",
"--overwrite",
dest="clobber",
action="store_true",
help="Overwrite the merged model if --merged_model_name already exists",
)
return parser.parse_args()
# ------------------------- GUI HERE -------------------------
class mergeModelsForm(npyscreen.FormMultiPageAction):
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid"]
def __init__(self, parentApp, name):
self.parentApp = parentApp
self.ALLOW_RESIZE = True
self.FIX_MINIMUM_SIZE_WHEN_CREATED = False
super().__init__(parentApp, name)
@property
def model_record_store(self) -> ModelRecordServiceBase:
installer: ModelInstallServiceBase = self.parentApp.installer
return installer.record_store
def afterEditing(self) -> None:
self.parentApp.setNextForm(None)
def create(self) -> None:
window_height, window_width = curses.initscr().getmaxyx()
self.current_base = 0
self.models = self.get_models(BASE_TYPES[self.current_base][0])
self.model_names = [x[1] for x in self.models]
max_width = max([len(x) for x in self.model_names])
max_width += 6
horizontal_layout = max_width * 3 < window_width
self.add_widget_intelligent(
npyscreen.FixedText,
color="CONTROL",
value="Select two models to merge and optionally a third.",
editable=False,
)
self.add_widget_intelligent(
npyscreen.FixedText,
color="CONTROL",
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
editable=False,
)
self.nextrely += 1
self.base_select = self.add_widget_intelligent(
SingleSelectColumns,
values=[x[1] for x in BASE_TYPES],
value=[self.current_base],
columns=4,
max_height=2,
relx=8,
scroll_exit=True,
)
self.base_select.on_changed = self._populate_models
self.add_widget_intelligent(
npyscreen.FixedText,
value="MODEL 1",
color="GOOD",
editable=False,
rely=6 if horizontal_layout else None,
)
self.model1 = self.add_widget_intelligent(
npyscreen.SelectOne,
values=self.model_names,
value=0,
max_height=len(self.model_names),
max_width=max_width,
scroll_exit=True,
rely=7,
)
self.add_widget_intelligent(
npyscreen.FixedText,
value="MODEL 2",
color="GOOD",
editable=False,
relx=max_width + 3 if horizontal_layout else None,
rely=6 if horizontal_layout else None,
)
self.model2 = self.add_widget_intelligent(
npyscreen.SelectOne,
name="(2)",
values=self.model_names,
value=1,
max_height=len(self.model_names),
max_width=max_width,
relx=max_width + 3 if horizontal_layout else None,
rely=7 if horizontal_layout else None,
scroll_exit=True,
)
self.add_widget_intelligent(
npyscreen.FixedText,
value="MODEL 3",
color="GOOD",
editable=False,
relx=max_width * 2 + 3 if horizontal_layout else None,
rely=6 if horizontal_layout else None,
)
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
self.model3 = self.add_widget_intelligent(
npyscreen.SelectOne,
name="(3)",
values=models_plus_none,
value=0,
max_height=len(self.model_names) + 1,
max_width=max_width,
scroll_exit=True,
relx=max_width * 2 + 3 if horizontal_layout else None,
rely=7 if horizontal_layout else None,
)
for m in [self.model1, self.model2, self.model3]:
m.when_value_edited = self.models_changed
self.merged_model_name = self.add_widget_intelligent(
TextBox,
name="Name for merged model:",
labelColor="CONTROL",
max_height=3,
value="",
scroll_exit=True,
)
self.force = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Force merge of models created by different diffusers library versions",
labelColor="CONTROL",
value=True,
scroll_exit=True,
)
self.nextrely += 1
self.merge_method = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Merge Method:",
values=self.interpolations,
value=0,
labelColor="CONTROL",
max_height=len(self.interpolations) + 1,
scroll_exit=True,
)
self.alpha = self.add_widget_intelligent(
FloatTitleSlider,
name="Weight (alpha) to assign to second and third models:",
out_of=1.0,
step=0.01,
lowest=0,
value=0.5,
labelColor="CONTROL",
scroll_exit=True,
)
self.model1.editing = True
def models_changed(self) -> None:
models = self.model1.values
selected_model1 = self.model1.value[0]
selected_model2 = self.model2.value[0]
selected_model3 = self.model3.value[0]
merged_model_name = f"{models[selected_model1]}+{models[selected_model2]}"
self.merged_model_name.value = merged_model_name
if selected_model3 > 0:
self.merge_method.values = ["add_difference ( A+(B-C) )"]
self.merged_model_name.value += f"+{models[selected_model3 -1]}" # In model3 there is one more element in the list (None). So we have to subtract one.
else:
self.merge_method.values = self.interpolations
self.merge_method.value = 0
def on_ok(self) -> None:
if self.validate_field_values() and self.check_for_overwrite():
self.parentApp.setNextForm(None)
self.editing = False
self.parentApp.merge_arguments = self.marshall_arguments()
npyscreen.notify("Starting the merge...")
else:
self.editing = True
def on_cancel(self) -> None:
sys.exit(0)
def marshall_arguments(self) -> dict:
model_keys = [x[0] for x in self.models]
models = [
model_keys[self.model1.value[0]],
model_keys[self.model2.value[0]],
]
if self.model3.value[0] > 0:
models.append(model_keys[self.model3.value[0] - 1])
interp = "add_difference"
else:
interp = self.interpolations[self.merge_method.value[0]]
args = {
"model_keys": models,
"alpha": self.alpha.value,
"interp": interp,
"force": self.force.value,
"merged_model_name": self.merged_model_name.value,
}
return args
def check_for_overwrite(self) -> bool:
model_out = self.merged_model_name.value
if model_out not in self.model_names:
return True
else:
result: bool = npyscreen.notify_yes_no(
f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
)
return result
def validate_field_values(self) -> bool:
bad_fields = []
model_names = self.model_names
selected_models = {model_names[self.model1.value[0]], model_names[self.model2.value[0]]}
if self.model3.value[0] > 0:
selected_models.add(model_names[self.model3.value[0] - 1])
if len(selected_models) < 2:
bad_fields.append(f"Please select two or three DIFFERENT models to compare. You selected {selected_models}")
if len(bad_fields) > 0:
message = "The following problems were detected and must be corrected:"
for problem in bad_fields:
message += f"\n* {problem}"
npyscreen.notify_confirm(message)
return False
else:
return True
def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name
models = [
(x.key, x.name)
for x in self.model_record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model)
if x.format == ModelFormat("diffusers")
and hasattr(x, "variant")
and x.variant == ModelVariantType("normal")
]
return sorted(models, key=lambda x: x[1])
def _populate_models(self, value: List[int]) -> None:
base_model = BASE_TYPES[value[0]][0]
self.models = self.get_models(base_model)
self.model_names = [x[1] for x in self.models]
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
self.model1.values = self.model_names
self.model2.values = self.model_names
self.model3.values = models_plus_none
self.display()
# npyscreen is untyped and causes mypy to get naggy
class Mergeapp(npyscreen.NPSAppManaged): # type: ignore
def __init__(self, installer: ModelInstallServiceBase):
"""Initialize the npyscreen application."""
super().__init__()
self.installer = installer
def onStart(self) -> None:
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
def run_gui(args: Namespace) -> None:
installer = initialize_installer(config)
mergeapp = Mergeapp(installer)
mergeapp.run()
merge_args = mergeapp.merge_arguments
merger = ModelMerger(installer)
merger.merge_diffusion_models_and_save(**merge_args)
logger.info(f'Models merged into new model: "{merge_args.merged_model_name}".')
def run_cli(args: Namespace) -> None:
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
assert (
args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
if not args.merged_model_name:
args.merged_model_name = "+".join(args.model_names)
logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
installer = initialize_installer(config)
store = installer.record_store
assert (
len(store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
merger = ModelMerger(installer)
model_keys = []
for name in args.model_names:
if len(name) == 32 and re.match(r"^[0-9a-f]$", name):
model_keys.append(name)
else:
models = store.search_by_attr(
model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model)
)
assert len(models) > 0, f"{name}: Unknown model"
assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead."
model_keys.append(models[0].key)
merger.merge_diffusion_models_and_save(
alpha=args.alpha,
model_keys=model_keys,
merged_model_name=args.merged_model_name,
interp=args.interp,
force=args.force,
)
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
def main() -> None:
args = _parse_args()
if args.root_dir:
config.parse_args(["--root", str(args.root_dir)])
else:
config.parse_args([])
try:
if args.front_end:
run_gui(args)
else:
run_cli(args)
except widget.NotEnoughSpaceForWidget as e:
if str(e).startswith("Height of 1 allocated"):
logger.error("You need to have at least two diffusers models defined in models.yaml in order to merge")
else:
logger.error("Not enough room for the user interface. Try making this window larger.")
sys.exit(-1)
except Exception as e:
logger.error(str(e))
sys.exit(-1)
except KeyboardInterrupt:
sys.exit(-1)
if __name__ == "__main__":
main()