mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cleanup: remove unused scripts, cruft
App runs & tests pass.
This commit is contained in:
@ -1,60 +0,0 @@
|
||||
"""
|
||||
Wrapper for invokeai.backend.configure.invokeai_configure
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def run_configure() -> None:
|
||||
# Before doing _anything_, parse CLI args!
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||
parser.add_argument(
|
||||
"--skip-sd-weights",
|
||||
dest="skip_sd_weights",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
help="skip downloading the large Stable Diffusion weight files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-support-models",
|
||||
dest="skip_support_models",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
help="skip downloading the support models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full-precision",
|
||||
dest="full_precision",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="use 32-bit weights instead of faster 16-bit weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--yes",
|
||||
"-y",
|
||||
dest="yes_to_all",
|
||||
action="store_true",
|
||||
help='answer "yes" to all prompts',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--default_only",
|
||||
action="store_true",
|
||||
help="when --yes specified, only install the default model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
dest="root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to root of install directory",
|
||||
)
|
||||
|
||||
opt = parser.parse_args()
|
||||
InvokeAIArgs.args = opt
|
||||
|
||||
from invokeai.backend.install.invokeai_configure import main as invokeai_configure
|
||||
|
||||
invokeai_configure(opt)
|
@ -1,652 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||
# Before running stable-diffusion on an internet-isolated machine,
|
||||
# run this script from one with internet connectivity. The
|
||||
# two machines must share a common .cache directory.
|
||||
|
||||
"""
|
||||
This is the npyscreen frontend to the model installation application.
|
||||
It is currently named model_install2.py, but will ultimately replace model_install.py.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import curses
|
||||
import pathlib
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from shutil import get_terminal_size
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import npyscreen
|
||||
import torch
|
||||
from npyscreen import widget
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||
from invokeai.backend.install.check_directories import validate_directories
|
||||
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo
|
||||
from invokeai.backend.model_manager import ModelType
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.widgets import (
|
||||
MIN_COLS,
|
||||
MIN_LINES,
|
||||
CenteredTitleText,
|
||||
CyclingForm,
|
||||
MultiSelectColumns,
|
||||
SingleSelectColumns,
|
||||
TextBox,
|
||||
WindowTooSmallException,
|
||||
set_min_terminal_size,
|
||||
)
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning) # noqa: E402
|
||||
config = get_config()
|
||||
logger = InvokeAILogger.get_logger("ModelInstallService", config=config)
|
||||
# logger.setLevel("WARNING")
|
||||
# logger.setLevel('DEBUG')
|
||||
|
||||
# build a table mapping all non-printable characters to None
|
||||
# for stripping control characters
|
||||
# from https://stackoverflow.com/questions/92438/stripping-non-printable-characters-from-a-string-in-python
|
||||
NOPRINT_TRANS_TABLE = {i: None for i in range(0, sys.maxunicode + 1) if not chr(i).isprintable()}
|
||||
|
||||
# maximum number of installed models we can display before overflowing vertically
|
||||
MAX_OTHER_MODELS = 72
|
||||
|
||||
|
||||
def make_printable(s: str) -> str:
|
||||
"""Replace non-printable characters in a string."""
|
||||
return s.translate(NOPRINT_TRANS_TABLE)
|
||||
|
||||
|
||||
class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
"""Main form for interactive TUI."""
|
||||
|
||||
# for responsive resizing set to False, but this seems to cause a crash!
|
||||
FIX_MINIMUM_SIZE_WHEN_CREATED = True
|
||||
|
||||
# for persistence
|
||||
current_tab = 0
|
||||
|
||||
def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, multipage: bool = False, **keywords: Any):
|
||||
self.multipage = multipage
|
||||
self.subprocess = None
|
||||
super().__init__(parentApp=parentApp, name=name, **keywords)
|
||||
|
||||
def create(self) -> None:
|
||||
self.installer = self.parentApp.install_helper.installer
|
||||
self.model_labels = self._get_model_labels()
|
||||
self.keypress_timeout = 10
|
||||
self.counter = 0
|
||||
self.subprocess_connection = None
|
||||
|
||||
window_width, window_height = get_terminal_size()
|
||||
|
||||
# npyscreen has no typing hints
|
||||
self.nextrely -= 1 # type: ignore
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields. Cursor keys navigate, and <space> selects.",
|
||||
editable=False,
|
||||
color="CAUTION",
|
||||
)
|
||||
self.nextrely += 1 # type: ignore
|
||||
self.tabs = self.add_widget_intelligent(
|
||||
SingleSelectColumns,
|
||||
values=[
|
||||
"STARTERS",
|
||||
"MAINS",
|
||||
"CONTROLNETS",
|
||||
"T2I-ADAPTERS",
|
||||
"IP-ADAPTERS",
|
||||
"LORAS",
|
||||
"TI EMBEDDINGS",
|
||||
],
|
||||
value=[self.current_tab],
|
||||
columns=7,
|
||||
max_height=2,
|
||||
relx=8,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.tabs.on_changed = self._toggle_tables
|
||||
|
||||
top_of_table = self.nextrely # type: ignore
|
||||
self.starter_pipelines = self.add_starter_pipelines()
|
||||
bottom_of_table = self.nextrely # type: ignore
|
||||
|
||||
self.nextrely = top_of_table
|
||||
self.pipeline_models = self.add_pipeline_widgets(
|
||||
model_type=ModelType.Main, window_width=window_width, exclude=self.starter_models
|
||||
)
|
||||
# self.pipeline_models['autoload_pending'] = True
|
||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||
|
||||
self.nextrely = top_of_table
|
||||
self.controlnet_models = self.add_model_widgets(
|
||||
model_type=ModelType.ControlNet,
|
||||
window_width=window_width,
|
||||
)
|
||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||
|
||||
self.nextrely = top_of_table
|
||||
self.t2i_models = self.add_model_widgets(
|
||||
model_type=ModelType.T2IAdapter,
|
||||
window_width=window_width,
|
||||
)
|
||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||
self.nextrely = top_of_table
|
||||
self.ipadapter_models = self.add_model_widgets(
|
||||
model_type=ModelType.IPAdapter,
|
||||
window_width=window_width,
|
||||
)
|
||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||
|
||||
self.nextrely = top_of_table
|
||||
self.lora_models = self.add_model_widgets(
|
||||
model_type=ModelType.LoRA,
|
||||
window_width=window_width,
|
||||
)
|
||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||
|
||||
self.nextrely = top_of_table
|
||||
self.ti_models = self.add_model_widgets(
|
||||
model_type=ModelType.TextualInversion,
|
||||
window_width=window_width,
|
||||
)
|
||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||
|
||||
self.nextrely = bottom_of_table + 1
|
||||
|
||||
self.nextrely += 1
|
||||
back_label = "BACK"
|
||||
cancel_label = "CANCEL"
|
||||
current_position = self.nextrely
|
||||
if self.multipage:
|
||||
self.back_button = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
name=back_label,
|
||||
when_pressed_function=self.on_back,
|
||||
)
|
||||
else:
|
||||
self.nextrely = current_position
|
||||
self.cancel_button = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel
|
||||
)
|
||||
self.nextrely = current_position
|
||||
|
||||
label = "APPLY CHANGES"
|
||||
self.nextrely = current_position
|
||||
self.done = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
name=label,
|
||||
relx=window_width - len(label) - 15,
|
||||
when_pressed_function=self.on_done,
|
||||
)
|
||||
|
||||
# This restores the selected page on return from an installation
|
||||
for _i in range(1, self.current_tab + 1):
|
||||
self.tabs.h_cursor_line_down(1)
|
||||
self._toggle_tables([self.current_tab])
|
||||
|
||||
############# diffusers tab ##########
|
||||
def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
|
||||
"""Add widgets responsible for selecting diffusers models"""
|
||||
widgets: Dict[str, npyscreen.widget] = {}
|
||||
|
||||
all_models = self.all_models # master dict of all models, indexed by key
|
||||
model_list = [x for x in self.starter_models if all_models[x].type in ["main", "vae"]]
|
||||
model_labels = [self.model_labels[x] for x in model_list]
|
||||
|
||||
widgets.update(
|
||||
label1=self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.",
|
||||
editable=False,
|
||||
labelColor="CAUTION",
|
||||
)
|
||||
)
|
||||
|
||||
self.nextrely -= 1
|
||||
# if user has already installed some initial models, then don't patronize them
|
||||
# by showing more recommendations
|
||||
show_recommended = len(self.installed_models) == 0
|
||||
|
||||
checked = [
|
||||
model_list.index(x)
|
||||
for x in model_list
|
||||
if (show_recommended and all_models[x].recommended) or all_models[x].installed
|
||||
]
|
||||
widgets.update(
|
||||
models_selected=self.add_widget_intelligent(
|
||||
MultiSelectColumns,
|
||||
columns=1,
|
||||
name="Install Starter Models",
|
||||
values=model_labels,
|
||||
value=checked,
|
||||
max_height=len(model_list) + 1,
|
||||
relx=4,
|
||||
scroll_exit=True,
|
||||
),
|
||||
models=model_list,
|
||||
)
|
||||
|
||||
self.nextrely += 1
|
||||
return widgets
|
||||
|
||||
############# Add a set of model install widgets ########
|
||||
def add_model_widgets(
|
||||
self,
|
||||
model_type: ModelType,
|
||||
window_width: int = 120,
|
||||
install_prompt: Optional[str] = None,
|
||||
exclude: Optional[Set[str]] = None,
|
||||
) -> dict[str, npyscreen.widget]:
|
||||
"""Generic code to create model selection widgets"""
|
||||
if exclude is None:
|
||||
exclude = set()
|
||||
widgets: Dict[str, npyscreen.widget] = {}
|
||||
all_models = self.all_models
|
||||
model_list = sorted(
|
||||
[x for x in all_models if all_models[x].type == model_type and x not in exclude],
|
||||
key=lambda x: all_models[x].name or "",
|
||||
)
|
||||
model_labels = [self.model_labels[x] for x in model_list]
|
||||
|
||||
show_recommended = len(self.installed_models) == 0
|
||||
truncated = False
|
||||
if len(model_list) > 0:
|
||||
max_width = max([len(x) for x in model_labels])
|
||||
columns = window_width // (max_width + 8) # 8 characters for "[x] " and padding
|
||||
columns = min(len(model_list), columns) or 1
|
||||
prompt = (
|
||||
install_prompt
|
||||
or f"Select the desired {model_type.value.title()} models to install. Unchecked models will be purged from disk."
|
||||
)
|
||||
|
||||
widgets.update(
|
||||
label1=self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name=prompt,
|
||||
editable=False,
|
||||
labelColor="CAUTION",
|
||||
)
|
||||
)
|
||||
|
||||
if len(model_labels) > MAX_OTHER_MODELS:
|
||||
model_labels = model_labels[0:MAX_OTHER_MODELS]
|
||||
truncated = True
|
||||
|
||||
widgets.update(
|
||||
models_selected=self.add_widget_intelligent(
|
||||
MultiSelectColumns,
|
||||
columns=columns,
|
||||
name=f"Install {model_type} Models",
|
||||
values=model_labels,
|
||||
value=[
|
||||
model_list.index(x)
|
||||
for x in model_list
|
||||
if (show_recommended and all_models[x].recommended) or all_models[x].installed
|
||||
],
|
||||
max_height=len(model_list) // columns + 1,
|
||||
relx=4,
|
||||
scroll_exit=True,
|
||||
),
|
||||
models=model_list,
|
||||
)
|
||||
|
||||
if truncated:
|
||||
widgets.update(
|
||||
warning_message=self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=f"Too many models to display (max={MAX_OTHER_MODELS}). Some are not displayed.",
|
||||
editable=False,
|
||||
color="CAUTION",
|
||||
)
|
||||
)
|
||||
|
||||
self.nextrely += 1
|
||||
widgets.update(
|
||||
download_ids=self.add_widget_intelligent(
|
||||
TextBox,
|
||||
name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):",
|
||||
max_height=6,
|
||||
scroll_exit=True,
|
||||
editable=True,
|
||||
)
|
||||
)
|
||||
return widgets
|
||||
|
||||
### Tab for arbitrary diffusers widgets ###
|
||||
def add_pipeline_widgets(
|
||||
self,
|
||||
model_type: ModelType = ModelType.Main,
|
||||
window_width: int = 120,
|
||||
**kwargs,
|
||||
) -> dict[str, npyscreen.widget]:
|
||||
"""Similar to add_model_widgets() but adds some additional widgets at the bottom
|
||||
to support the autoload directory"""
|
||||
widgets = self.add_model_widgets(
|
||||
model_type=model_type,
|
||||
window_width=window_width,
|
||||
install_prompt=f"Installed {model_type.value.title()} models. Unchecked models in the InvokeAI root directory will be deleted. Enter URLs, paths or repo_ids to import.",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return widgets
|
||||
|
||||
def resize(self) -> None:
|
||||
super().resize()
|
||||
if s := self.starter_pipelines.get("models_selected"):
|
||||
if model_list := self.starter_pipelines.get("models"):
|
||||
s.values = [self.model_labels[x] for x in model_list]
|
||||
|
||||
def _toggle_tables(self, value: List[int]) -> None:
|
||||
selected_tab = value[0]
|
||||
widgets = [
|
||||
self.starter_pipelines,
|
||||
self.pipeline_models,
|
||||
self.controlnet_models,
|
||||
self.t2i_models,
|
||||
self.ipadapter_models,
|
||||
self.lora_models,
|
||||
self.ti_models,
|
||||
]
|
||||
|
||||
for group in widgets:
|
||||
for _k, v in group.items():
|
||||
try:
|
||||
v.hidden = True
|
||||
v.editable = False
|
||||
except Exception:
|
||||
pass
|
||||
for _k, v in widgets[selected_tab].items():
|
||||
try:
|
||||
v.hidden = False
|
||||
if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
|
||||
v.editable = True
|
||||
except Exception:
|
||||
pass
|
||||
self.__class__.current_tab = selected_tab # for persistence
|
||||
self.display()
|
||||
|
||||
def _get_model_labels(self) -> dict[str, str]:
|
||||
"""Return a list of trimmed labels for all models."""
|
||||
window_width, window_height = get_terminal_size()
|
||||
checkbox_width = 4
|
||||
spacing_width = 2
|
||||
result = {}
|
||||
|
||||
models = self.all_models
|
||||
label_width = max([len(models[x].name or "") for x in self.starter_models])
|
||||
description_width = window_width - label_width - checkbox_width - spacing_width
|
||||
|
||||
for key in self.all_models:
|
||||
description = models[key].description
|
||||
description = (
|
||||
description[0 : description_width - 3] + "..."
|
||||
if description and len(description) > description_width
|
||||
else description
|
||||
if description
|
||||
else ""
|
||||
)
|
||||
result[key] = f"%-{label_width}s %s" % (models[key].name, description)
|
||||
|
||||
return result
|
||||
|
||||
def _get_columns(self) -> int:
|
||||
window_width, window_height = get_terminal_size()
|
||||
cols = 4 if window_width > 240 else 3 if window_width > 160 else 2 if window_width > 80 else 1
|
||||
return min(cols, len(self.installed_models))
|
||||
|
||||
def confirm_deletions(self, selections: InstallSelections) -> bool:
|
||||
remove_models = selections.remove_models
|
||||
if remove_models:
|
||||
model_names = [self.all_models[x].name or "" for x in remove_models]
|
||||
mods = "\n".join(model_names)
|
||||
is_ok = npyscreen.notify_ok_cancel(
|
||||
f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}"
|
||||
)
|
||||
assert isinstance(is_ok, bool) # npyscreen doesn't have return type annotations
|
||||
return is_ok
|
||||
else:
|
||||
return True
|
||||
|
||||
@property
|
||||
def all_models(self) -> Dict[str, UnifiedModelInfo]:
|
||||
# npyscreen doesn't having typing hints
|
||||
return self.parentApp.install_helper.all_models # type: ignore
|
||||
|
||||
@property
|
||||
def starter_models(self) -> List[str]:
|
||||
return self.parentApp.install_helper._starter_models # type: ignore
|
||||
|
||||
@property
|
||||
def installed_models(self) -> List[str]:
|
||||
return self.parentApp.install_helper._installed_models # type: ignore
|
||||
|
||||
def on_back(self) -> None:
|
||||
self.parentApp.switchFormPrevious()
|
||||
self.editing = False
|
||||
|
||||
def on_cancel(self) -> None:
|
||||
self.parentApp.setNextForm(None)
|
||||
self.parentApp.user_cancelled = True
|
||||
self.editing = False
|
||||
|
||||
def on_done(self) -> None:
|
||||
self.marshall_arguments()
|
||||
if not self.confirm_deletions(self.parentApp.install_selections):
|
||||
return
|
||||
self.parentApp.setNextForm(None)
|
||||
self.parentApp.user_cancelled = False
|
||||
self.editing = False
|
||||
|
||||
def marshall_arguments(self) -> None:
|
||||
"""
|
||||
Assemble arguments and store as attributes of the application:
|
||||
.starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml
|
||||
True => Install
|
||||
False => Remove
|
||||
.scan_directory: Path to a directory of models to scan and import
|
||||
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
||||
.import_model_paths: list of URLs, repo_ids and file paths to import
|
||||
"""
|
||||
selections = self.parentApp.install_selections
|
||||
all_models = self.all_models
|
||||
|
||||
# Defined models (in INITIAL_CONFIG.yaml or invokeai.db) to add/remove
|
||||
ui_sections = [
|
||||
self.starter_pipelines,
|
||||
self.pipeline_models,
|
||||
self.controlnet_models,
|
||||
self.t2i_models,
|
||||
self.ipadapter_models,
|
||||
self.lora_models,
|
||||
self.ti_models,
|
||||
]
|
||||
for section in ui_sections:
|
||||
if "models_selected" not in section:
|
||||
continue
|
||||
selected = {section["models"][x] for x in section["models_selected"].value}
|
||||
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
||||
models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
|
||||
selections.remove_models.extend(models_to_remove)
|
||||
selections.install_models.extend([all_models[x] for x in models_to_install])
|
||||
|
||||
# models located in the 'download_ids" section
|
||||
for section in ui_sections:
|
||||
if downloads := section.get("download_ids"):
|
||||
models = [UnifiedModelInfo(source=x) for x in downloads.value.split()]
|
||||
selections.install_models.extend(models)
|
||||
|
||||
|
||||
class AddModelApplication(npyscreen.NPSAppManaged): # type: ignore
|
||||
def __init__(self, opt: Namespace, install_helper: InstallHelper):
|
||||
super().__init__()
|
||||
self.program_opts = opt
|
||||
self.user_cancelled = False
|
||||
self.install_selections = InstallSelections()
|
||||
self.install_helper = install_helper
|
||||
|
||||
def onStart(self) -> None:
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main_form = self.addForm(
|
||||
"MAIN",
|
||||
addModelsForm,
|
||||
name="Install Stable Diffusion Models",
|
||||
cycle_widgets=False,
|
||||
)
|
||||
|
||||
|
||||
def list_models(installer: ModelInstallServiceBase, model_type: ModelType):
|
||||
"""Print out all models of type model_type."""
|
||||
models = installer.record_store.search_by_attr(model_type=model_type)
|
||||
print(f"Installed models of type `{model_type}`:")
|
||||
for model in models:
|
||||
path = (config.models_path / model.path).resolve()
|
||||
print(f"{model.name:40}{model.base.value:5}{model.type.value:8}{model.format.value:12}{path}")
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def select_and_download_models(opt: Namespace) -> None:
|
||||
"""Prompt user for install/delete selections and execute."""
|
||||
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||
config.precision = precision
|
||||
install_helper = InstallHelper(config, logger)
|
||||
installer = install_helper.installer
|
||||
|
||||
if opt.list_models:
|
||||
list_models(installer, opt.list_models)
|
||||
|
||||
elif opt.add or opt.delete:
|
||||
selections = InstallSelections(
|
||||
install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or []
|
||||
)
|
||||
install_helper.add_or_delete(selections)
|
||||
|
||||
elif opt.default_only:
|
||||
default_model = install_helper.default_model()
|
||||
assert default_model is not None
|
||||
selections = InstallSelections(install_models=[default_model])
|
||||
install_helper.add_or_delete(selections)
|
||||
|
||||
elif opt.yes_to_all:
|
||||
selections = InstallSelections(install_models=install_helper.recommended_models())
|
||||
install_helper.add_or_delete(selections)
|
||||
|
||||
# this is where the TUI is called
|
||||
else:
|
||||
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
|
||||
raise WindowTooSmallException(
|
||||
"Could not increase terminal size. Try running again with a larger window or smaller font size."
|
||||
)
|
||||
|
||||
installApp = AddModelApplication(opt, install_helper)
|
||||
try:
|
||||
installApp.run()
|
||||
except KeyboardInterrupt:
|
||||
print("Aborted...")
|
||||
sys.exit(-1)
|
||||
|
||||
install_helper.add_or_delete(installApp.install_selections)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||
parser.add_argument(
|
||||
"--add",
|
||||
nargs="*",
|
||||
help="List of URLs, local paths or repo_ids of models to install",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delete",
|
||||
nargs="*",
|
||||
help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full-precision",
|
||||
dest="full_precision",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="use 32-bit weights instead of faster 16-bit weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--yes",
|
||||
"-y",
|
||||
dest="yes_to_all",
|
||||
action="store_true",
|
||||
help='answer "yes" to all prompts',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--default_only",
|
||||
action="store_true",
|
||||
help="Only install the default model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list-models",
|
||||
choices=[x.value for x in ModelType],
|
||||
help="list installed models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
dest="root",
|
||||
type=pathlib.Path,
|
||||
default=None,
|
||||
help="path to root of install directory",
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
|
||||
invoke_args: dict[str, Any] = {}
|
||||
if opt.full_precision:
|
||||
invoke_args["precision"] = "float32"
|
||||
config.update_config(invoke_args)
|
||||
if opt.root:
|
||||
config.set_root(opt.root)
|
||||
|
||||
logger = InvokeAILogger().get_logger(config=config)
|
||||
|
||||
try:
|
||||
validate_directories(config)
|
||||
except AssertionError:
|
||||
logger.info("Your InvokeAI root directory is not set up. Calling invokeai-configure.")
|
||||
sys.argv = ["invokeai_configure", "--yes", "--skip-sd-weights"]
|
||||
from invokeai.frontend.install.invokeai_configure import invokeai_configure
|
||||
|
||||
invokeai_configure()
|
||||
sys.exit(0)
|
||||
|
||||
try:
|
||||
select_and_download_models(opt)
|
||||
except AssertionError as e:
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
curses.nocbreak()
|
||||
curses.echo()
|
||||
curses.endwin()
|
||||
logger.info("Goodbye! Come back soon.")
|
||||
except WindowTooSmallException as e:
|
||||
logger.error(str(e))
|
||||
except widget.NotEnoughSpaceForWidget as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
logger.error("Insufficient vertical space for the interface. Please make your window taller and try again")
|
||||
input("Press any key to continue...")
|
||||
except Exception as e:
|
||||
if str(e).startswith("addwstr"):
|
||||
logger.error(
|
||||
"Insufficient horizontal space for the interface. Please make your window wider and try again."
|
||||
)
|
||||
else:
|
||||
print(f"An exception has occurred: {str(e)} Details:")
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
input("Press any key to continue...")
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,441 +0,0 @@
|
||||
"""
|
||||
Widget class definitions used by model_select.py, merge_diffusers.py and textual_inversion.py
|
||||
"""
|
||||
|
||||
import curses
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
import struct
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
from curses import BUTTON2_CLICKED, BUTTON3_CLICKED
|
||||
from shutil import get_terminal_size
|
||||
from typing import Optional
|
||||
|
||||
import npyscreen
|
||||
import npyscreen.wgmultiline as wgmultiline
|
||||
import pyperclip
|
||||
from npyscreen import fmPopup
|
||||
|
||||
# minimum size for UIs
|
||||
MIN_COLS = 150
|
||||
MIN_LINES = 40
|
||||
|
||||
|
||||
class WindowTooSmallException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def set_terminal_size(columns: int, lines: int) -> bool:
|
||||
OS = platform.uname().system
|
||||
screen_ok = False
|
||||
while not screen_ok:
|
||||
ts = get_terminal_size()
|
||||
width = max(columns, ts.columns)
|
||||
height = max(lines, ts.lines)
|
||||
|
||||
if OS == "Windows":
|
||||
pass
|
||||
# not working reliably - ask user to adjust the window
|
||||
# _set_terminal_size_powershell(width,height)
|
||||
elif OS in ["Darwin", "Linux"]:
|
||||
_set_terminal_size_unix(width, height)
|
||||
|
||||
# check whether it worked....
|
||||
ts = get_terminal_size()
|
||||
if ts.columns < columns or ts.lines < lines:
|
||||
print(
|
||||
f"\033[1mThis window is too small for the interface. InvokeAI requires {columns}x{lines} (w x h) characters, but window is {ts.columns}x{ts.lines}\033[0m"
|
||||
)
|
||||
resp = input(
|
||||
"Maximize the window and/or decrease the font size then press any key to continue. Type [Q] to give up.."
|
||||
)
|
||||
if resp.upper().startswith("Q"):
|
||||
break
|
||||
else:
|
||||
screen_ok = True
|
||||
return screen_ok
|
||||
|
||||
|
||||
def _set_terminal_size_powershell(width: int, height: int):
|
||||
script = f"""
|
||||
$pshost = get-host
|
||||
$pswindow = $pshost.ui.rawui
|
||||
$newsize = $pswindow.buffersize
|
||||
$newsize.height = 3000
|
||||
$newsize.width = {width}
|
||||
$pswindow.buffersize = $newsize
|
||||
$newsize = $pswindow.windowsize
|
||||
$newsize.height = {height}
|
||||
$newsize.width = {width}
|
||||
$pswindow.windowsize = $newsize
|
||||
"""
|
||||
subprocess.run(["powershell", "-Command", "-"], input=script, text=True)
|
||||
|
||||
|
||||
def _set_terminal_size_unix(width: int, height: int):
|
||||
import fcntl
|
||||
import termios
|
||||
|
||||
# These terminals accept the size command and report that the
|
||||
# size changed, but they lie!!!
|
||||
for bad_terminal in ["TERMINATOR_UUID", "ALACRITTY_WINDOW_ID"]:
|
||||
if os.environ.get(bad_terminal):
|
||||
return
|
||||
|
||||
winsize = struct.pack("HHHH", height, width, 0, 0)
|
||||
fcntl.ioctl(sys.stdout.fileno(), termios.TIOCSWINSZ, winsize)
|
||||
sys.stdout.write("\x1b[8;{height};{width}t".format(height=height, width=width))
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def set_min_terminal_size(min_cols: int, min_lines: int) -> bool:
|
||||
# make sure there's enough room for the ui
|
||||
term_cols, term_lines = get_terminal_size()
|
||||
if term_cols >= min_cols and term_lines >= min_lines:
|
||||
return True
|
||||
cols = max(term_cols, min_cols)
|
||||
lines = max(term_lines, min_lines)
|
||||
return set_terminal_size(cols, lines)
|
||||
|
||||
|
||||
class IntSlider(npyscreen.Slider):
|
||||
def translate_value(self):
|
||||
stri = "%2d / %2d" % (self.value, self.out_of)
|
||||
length = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(length)
|
||||
return stri
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
# fix npyscreen form so that cursor wraps both forward and backward
|
||||
class CyclingForm(object):
|
||||
def find_previous_editable(self, *args):
|
||||
done = False
|
||||
n = self.editw - 1
|
||||
while not done:
|
||||
if self._widgets__[n].editable and not self._widgets__[n].hidden:
|
||||
self.editw = n
|
||||
done = True
|
||||
n -= 1
|
||||
if n < 0:
|
||||
if self.cycle_widgets:
|
||||
n = len(self._widgets__) - 1
|
||||
else:
|
||||
done = True
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
class CenteredTitleText(npyscreen.TitleText):
|
||||
def __init__(self, *args, **keywords):
|
||||
super().__init__(*args, **keywords)
|
||||
self.resize()
|
||||
|
||||
def resize(self):
|
||||
super().resize()
|
||||
maxy, maxx = self.parent.curses_pad.getmaxyx()
|
||||
label = self.name
|
||||
self.relx = (maxx - len(label)) // 2
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
class CenteredButtonPress(npyscreen.ButtonPress):
|
||||
def resize(self):
|
||||
super().resize()
|
||||
maxy, maxx = self.parent.curses_pad.getmaxyx()
|
||||
label = self.name
|
||||
self.relx = (maxx - len(label)) // 2
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
class OffsetButtonPress(npyscreen.ButtonPress):
|
||||
def __init__(self, screen, offset=0, *args, **keywords):
|
||||
super().__init__(screen, *args, **keywords)
|
||||
self.offset = offset
|
||||
|
||||
def resize(self):
|
||||
maxy, maxx = self.parent.curses_pad.getmaxyx()
|
||||
width = len(self.name)
|
||||
self.relx = self.offset + (maxx - width) // 2
|
||||
|
||||
|
||||
class IntTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = IntSlider
|
||||
|
||||
|
||||
class FloatSlider(npyscreen.Slider):
|
||||
# this is supposed to adjust display precision, but doesn't
|
||||
def translate_value(self):
|
||||
stri = "%3.2f / %3.2f" % (self.value, self.out_of)
|
||||
length = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(length)
|
||||
return stri
|
||||
|
||||
|
||||
class FloatTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = npyscreen.Slider
|
||||
|
||||
|
||||
class SelectColumnBase:
|
||||
"""Base class for selection widget arranged in columns."""
|
||||
|
||||
def make_contained_widgets(self):
|
||||
self._my_widgets = []
|
||||
column_width = self.width // self.columns
|
||||
for h in range(self.value_cnt):
|
||||
self._my_widgets.append(
|
||||
self._contained_widgets(
|
||||
self.parent,
|
||||
rely=self.rely + (h % self.rows) * self._contained_widget_height,
|
||||
relx=self.relx + (h // self.rows) * column_width,
|
||||
max_width=column_width,
|
||||
max_height=self.__class__._contained_widget_height,
|
||||
)
|
||||
)
|
||||
|
||||
def set_up_handlers(self):
|
||||
super().set_up_handlers()
|
||||
self.handlers.update(
|
||||
{
|
||||
curses.KEY_UP: self.h_cursor_line_left,
|
||||
curses.KEY_DOWN: self.h_cursor_line_right,
|
||||
}
|
||||
)
|
||||
|
||||
def h_cursor_line_down(self, ch):
|
||||
self.cursor_line += self.rows
|
||||
if self.cursor_line >= len(self.values):
|
||||
if self.scroll_exit:
|
||||
self.cursor_line = len(self.values) - self.rows
|
||||
self.h_exit_down(ch)
|
||||
return True
|
||||
else:
|
||||
self.cursor_line -= self.rows
|
||||
return True
|
||||
|
||||
def h_cursor_line_up(self, ch):
|
||||
self.cursor_line -= self.rows
|
||||
if self.cursor_line < 0:
|
||||
if self.scroll_exit:
|
||||
self.cursor_line = 0
|
||||
self.h_exit_up(ch)
|
||||
else:
|
||||
self.cursor_line = 0
|
||||
|
||||
def h_cursor_line_left(self, ch):
|
||||
super().h_cursor_line_up(ch)
|
||||
|
||||
def h_cursor_line_right(self, ch):
|
||||
super().h_cursor_line_down(ch)
|
||||
|
||||
def handle_mouse_event(self, mouse_event):
|
||||
mouse_id, rel_x, rel_y, z, bstate = self.interpret_mouse_event(mouse_event)
|
||||
column_width = self.width // self.columns
|
||||
column_height = math.ceil(self.value_cnt / self.columns)
|
||||
column_no = rel_x // column_width
|
||||
row_no = rel_y // self._contained_widget_height
|
||||
self.cursor_line = column_no * column_height + row_no
|
||||
if bstate & curses.BUTTON1_DOUBLE_CLICKED:
|
||||
if hasattr(self, "on_mouse_double_click"):
|
||||
self.on_mouse_double_click(self.cursor_line)
|
||||
self.display()
|
||||
|
||||
|
||||
class MultiSelectColumns(SelectColumnBase, npyscreen.MultiSelect):
|
||||
def __init__(self, screen, columns: int = 1, values: Optional[list] = None, **keywords):
|
||||
if values is None:
|
||||
values = []
|
||||
self.columns = columns
|
||||
self.value_cnt = len(values)
|
||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||
super().__init__(screen, values=values, **keywords)
|
||||
|
||||
def on_mouse_double_click(self, cursor_line):
|
||||
self.h_select_toggle(cursor_line)
|
||||
|
||||
|
||||
class SingleSelectWithChanged(npyscreen.SelectOne):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.on_changed = None
|
||||
|
||||
def h_select(self, ch):
|
||||
super().h_select(ch)
|
||||
if self.on_changed:
|
||||
self.on_changed(self.value)
|
||||
|
||||
|
||||
class CheckboxWithChanged(npyscreen.Checkbox):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.on_changed = None
|
||||
|
||||
def whenToggled(self):
|
||||
super().whenToggled()
|
||||
if self.on_changed:
|
||||
self.on_changed(self.value)
|
||||
|
||||
|
||||
class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged):
|
||||
"""Row of radio buttons. Spacebar to select."""
|
||||
|
||||
def __init__(self, screen, columns: int = 1, values: list = None, **keywords):
|
||||
if values is None:
|
||||
values = []
|
||||
self.columns = columns
|
||||
self.value_cnt = len(values)
|
||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||
self.on_changed = None
|
||||
super().__init__(screen, values=values, **keywords)
|
||||
|
||||
def h_cursor_line_right(self, ch):
|
||||
self.h_exit_down("bye bye")
|
||||
|
||||
def h_cursor_line_left(self, ch):
|
||||
self.h_exit_up("bye bye")
|
||||
|
||||
|
||||
class SingleSelectColumns(SingleSelectColumnsSimple):
|
||||
"""Row of radio buttons. When tabbing over a selection, it is auto selected."""
|
||||
|
||||
def when_cursor_moved(self):
|
||||
self.h_select(self.cursor_line)
|
||||
|
||||
|
||||
class TextBoxInner(npyscreen.MultiLineEdit):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.yank = None
|
||||
self.handlers.update(
|
||||
{
|
||||
"^A": self.h_cursor_to_start,
|
||||
"^E": self.h_cursor_to_end,
|
||||
"^K": self.h_kill,
|
||||
"^F": self.h_cursor_right,
|
||||
"^B": self.h_cursor_left,
|
||||
"^Y": self.h_yank,
|
||||
"^V": self.h_paste,
|
||||
}
|
||||
)
|
||||
|
||||
def h_cursor_to_start(self, input):
|
||||
self.cursor_position = 0
|
||||
|
||||
def h_cursor_to_end(self, input):
|
||||
self.cursor_position = len(self.value)
|
||||
|
||||
def h_kill(self, input):
|
||||
self.yank = self.value[self.cursor_position :]
|
||||
self.value = self.value[: self.cursor_position]
|
||||
|
||||
def h_yank(self, input):
|
||||
if self.yank:
|
||||
self.paste(self.yank)
|
||||
|
||||
def paste(self, text: str):
|
||||
self.value = self.value[: self.cursor_position] + text + self.value[self.cursor_position :]
|
||||
self.cursor_position += len(text)
|
||||
|
||||
def h_paste(self, input: int = 0):
|
||||
try:
|
||||
text = pyperclip.paste()
|
||||
except ModuleNotFoundError:
|
||||
text = "To paste with the mouse on Linux, please install the 'xclip' program."
|
||||
self.paste(text)
|
||||
|
||||
def handle_mouse_event(self, mouse_event):
|
||||
mouse_id, rel_x, rel_y, z, bstate = self.interpret_mouse_event(mouse_event)
|
||||
if bstate & (BUTTON2_CLICKED | BUTTON3_CLICKED):
|
||||
self.h_paste()
|
||||
|
||||
|
||||
class TextBox(npyscreen.BoxTitle):
|
||||
_contained_widget = TextBoxInner
|
||||
|
||||
|
||||
class BufferBox(npyscreen.BoxTitle):
|
||||
_contained_widget = npyscreen.BufferPager
|
||||
|
||||
|
||||
class ConfirmCancelPopup(fmPopup.ActionPopup):
|
||||
DEFAULT_COLUMNS = 100
|
||||
|
||||
def on_ok(self):
|
||||
self.value = True
|
||||
|
||||
def on_cancel(self):
|
||||
self.value = False
|
||||
|
||||
|
||||
class FileBox(npyscreen.BoxTitle):
|
||||
_contained_widget = npyscreen.Filename
|
||||
|
||||
|
||||
class PrettyTextBox(npyscreen.BoxTitle):
|
||||
_contained_widget = TextBox
|
||||
|
||||
|
||||
def _wrap_message_lines(message, line_length):
|
||||
lines = []
|
||||
for line in message.split("\n"):
|
||||
lines.extend(textwrap.wrap(line.rstrip(), line_length))
|
||||
return lines
|
||||
|
||||
|
||||
def _prepare_message(message):
|
||||
if isinstance(message, list) or isinstance(message, tuple):
|
||||
return "\n".join([s.rstrip() for s in message])
|
||||
# return "\n".join(message)
|
||||
else:
|
||||
return message
|
||||
|
||||
|
||||
def select_stable_diffusion_config_file(
|
||||
form_color: str = "DANGER",
|
||||
wrap: bool = True,
|
||||
model_name: str = "Unknown",
|
||||
):
|
||||
message = f"Please select the correct prediction type for the checkpoint named '{model_name}'. Press <CANCEL> to skip installation."
|
||||
title = "CONFIG FILE SELECTION"
|
||||
options = [
|
||||
"'epsilon' - most v1.5 models and v2 models trained on 512 pixel images",
|
||||
"'vprediction' - v2 models trained on 768 pixel images and a few v1.5 models)",
|
||||
"Accept the best guess; you can fix it in the Web UI later",
|
||||
]
|
||||
|
||||
F = ConfirmCancelPopup(
|
||||
name=title,
|
||||
color=form_color,
|
||||
cycle_widgets=True,
|
||||
lines=16,
|
||||
)
|
||||
F.preserve_selected_widget = True
|
||||
|
||||
mlw = F.add(
|
||||
wgmultiline.Pager,
|
||||
max_height=4,
|
||||
editable=False,
|
||||
)
|
||||
mlw_width = mlw.width - 1
|
||||
if wrap:
|
||||
message = _wrap_message_lines(message, mlw_width)
|
||||
mlw.values = message
|
||||
|
||||
choice = F.add(
|
||||
npyscreen.SelectOne,
|
||||
values=options,
|
||||
value=[2],
|
||||
max_height=len(options) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
|
||||
F.editw = 1
|
||||
F.edit()
|
||||
if not F.value:
|
||||
return None
|
||||
assert choice.value[0] in range(0, 3), "invalid choice"
|
||||
choices = ["epsilon", "v", "guess"]
|
||||
return choices[choice.value[0]]
|
@ -1,22 +0,0 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--web", action="store_true")
|
||||
opts, _ = parser.parse_known_args()
|
||||
|
||||
if opts.web:
|
||||
sys.argv.pop(sys.argv.index("--web"))
|
||||
from invokeai.app.api_app import invoke_api
|
||||
|
||||
invoke_api()
|
||||
else:
|
||||
from invokeai.app.cli_app import invoke_cli
|
||||
|
||||
invoke_cli()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,5 +0,0 @@
|
||||
"""
|
||||
Initialization file for invokeai.frontend.merge
|
||||
"""
|
||||
|
||||
from .merge_diffusers import main as invokeai_merge_diffusers # noqa: F401
|
@ -1,448 +0,0 @@
|
||||
"""
|
||||
invokeai.frontend.merge exports a single function call merge_diffusion_models()
|
||||
used to merge 2-3 models together and create a new InvokeAI-registered diffusion model.
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import curses
|
||||
import re
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import npyscreen
|
||||
from npyscreen import widget
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from invokeai.app.services.model_install import ModelInstallService
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager import (
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import ModelMerger
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox
|
||||
|
||||
config = get_config()
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
BASE_TYPES = [
|
||||
(BaseModelType.StableDiffusion1, "Models Built on SD-1.x"),
|
||||
(BaseModelType.StableDiffusion2, "Models Built on SD-2.x"),
|
||||
(BaseModelType.StableDiffusionXL, "Models Built on SDXL"),
|
||||
]
|
||||
|
||||
|
||||
def _parse_args() -> Namespace:
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
type=Path,
|
||||
default=config.root_path,
|
||||
help="Path to the invokeai runtime directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--front_end",
|
||||
"--gui",
|
||||
dest="front_end",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
dest="model_names",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Two to three model names to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_model",
|
||||
type=str,
|
||||
choices=[x[0].value for x in BASE_TYPES],
|
||||
help="The base model shared by the models to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merged_model_name",
|
||||
"--destination",
|
||||
dest="merged_model_name",
|
||||
type=str,
|
||||
help="Name of the output model. If not specified, will be the concatenation of the input model names.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation",
|
||||
dest="interp",
|
||||
type=str,
|
||||
choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"],
|
||||
default="weighted_sum",
|
||||
help='Interpolation method to use. If three models are present, only "add_difference" will work.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Try to merge models even if they are incompatible with each other",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clobber",
|
||||
"--overwrite",
|
||||
dest="clobber",
|
||||
action="store_true",
|
||||
help="Overwrite the merged model if --merged_model_name already exists",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# ------------------------- GUI HERE -------------------------
|
||||
class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid"]
|
||||
|
||||
def __init__(self, parentApp, name):
|
||||
self.parentApp = parentApp
|
||||
self.ALLOW_RESIZE = True
|
||||
self.FIX_MINIMUM_SIZE_WHEN_CREATED = False
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
@property
|
||||
def record_store(self):
|
||||
return self.parentApp.record_store
|
||||
|
||||
def afterEditing(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self):
|
||||
window_height, window_width = curses.initscr().getmaxyx()
|
||||
self.current_base = 0
|
||||
self.models = self.get_models(BASE_TYPES[self.current_base][0])
|
||||
self.model_names = [x[1] for x in self.models]
|
||||
max_width = max([len(x) for x in self.model_names])
|
||||
max_width += 6
|
||||
horizontal_layout = max_width * 3 < window_width
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
color="CONTROL",
|
||||
value="Select two models to merge and optionally a third.",
|
||||
editable=False,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
color="CONTROL",
|
||||
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
|
||||
editable=False,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.base_select = self.add_widget_intelligent(
|
||||
SingleSelectColumns,
|
||||
values=[x[1] for x in BASE_TYPES],
|
||||
value=[self.current_base],
|
||||
columns=4,
|
||||
max_height=2,
|
||||
relx=8,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.base_select.on_changed = self._populate_models
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="MODEL 1",
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
rely=6 if horizontal_layout else None,
|
||||
)
|
||||
self.model1 = self.add_widget_intelligent(
|
||||
npyscreen.SelectOne,
|
||||
values=self.model_names,
|
||||
value=0,
|
||||
max_height=len(self.model_names),
|
||||
max_width=max_width,
|
||||
scroll_exit=True,
|
||||
rely=7,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="MODEL 2",
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
relx=max_width + 3 if horizontal_layout else None,
|
||||
rely=6 if horizontal_layout else None,
|
||||
)
|
||||
self.model2 = self.add_widget_intelligent(
|
||||
npyscreen.SelectOne,
|
||||
name="(2)",
|
||||
values=self.model_names,
|
||||
value=1,
|
||||
max_height=len(self.model_names),
|
||||
max_width=max_width,
|
||||
relx=max_width + 3 if horizontal_layout else None,
|
||||
rely=7 if horizontal_layout else None,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="MODEL 3",
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||
rely=6 if horizontal_layout else None,
|
||||
)
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0, "None")
|
||||
self.model3 = self.add_widget_intelligent(
|
||||
npyscreen.SelectOne,
|
||||
name="(3)",
|
||||
values=models_plus_none,
|
||||
value=0,
|
||||
max_height=len(self.model_names) + 1,
|
||||
max_width=max_width,
|
||||
scroll_exit=True,
|
||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||
rely=7 if horizontal_layout else None,
|
||||
)
|
||||
for m in [self.model1, self.model2, self.model3]:
|
||||
m.when_value_edited = self.models_changed
|
||||
self.merged_model_name = self.add_widget_intelligent(
|
||||
TextBox,
|
||||
name="Name for merged model:",
|
||||
labelColor="CONTROL",
|
||||
max_height=3,
|
||||
value="",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.force = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Force merge of models created by different diffusers library versions",
|
||||
labelColor="CONTROL",
|
||||
value=True,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.merge_method = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Merge Method:",
|
||||
values=self.interpolations,
|
||||
value=0,
|
||||
labelColor="CONTROL",
|
||||
max_height=len(self.interpolations) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.alpha = self.add_widget_intelligent(
|
||||
FloatTitleSlider,
|
||||
name="Weight (alpha) to assign to second and third models:",
|
||||
out_of=1.0,
|
||||
step=0.01,
|
||||
lowest=0,
|
||||
value=0.5,
|
||||
labelColor="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.model1.editing = True
|
||||
|
||||
def models_changed(self):
|
||||
models = self.model1.values
|
||||
selected_model1 = self.model1.value[0]
|
||||
selected_model2 = self.model2.value[0]
|
||||
selected_model3 = self.model3.value[0]
|
||||
merged_model_name = f"{models[selected_model1]}+{models[selected_model2]}"
|
||||
self.merged_model_name.value = merged_model_name
|
||||
|
||||
if selected_model3 > 0:
|
||||
self.merge_method.values = ["add_difference ( A+(B-C) )"]
|
||||
self.merged_model_name.value += f"+{models[selected_model3 -1]}" # In model3 there is one more element in the list (None). So we have to subtract one.
|
||||
else:
|
||||
self.merge_method.values = self.interpolations
|
||||
self.merge_method.value = 0
|
||||
|
||||
def on_ok(self):
|
||||
if self.validate_field_values() and self.check_for_overwrite():
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.merge_arguments = self.marshall_arguments()
|
||||
npyscreen.notify("Starting the merge...")
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def on_cancel(self):
|
||||
sys.exit(0)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
model_keys = [x[0] for x in self.models]
|
||||
models = [
|
||||
model_keys[self.model1.value[0]],
|
||||
model_keys[self.model2.value[0]],
|
||||
]
|
||||
if self.model3.value[0] > 0:
|
||||
models.append(model_keys[self.model3.value[0] - 1])
|
||||
interp = "add_difference"
|
||||
else:
|
||||
interp = self.interpolations[self.merge_method.value[0]]
|
||||
|
||||
args = {
|
||||
"model_keys": models,
|
||||
"base_model": tuple(BaseModelType)[self.base_select.value[0]],
|
||||
"alpha": self.alpha.value,
|
||||
"interp": interp,
|
||||
"force": self.force.value,
|
||||
"merged_model_name": self.merged_model_name.value,
|
||||
}
|
||||
return args
|
||||
|
||||
def check_for_overwrite(self) -> bool:
|
||||
model_out = self.merged_model_name.value
|
||||
if model_out not in self.model_names:
|
||||
return True
|
||||
else:
|
||||
return npyscreen.notify_yes_no(
|
||||
f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
|
||||
)
|
||||
|
||||
def validate_field_values(self) -> bool:
|
||||
bad_fields = []
|
||||
model_names = self.model_names
|
||||
selected_models = {model_names[self.model1.value[0]], model_names[self.model2.value[0]]}
|
||||
if self.model3.value[0] > 0:
|
||||
selected_models.add(model_names[self.model3.value[0] - 1])
|
||||
if len(selected_models) < 2:
|
||||
bad_fields.append(f"Please select two or three DIFFERENT models to compare. You selected {selected_models}")
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:"
|
||||
for problem in bad_fields:
|
||||
message += f"\n* {problem}"
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name
|
||||
models = [
|
||||
(x.key, x.name)
|
||||
for x in self.record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model)
|
||||
if x.format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal")
|
||||
]
|
||||
return sorted(models, key=lambda x: x[1])
|
||||
|
||||
def _populate_models(self, value: List[int]):
|
||||
base_model = BASE_TYPES[value[0]][0]
|
||||
self.models = self.get_models(base_model)
|
||||
self.model_names = [x[1] for x in self.models]
|
||||
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0, "None")
|
||||
self.model1.values = self.model_names
|
||||
self.model2.values = self.model_names
|
||||
self.model3.values = models_plus_none
|
||||
|
||||
self.display()
|
||||
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self, record_store: ModelRecordServiceBase):
|
||||
super().__init__()
|
||||
self.record_store = record_store
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
|
||||
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
|
||||
|
||||
|
||||
def run_gui(args: Namespace) -> None:
|
||||
record_store: ModelRecordServiceBase = get_config_store()
|
||||
mergeapp = Mergeapp(record_store)
|
||||
mergeapp.run()
|
||||
args = mergeapp.merge_arguments
|
||||
merger = get_model_merger(record_store)
|
||||
merger.merge_diffusion_models_and_save(**args)
|
||||
merged_model_name = args["merged_model_name"]
|
||||
logger.info(f'Models merged into new model: "{merged_model_name}".')
|
||||
|
||||
|
||||
def run_cli(args: Namespace):
|
||||
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
||||
assert (
|
||||
args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3
|
||||
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
|
||||
|
||||
if not args.merged_model_name:
|
||||
args.merged_model_name = "+".join(args.model_names)
|
||||
logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
|
||||
|
||||
record_store: ModelRecordServiceBase = get_config_store()
|
||||
assert (
|
||||
len(record_store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
|
||||
merger = get_model_merger(record_store)
|
||||
model_keys = []
|
||||
for name in args.model_names:
|
||||
if len(name) == 32 and re.match(r"^[0-9a-f]$", name):
|
||||
model_keys.append(name)
|
||||
else:
|
||||
models = record_store.search_by_attr(
|
||||
model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model)
|
||||
)
|
||||
assert len(models) > 0, f"{name}: Unknown model"
|
||||
assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead."
|
||||
model_keys.append(models[0].key)
|
||||
|
||||
merger.merge_diffusion_models_and_save(
|
||||
alpha=args.alpha,
|
||||
model_keys=model_keys,
|
||||
merged_model_name=args.merged_model_name,
|
||||
interp=args.interp,
|
||||
force=args.force,
|
||||
)
|
||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||
|
||||
|
||||
def get_config_store() -> ModelRecordServiceSQL:
|
||||
output_path = config.outputs_path
|
||||
assert output_path is not None
|
||||
image_files = DiskImageFileStorage(output_path / "images")
|
||||
db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files)
|
||||
return ModelRecordServiceSQL(db)
|
||||
|
||||
|
||||
def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger:
|
||||
installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=DownloadQueueService())
|
||||
installer.start()
|
||||
return ModelMerger(installer)
|
||||
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
if args.root_dir:
|
||||
config.set_root(Path(args.root_dir))
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
run_gui(args)
|
||||
else:
|
||||
run_cli(args)
|
||||
except widget.NotEnoughSpaceForWidget as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
logger.error("You need to have at least two diffusers models in order to merge")
|
||||
else:
|
||||
logger.error("Not enough room for the user interface. Try making this window larger.")
|
||||
sys.exit(-1)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,5 +0,0 @@
|
||||
"""
|
||||
Initialization file for invokeai.frontend.training
|
||||
"""
|
||||
|
||||
from .textual_inversion import main as invokeai_textual_inversion # noqa: F401
|
@ -1,452 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
This is the frontend to "textual_inversion_training.py".
|
||||
|
||||
Copyright (c) 2023-24 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import npyscreen
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.install.install_helper import initialize_installer
|
||||
from invokeai.backend.model_manager import ModelType
|
||||
from invokeai.backend.training import do_textual_inversion_training, parse_args
|
||||
|
||||
TRAINING_DATA = "text-inversion-training-data"
|
||||
TRAINING_DIR = "text-inversion-output"
|
||||
CONF_FILE = "preferences.conf"
|
||||
config = None
|
||||
|
||||
|
||||
class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
resolutions = [512, 768, 1024]
|
||||
lr_schedulers = [
|
||||
"linear",
|
||||
"cosine",
|
||||
"cosine_with_restarts",
|
||||
"polynomial",
|
||||
"constant",
|
||||
"constant_with_warmup",
|
||||
]
|
||||
precisions = ["no", "fp16", "bf16"]
|
||||
learnable_properties = ["object", "style"]
|
||||
|
||||
def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, saved_args: Optional[Dict[str, str]] = None):
|
||||
self.saved_args = saved_args or {}
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
def afterEditing(self) -> None:
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self) -> None:
|
||||
self.model_names, default = self.get_model_names()
|
||||
default_initializer_token = "★"
|
||||
default_placeholder_token = ""
|
||||
saved_args = self.saved_args
|
||||
|
||||
assert config is not None
|
||||
|
||||
try:
|
||||
default = self.model_names.index(saved_args["model"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields, cursor arrows to make a selection, and space to toggle checkboxes.",
|
||||
editable=False,
|
||||
)
|
||||
|
||||
self.model = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Model Name:",
|
||||
values=sorted(self.model_names),
|
||||
value=default,
|
||||
max_height=len(self.model_names) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.placeholder_token = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Trigger Term:",
|
||||
value="", # saved_args.get('placeholder_token',''), # to restore previous term
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.placeholder_token.when_value_edited = self.initializer_changed
|
||||
self.nextrely -= 1
|
||||
self.nextrelx += 30
|
||||
self.prompt_token = self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
name="Trigger term for use in prompt",
|
||||
value="",
|
||||
editable=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrelx -= 30
|
||||
self.initializer_token = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Initializer:",
|
||||
value=saved_args.get("initializer_token", default_initializer_token),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.resume_from_checkpoint = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Resume from last saved checkpoint",
|
||||
value=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.learnable_property = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Learnable property:",
|
||||
values=self.learnable_properties,
|
||||
value=self.learnable_properties.index(saved_args.get("learnable_property", "object")),
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.train_data_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="Data Training Directory:",
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
value=str(
|
||||
saved_args.get(
|
||||
"train_data_dir",
|
||||
config.root_path / TRAINING_DATA / default_placeholder_token,
|
||||
)
|
||||
),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.output_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="Output Destination Directory:",
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
value=str(
|
||||
saved_args.get(
|
||||
"output_dir",
|
||||
config.root_path / TRAINING_DIR / default_placeholder_token,
|
||||
)
|
||||
),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.resolution = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Image resolution (pixels):",
|
||||
values=self.resolutions,
|
||||
value=self.resolutions.index(saved_args.get("resolution", 512)),
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.center_crop = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Center crop images before resizing to resolution",
|
||||
value=saved_args.get("center_crop", False),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.mixed_precision = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Mixed Precision:",
|
||||
values=self.precisions,
|
||||
value=self.precisions.index(saved_args.get("mixed_precision", "fp16")),
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.num_train_epochs = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Number of training epochs:",
|
||||
out_of=1000,
|
||||
step=50,
|
||||
lowest=1,
|
||||
value=saved_args.get("num_train_epochs", 100),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.max_train_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Max Training Steps:",
|
||||
out_of=10000,
|
||||
step=500,
|
||||
lowest=1,
|
||||
value=saved_args.get("max_train_steps", 3000),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.train_batch_size = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Batch Size (reduce if you run out of memory):",
|
||||
out_of=50,
|
||||
step=1,
|
||||
lowest=1,
|
||||
value=saved_args.get("train_batch_size", 8),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.gradient_accumulation_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Gradient Accumulation Steps (may need to decrease this to resume from a checkpoint):",
|
||||
out_of=10,
|
||||
step=1,
|
||||
lowest=1,
|
||||
value=saved_args.get("gradient_accumulation_steps", 4),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.lr_warmup_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Warmup Steps:",
|
||||
out_of=100,
|
||||
step=1,
|
||||
lowest=0,
|
||||
value=saved_args.get("lr_warmup_steps", 0),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.learning_rate = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Learning Rate:",
|
||||
value=str(
|
||||
saved_args.get("learning_rate", "5.0e-04"),
|
||||
),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.scale_lr = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Scale learning rate by number GPUs, steps and batch size",
|
||||
value=saved_args.get("scale_lr", True),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.enable_xformers_memory_efficient_attention = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Use xformers acceleration",
|
||||
value=saved_args.get("enable_xformers_memory_efficient_attention", False),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.lr_scheduler = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Learning rate scheduler:",
|
||||
values=self.lr_schedulers,
|
||||
max_height=7,
|
||||
value=self.lr_schedulers.index(saved_args.get("lr_scheduler", "constant")),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.model.editing = True
|
||||
|
||||
def initializer_changed(self) -> None:
|
||||
placeholder = self.placeholder_token.value
|
||||
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
|
||||
self.train_data_dir.value = str(config.root_path / TRAINING_DATA / placeholder)
|
||||
self.output_dir.value = str(config.root_path / TRAINING_DIR / placeholder)
|
||||
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
|
||||
|
||||
def on_ok(self):
|
||||
if self.validate_field_values():
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.ti_arguments = self.marshall_arguments()
|
||||
npyscreen.notify("Launching textual inversion training. This will take a while...")
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def ok_cancel(self):
|
||||
sys.exit(0)
|
||||
|
||||
def validate_field_values(self) -> bool:
|
||||
bad_fields = []
|
||||
if self.model.value is None:
|
||||
bad_fields.append("Model Name must correspond to a known model in invokeai.db")
|
||||
if not re.match("^[a-zA-Z0-9.-]+$", self.placeholder_token.value):
|
||||
bad_fields.append("Trigger term must only contain alphanumeric characters, the dot and hyphen")
|
||||
if self.train_data_dir.value is None:
|
||||
bad_fields.append("Data Training Directory cannot be empty")
|
||||
if self.output_dir.value is None:
|
||||
bad_fields.append("The Output Destination Directory cannot be empty")
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:"
|
||||
for problem in bad_fields:
|
||||
message += f"\n* {problem}"
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self) -> Tuple[List[str], int]:
|
||||
global config
|
||||
assert config is not None
|
||||
installer = initialize_installer(config)
|
||||
store = installer.record_store
|
||||
main_models = store.search_by_attr(model_type=ModelType.Main)
|
||||
model_names = [f"{x.base.value}/{x.type.value}/{x.name}" for x in main_models if x.format == "diffusers"]
|
||||
default = 0
|
||||
return (model_names, default)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
args = {}
|
||||
|
||||
# the choices
|
||||
args.update(
|
||||
model=self.model_names[self.model.value[0]],
|
||||
resolution=self.resolutions[self.resolution.value[0]],
|
||||
lr_scheduler=self.lr_schedulers[self.lr_scheduler.value[0]],
|
||||
mixed_precision=self.precisions[self.mixed_precision.value[0]],
|
||||
learnable_property=self.learnable_properties[self.learnable_property.value[0]],
|
||||
)
|
||||
|
||||
# all the strings and booleans
|
||||
for attr in (
|
||||
"initializer_token",
|
||||
"placeholder_token",
|
||||
"train_data_dir",
|
||||
"output_dir",
|
||||
"scale_lr",
|
||||
"center_crop",
|
||||
"enable_xformers_memory_efficient_attention",
|
||||
):
|
||||
args[attr] = getattr(self, attr).value
|
||||
|
||||
# all the integers
|
||||
for attr in (
|
||||
"train_batch_size",
|
||||
"gradient_accumulation_steps",
|
||||
"num_train_epochs",
|
||||
"max_train_steps",
|
||||
"lr_warmup_steps",
|
||||
):
|
||||
args[attr] = int(getattr(self, attr).value)
|
||||
|
||||
# the floats (just one)
|
||||
args.update(learning_rate=float(self.learning_rate.value))
|
||||
|
||||
# a special case
|
||||
if self.resume_from_checkpoint.value and Path(self.output_dir.value).exists():
|
||||
args["resume_from_checkpoint"] = "latest"
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class MyApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, saved_args: Optional[Dict[str, str]] = None):
|
||||
super().__init__()
|
||||
self.ti_arguments = None
|
||||
self.saved_args = saved_args
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main = self.addForm(
|
||||
"MAIN",
|
||||
textualInversionForm,
|
||||
name="Textual Inversion Settings",
|
||||
saved_args=self.saved_args,
|
||||
)
|
||||
|
||||
|
||||
def copy_to_embeddings_folder(args: Dict[str, str]) -> None:
|
||||
"""
|
||||
Copy learned_embeds.bin into the embeddings folder, and offer to
|
||||
delete the full model and checkpoints.
|
||||
"""
|
||||
assert config is not None
|
||||
source = Path(args["output_dir"], "learned_embeds.bin")
|
||||
dest_dir_name = args["placeholder_token"].strip("<>")
|
||||
destination = config.root_path / "embeddings" / dest_dir_name
|
||||
os.makedirs(destination, exist_ok=True)
|
||||
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
|
||||
shutil.copy(source, destination)
|
||||
if (input("Delete training logs and intermediate checkpoints? [y] ") or "y").startswith(("y", "Y")):
|
||||
shutil.rmtree(Path(args["output_dir"]))
|
||||
else:
|
||||
logger.info(f'Keeping {args["output_dir"]}')
|
||||
|
||||
|
||||
def save_args(args: dict) -> None:
|
||||
"""
|
||||
Save the current argument values to an omegaconf file
|
||||
"""
|
||||
assert config is not None
|
||||
dest_dir = config.root_path / TRAINING_DIR
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
conf_file = dest_dir / CONF_FILE
|
||||
conf = OmegaConf.create(args)
|
||||
OmegaConf.save(config=conf, f=conf_file)
|
||||
|
||||
|
||||
def previous_args() -> dict:
|
||||
"""
|
||||
Get the previous arguments used.
|
||||
"""
|
||||
assert config is not None
|
||||
conf_file = config.root_path / TRAINING_DIR / CONF_FILE
|
||||
try:
|
||||
conf = OmegaConf.load(conf_file)
|
||||
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
|
||||
except Exception:
|
||||
conf = None
|
||||
|
||||
return conf
|
||||
|
||||
|
||||
def do_front_end() -> None:
|
||||
global config
|
||||
saved_args = previous_args()
|
||||
myapplication = MyApplication(saved_args=saved_args)
|
||||
myapplication.run()
|
||||
|
||||
if my_args := myapplication.ti_arguments:
|
||||
os.makedirs(my_args["output_dir"], exist_ok=True)
|
||||
|
||||
# Automatically add angle brackets around the trigger
|
||||
if not re.match("^<.+>$", my_args["placeholder_token"]):
|
||||
my_args["placeholder_token"] = f"<{my_args['placeholder_token']}>"
|
||||
|
||||
my_args["only_save_embeds"] = True
|
||||
save_args(my_args)
|
||||
|
||||
try:
|
||||
print(my_args)
|
||||
do_textual_inversion_training(config, **my_args)
|
||||
copy_to_embeddings_folder(my_args)
|
||||
except Exception as e:
|
||||
logger.error("An exception occurred during training. The exception was:")
|
||||
logger.error(str(e))
|
||||
logger.error("DETAILS:")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
def main() -> None:
|
||||
global config
|
||||
|
||||
args: Namespace = parse_args()
|
||||
config = get_config()
|
||||
|
||||
# change root if needed
|
||||
if args.root_dir:
|
||||
config.set_root(args.root_dir)
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
do_front_end()
|
||||
else:
|
||||
do_textual_inversion_training(config, **vars(args))
|
||||
except AssertionError as e:
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except (widget.NotEnoughSpaceForWidget, Exception) as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
logger.error("You need to have at least one diffusers models defined in invokeai.db in order to train")
|
||||
elif str(e).startswith("addwstr"):
|
||||
logger.error("Not enough window space for the interface. Please make your window larger and try again.")
|
||||
else:
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user