mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix invokeai_configure script to work with new mm; rename CLIs
This commit is contained in:
committed by
psychedelicious
parent
78ef946e01
commit
db340bc253
@ -6,47 +6,45 @@
|
||||
|
||||
"""
|
||||
This is the npyscreen frontend to the model installation application.
|
||||
The work is actually done in backend code in model_install_backend.py.
|
||||
It is currently named model_install2.py, but will ultimately replace model_install.py.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import curses
|
||||
import logging
|
||||
import sys
|
||||
import textwrap
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from multiprocessing import Process
|
||||
from multiprocessing.connection import Connection, Pipe
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import npyscreen
|
||||
import torch
|
||||
from npyscreen import widget
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType
|
||||
from invokeai.backend.model_management import ModelManager, ModelType
|
||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo
|
||||
from invokeai.backend.model_manager import ModelType
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.widgets import (
|
||||
MIN_COLS,
|
||||
MIN_LINES,
|
||||
BufferBox,
|
||||
CenteredTitleText,
|
||||
CyclingForm,
|
||||
MultiSelectColumns,
|
||||
SingleSelectColumns,
|
||||
TextBox,
|
||||
WindowTooSmallException,
|
||||
select_stable_diffusion_config_file,
|
||||
set_min_terminal_size,
|
||||
)
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning) # noqa: E402
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.get_logger()
|
||||
logger = InvokeAILogger.get_logger("ModelInstallService")
|
||||
logger.setLevel("WARNING")
|
||||
# logger.setLevel('DEBUG')
|
||||
|
||||
# build a table mapping all non-printable characters to None
|
||||
# for stripping control characters
|
||||
@ -58,44 +56,42 @@ MAX_OTHER_MODELS = 72
|
||||
|
||||
|
||||
def make_printable(s: str) -> str:
|
||||
"""Replace non-printable characters in a string"""
|
||||
"""Replace non-printable characters in a string."""
|
||||
return s.translate(NOPRINT_TRANS_TABLE)
|
||||
|
||||
|
||||
class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
"""Main form for interactive TUI."""
|
||||
|
||||
# for responsive resizing set to False, but this seems to cause a crash!
|
||||
FIX_MINIMUM_SIZE_WHEN_CREATED = True
|
||||
|
||||
# for persistence
|
||||
current_tab = 0
|
||||
|
||||
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
|
||||
def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, multipage: bool = False, **keywords: Any):
|
||||
self.multipage = multipage
|
||||
self.subprocess = None
|
||||
super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad?
|
||||
super().__init__(parentApp=parentApp, name=name, **keywords)
|
||||
|
||||
def create(self):
|
||||
def create(self) -> None:
|
||||
self.installer = self.parentApp.install_helper.installer
|
||||
self.model_labels = self._get_model_labels()
|
||||
self.keypress_timeout = 10
|
||||
self.counter = 0
|
||||
self.subprocess_connection = None
|
||||
|
||||
if not config.model_conf_path.exists():
|
||||
with open(config.model_conf_path, "w") as file:
|
||||
print("# InvokeAI model configuration file", file=file)
|
||||
self.installer = ModelInstall(config)
|
||||
self.all_models = self.installer.all_models()
|
||||
self.starter_models = self.installer.starter_models()
|
||||
self.model_labels = self._get_model_labels()
|
||||
window_width, window_height = get_terminal_size()
|
||||
|
||||
self.nextrely -= 1
|
||||
# npyscreen has no typing hints
|
||||
self.nextrely -= 1 # type: ignore
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields. Cursor keys navigate, and <space> selects.",
|
||||
editable=False,
|
||||
color="CAUTION",
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.nextrely += 1 # type: ignore
|
||||
self.tabs = self.add_widget_intelligent(
|
||||
SingleSelectColumns,
|
||||
values=[
|
||||
@ -115,9 +111,9 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
)
|
||||
self.tabs.on_changed = self._toggle_tables
|
||||
|
||||
top_of_table = self.nextrely
|
||||
top_of_table = self.nextrely # type: ignore
|
||||
self.starter_pipelines = self.add_starter_pipelines()
|
||||
bottom_of_table = self.nextrely
|
||||
bottom_of_table = self.nextrely # type: ignore
|
||||
|
||||
self.nextrely = top_of_table
|
||||
self.pipeline_models = self.add_pipeline_widgets(
|
||||
@ -162,15 +158,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
|
||||
self.nextrely = bottom_of_table + 1
|
||||
|
||||
self.monitor = self.add_widget_intelligent(
|
||||
BufferBox,
|
||||
name="Log Messages",
|
||||
editable=False,
|
||||
max_height=6,
|
||||
)
|
||||
|
||||
self.nextrely += 1
|
||||
done_label = "APPLY CHANGES"
|
||||
back_label = "BACK"
|
||||
cancel_label = "CANCEL"
|
||||
current_position = self.nextrely
|
||||
@ -186,14 +174,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel
|
||||
)
|
||||
self.nextrely = current_position
|
||||
self.ok_button = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
name=done_label,
|
||||
relx=(window_width - len(done_label)) // 2,
|
||||
when_pressed_function=self.on_execute,
|
||||
)
|
||||
|
||||
label = "APPLY CHANGES & EXIT"
|
||||
label = "APPLY CHANGES"
|
||||
self.nextrely = current_position
|
||||
self.done = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
@ -210,17 +192,16 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
############# diffusers tab ##########
|
||||
def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
|
||||
"""Add widgets responsible for selecting diffusers models"""
|
||||
widgets = {}
|
||||
models = self.all_models
|
||||
starters = self.starter_models
|
||||
starter_model_labels = self.model_labels
|
||||
widgets: Dict[str, npyscreen.widget] = {}
|
||||
|
||||
self.installed_models = sorted([x for x in starters if models[x].installed])
|
||||
all_models = self.all_models # master dict of all models, indexed by key
|
||||
model_list = [x for x in self.starter_models if all_models[x].type in ["main", "vae"]]
|
||||
model_labels = [self.model_labels[x] for x in model_list]
|
||||
|
||||
widgets.update(
|
||||
label1=self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="Select from a starter set of Stable Diffusion models from HuggingFace.",
|
||||
name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.",
|
||||
editable=False,
|
||||
labelColor="CAUTION",
|
||||
)
|
||||
@ -230,23 +211,24 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
# if user has already installed some initial models, then don't patronize them
|
||||
# by showing more recommendations
|
||||
show_recommended = len(self.installed_models) == 0
|
||||
keys = [x for x in models.keys() if x in starters]
|
||||
|
||||
checked = [
|
||||
model_list.index(x)
|
||||
for x in model_list
|
||||
if (show_recommended and all_models[x].recommended) or all_models[x].installed
|
||||
]
|
||||
widgets.update(
|
||||
models_selected=self.add_widget_intelligent(
|
||||
MultiSelectColumns,
|
||||
columns=1,
|
||||
name="Install Starter Models",
|
||||
values=[starter_model_labels[x] for x in keys],
|
||||
value=[
|
||||
keys.index(x)
|
||||
for x in keys
|
||||
if (show_recommended and models[x].recommended) or (x in self.installed_models)
|
||||
],
|
||||
max_height=len(starters) + 1,
|
||||
values=model_labels,
|
||||
value=checked,
|
||||
max_height=len(model_list) + 1,
|
||||
relx=4,
|
||||
scroll_exit=True,
|
||||
),
|
||||
models=keys,
|
||||
models=model_list,
|
||||
)
|
||||
|
||||
self.nextrely += 1
|
||||
@ -257,14 +239,18 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
self,
|
||||
model_type: ModelType,
|
||||
window_width: int = 120,
|
||||
install_prompt: str = None,
|
||||
exclude: set = None,
|
||||
install_prompt: Optional[str] = None,
|
||||
exclude: Optional[Set[str]] = None,
|
||||
) -> dict[str, npyscreen.widget]:
|
||||
"""Generic code to create model selection widgets"""
|
||||
if exclude is None:
|
||||
exclude = set()
|
||||
widgets = {}
|
||||
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
|
||||
widgets: Dict[str, npyscreen.widget] = {}
|
||||
all_models = self.all_models
|
||||
model_list = sorted(
|
||||
[x for x in all_models if all_models[x].type == model_type and x not in exclude],
|
||||
key=lambda x: all_models[x].name or "",
|
||||
)
|
||||
model_labels = [self.model_labels[x] for x in model_list]
|
||||
|
||||
show_recommended = len(self.installed_models) == 0
|
||||
@ -300,7 +286,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
value=[
|
||||
model_list.index(x)
|
||||
for x in model_list
|
||||
if (show_recommended and self.all_models[x].recommended) or self.all_models[x].installed
|
||||
if (show_recommended and all_models[x].recommended) or all_models[x].installed
|
||||
],
|
||||
max_height=len(model_list) // columns + 1,
|
||||
relx=4,
|
||||
@ -324,7 +310,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
download_ids=self.add_widget_intelligent(
|
||||
TextBox,
|
||||
name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):",
|
||||
max_height=4,
|
||||
max_height=6,
|
||||
scroll_exit=True,
|
||||
editable=True,
|
||||
)
|
||||
@ -349,13 +335,13 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
|
||||
return widgets
|
||||
|
||||
def resize(self):
|
||||
def resize(self) -> None:
|
||||
super().resize()
|
||||
if s := self.starter_pipelines.get("models_selected"):
|
||||
keys = [x for x in self.all_models.keys() if x in self.starter_models]
|
||||
s.values = [self.model_labels[x] for x in keys]
|
||||
if model_list := self.starter_pipelines.get("models"):
|
||||
s.values = [self.model_labels[x] for x in model_list]
|
||||
|
||||
def _toggle_tables(self, value=None):
|
||||
def _toggle_tables(self, value: List[int]) -> None:
|
||||
selected_tab = value[0]
|
||||
widgets = [
|
||||
self.starter_pipelines,
|
||||
@ -385,17 +371,18 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
self.display()
|
||||
|
||||
def _get_model_labels(self) -> dict[str, str]:
|
||||
"""Return a list of trimmed labels for all models."""
|
||||
window_width, window_height = get_terminal_size()
|
||||
checkbox_width = 4
|
||||
spacing_width = 2
|
||||
result = {}
|
||||
|
||||
models = self.all_models
|
||||
label_width = max([len(models[x].name) for x in models])
|
||||
label_width = max([len(models[x].name or "") for x in self.starter_models])
|
||||
description_width = window_width - label_width - checkbox_width - spacing_width
|
||||
|
||||
result = {}
|
||||
for x in models.keys():
|
||||
description = models[x].description
|
||||
for key in self.all_models:
|
||||
description = models[key].description
|
||||
description = (
|
||||
description[0 : description_width - 3] + "..."
|
||||
if description and len(description) > description_width
|
||||
@ -403,7 +390,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
if description
|
||||
else ""
|
||||
)
|
||||
result[x] = f"%-{label_width}s %s" % (models[x].name, description)
|
||||
result[key] = f"%-{label_width}s %s" % (models[key].name, description)
|
||||
|
||||
return result
|
||||
|
||||
def _get_columns(self) -> int:
|
||||
@ -413,50 +401,40 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
|
||||
def confirm_deletions(self, selections: InstallSelections) -> bool:
|
||||
remove_models = selections.remove_models
|
||||
if len(remove_models) > 0:
|
||||
mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models])
|
||||
return npyscreen.notify_ok_cancel(
|
||||
if remove_models:
|
||||
model_names = [self.all_models[x].name or "" for x in remove_models]
|
||||
mods = "\n".join(model_names)
|
||||
is_ok = npyscreen.notify_ok_cancel(
|
||||
f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}"
|
||||
)
|
||||
assert isinstance(is_ok, bool) # npyscreen doesn't have return type annotations
|
||||
return is_ok
|
||||
else:
|
||||
return True
|
||||
|
||||
def on_execute(self):
|
||||
self.marshall_arguments()
|
||||
app = self.parentApp
|
||||
if not self.confirm_deletions(app.install_selections):
|
||||
return
|
||||
@property
|
||||
def all_models(self) -> Dict[str, UnifiedModelInfo]:
|
||||
# npyscreen doesn't having typing hints
|
||||
return self.parentApp.install_helper.all_models # type: ignore
|
||||
|
||||
self.monitor.entry_widget.buffer(["Processing..."], scroll_end=True)
|
||||
self.ok_button.hidden = True
|
||||
self.display()
|
||||
@property
|
||||
def starter_models(self) -> List[str]:
|
||||
return self.parentApp.install_helper._starter_models # type: ignore
|
||||
|
||||
# TO DO: Spawn a worker thread, not a subprocess
|
||||
parent_conn, child_conn = Pipe()
|
||||
p = Process(
|
||||
target=process_and_execute,
|
||||
kwargs={
|
||||
"opt": app.program_opts,
|
||||
"selections": app.install_selections,
|
||||
"conn_out": child_conn,
|
||||
},
|
||||
)
|
||||
p.start()
|
||||
child_conn.close()
|
||||
self.subprocess_connection = parent_conn
|
||||
self.subprocess = p
|
||||
app.install_selections = InstallSelections()
|
||||
@property
|
||||
def installed_models(self) -> List[str]:
|
||||
return self.parentApp.install_helper._installed_models # type: ignore
|
||||
|
||||
def on_back(self):
|
||||
def on_back(self) -> None:
|
||||
self.parentApp.switchFormPrevious()
|
||||
self.editing = False
|
||||
|
||||
def on_cancel(self):
|
||||
def on_cancel(self) -> None:
|
||||
self.parentApp.setNextForm(None)
|
||||
self.parentApp.user_cancelled = True
|
||||
self.editing = False
|
||||
|
||||
def on_done(self):
|
||||
def on_done(self) -> None:
|
||||
self.marshall_arguments()
|
||||
if not self.confirm_deletions(self.parentApp.install_selections):
|
||||
return
|
||||
@ -464,77 +442,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
self.parentApp.user_cancelled = False
|
||||
self.editing = False
|
||||
|
||||
########## This routine monitors the child process that is performing model installation and removal #####
|
||||
def while_waiting(self):
|
||||
"""Called during idle periods. Main task is to update the Log Messages box with messages
|
||||
from the child process that does the actual installation/removal"""
|
||||
c = self.subprocess_connection
|
||||
if not c:
|
||||
return
|
||||
|
||||
monitor_widget = self.monitor.entry_widget
|
||||
while c.poll():
|
||||
try:
|
||||
data = c.recv_bytes().decode("utf-8")
|
||||
data.strip("\n")
|
||||
|
||||
# processing child is requesting user input to select the
|
||||
# right configuration file
|
||||
if data.startswith("*need v2 config"):
|
||||
_, model_path, *_ = data.split(":", 2)
|
||||
self._return_v2_config(model_path)
|
||||
|
||||
# processing child is done
|
||||
elif data == "*done*":
|
||||
self._close_subprocess_and_regenerate_form()
|
||||
break
|
||||
|
||||
# update the log message box
|
||||
else:
|
||||
data = make_printable(data)
|
||||
data = data.replace("[A", "")
|
||||
monitor_widget.buffer(
|
||||
textwrap.wrap(
|
||||
data,
|
||||
width=monitor_widget.width,
|
||||
subsequent_indent=" ",
|
||||
),
|
||||
scroll_end=True,
|
||||
)
|
||||
self.display()
|
||||
except (EOFError, OSError):
|
||||
self.subprocess_connection = None
|
||||
|
||||
def _return_v2_config(self, model_path: str):
|
||||
c = self.subprocess_connection
|
||||
model_name = Path(model_path).name
|
||||
message = select_stable_diffusion_config_file(model_name=model_name)
|
||||
c.send_bytes(message.encode("utf-8"))
|
||||
|
||||
def _close_subprocess_and_regenerate_form(self):
|
||||
app = self.parentApp
|
||||
self.subprocess_connection.close()
|
||||
self.subprocess_connection = None
|
||||
self.monitor.entry_widget.buffer(["** Action Complete **"])
|
||||
self.display()
|
||||
|
||||
# rebuild the form, saving and restoring some of the fields that need to be preserved.
|
||||
saved_messages = self.monitor.entry_widget.values
|
||||
|
||||
app.main_form = app.addForm(
|
||||
"MAIN",
|
||||
addModelsForm,
|
||||
name="Install Stable Diffusion Models",
|
||||
multipage=self.multipage,
|
||||
)
|
||||
app.switchForm("MAIN")
|
||||
|
||||
app.main_form.monitor.entry_widget.values = saved_messages
|
||||
app.main_form.monitor.entry_widget.buffer([""], scroll_end=True)
|
||||
# app.main_form.pipeline_models['autoload_directory'].value = autoload_dir
|
||||
# app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan
|
||||
|
||||
def marshall_arguments(self):
|
||||
def marshall_arguments(self) -> None:
|
||||
"""
|
||||
Assemble arguments and store as attributes of the application:
|
||||
.starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml
|
||||
@ -564,46 +472,24 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
||||
models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
|
||||
selections.remove_models.extend(models_to_remove)
|
||||
selections.install_models.extend(
|
||||
all_models[x].path or all_models[x].repo_id
|
||||
for x in models_to_install
|
||||
if all_models[x].path or all_models[x].repo_id
|
||||
)
|
||||
selections.install_models.extend([all_models[x] for x in models_to_install])
|
||||
|
||||
# models located in the 'download_ids" section
|
||||
for section in ui_sections:
|
||||
if downloads := section.get("download_ids"):
|
||||
selections.install_models.extend(downloads.value.split())
|
||||
|
||||
# NOT NEEDED - DONE IN BACKEND NOW
|
||||
# # special case for the ipadapter_models. If any of the adapters are
|
||||
# # chosen, then we add the corresponding encoder(s) to the install list.
|
||||
# section = self.ipadapter_models
|
||||
# if section.get("models_selected"):
|
||||
# selected_adapters = [
|
||||
# self.all_models[section["models"][x]].name for x in section.get("models_selected").value
|
||||
# ]
|
||||
# encoders = []
|
||||
# if any(["sdxl" in x for x in selected_adapters]):
|
||||
# encoders.append("ip_adapter_sdxl_image_encoder")
|
||||
# if any(["sd15" in x for x in selected_adapters]):
|
||||
# encoders.append("ip_adapter_sd_image_encoder")
|
||||
# for encoder in encoders:
|
||||
# key = f"any/clip_vision/{encoder}"
|
||||
# repo_id = f"InvokeAI/{encoder}"
|
||||
# if key not in self.all_models:
|
||||
# selections.install_models.append(repo_id)
|
||||
models = [UnifiedModelInfo(source=x) for x in downloads.value.split()]
|
||||
selections.install_models.extend(models)
|
||||
|
||||
|
||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, opt):
|
||||
class AddModelApplication(npyscreen.NPSAppManaged): # type: ignore
|
||||
def __init__(self, opt: Namespace, install_helper: InstallHelper):
|
||||
super().__init__()
|
||||
self.program_opts = opt
|
||||
self.user_cancelled = False
|
||||
# self.autoload_pending = True
|
||||
self.install_selections = InstallSelections()
|
||||
self.install_helper = install_helper
|
||||
|
||||
def onStart(self):
|
||||
def onStart(self) -> None:
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main_form = self.addForm(
|
||||
"MAIN",
|
||||
@ -613,138 +499,62 @@ class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
)
|
||||
|
||||
|
||||
class StderrToMessage:
|
||||
def __init__(self, connection: Connection):
|
||||
self.connection = connection
|
||||
|
||||
def write(self, data: str):
|
||||
self.connection.send_bytes(data.encode("utf-8"))
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
def list_models(installer: ModelInstallServiceBase, model_type: ModelType):
|
||||
"""Print out all models of type model_type."""
|
||||
models = installer.record_store.search_by_attr(model_type=model_type)
|
||||
print(f"Installed models of type `{model_type}`:")
|
||||
for model in models:
|
||||
path = (config.models_path / model.path).resolve()
|
||||
print(f"{model.name:40}{model.base.value:5}{model.type.value:8}{model.format.value:12}{path}")
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None) -> SchedulerPredictionType:
|
||||
if tui_conn:
|
||||
logger.debug("Waiting for user response...")
|
||||
return _ask_user_for_pt_tui(model_path, tui_conn)
|
||||
else:
|
||||
return _ask_user_for_pt_cmdline(model_path)
|
||||
|
||||
|
||||
def _ask_user_for_pt_cmdline(model_path: Path) -> Optional[SchedulerPredictionType]:
|
||||
choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
|
||||
print(
|
||||
f"""
|
||||
Please select the scheduler prediction type of the checkpoint named {model_path.name}:
|
||||
[1] "epsilon" - most v1.5 models and v2 models trained on 512 pixel images
|
||||
[2] "vprediction" - v2 models trained on 768 pixel images and a few v1.5 models
|
||||
[3] Accept the best guess; you can fix it in the Web UI later
|
||||
"""
|
||||
)
|
||||
choice = None
|
||||
ok = False
|
||||
while not ok:
|
||||
try:
|
||||
choice = input("select [3]> ").strip()
|
||||
if not choice:
|
||||
return None
|
||||
choice = choices[int(choice) - 1]
|
||||
ok = True
|
||||
except (ValueError, IndexError):
|
||||
print(f"{choice} is not a valid choice")
|
||||
except EOFError:
|
||||
return
|
||||
return choice
|
||||
|
||||
|
||||
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType:
|
||||
tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8"))
|
||||
# note that we don't do any status checking here
|
||||
response = tui_conn.recv_bytes().decode("utf-8")
|
||||
if response is None:
|
||||
return None
|
||||
elif response == "epsilon":
|
||||
return SchedulerPredictionType.epsilon
|
||||
elif response == "v":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
elif response == "guess":
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def process_and_execute(
|
||||
opt: Namespace,
|
||||
selections: InstallSelections,
|
||||
conn_out: Connection = None,
|
||||
):
|
||||
# need to reinitialize config in subprocess
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
args = ["--root", opt.root] if opt.root else []
|
||||
config.parse_args(args)
|
||||
|
||||
# set up so that stderr is sent to conn_out
|
||||
if conn_out:
|
||||
translator = StderrToMessage(conn_out)
|
||||
sys.stderr = translator
|
||||
sys.stdout = translator
|
||||
logger = InvokeAILogger.get_logger()
|
||||
logger.handlers.clear()
|
||||
logger.addHandler(logging.StreamHandler(translator))
|
||||
|
||||
installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x, conn_out))
|
||||
installer.install(selections)
|
||||
|
||||
if conn_out:
|
||||
conn_out.send_bytes("*done*".encode("utf-8"))
|
||||
conn_out.close()
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def select_and_download_models(opt: Namespace):
|
||||
def select_and_download_models(opt: Namespace) -> None:
|
||||
"""Prompt user for install/delete selections and execute."""
|
||||
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||
config.precision = precision
|
||||
installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type)
|
||||
# unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal
|
||||
config.precision = precision # type: ignore
|
||||
install_helper = InstallHelper(config, logger)
|
||||
installer = install_helper.installer
|
||||
|
||||
if opt.list_models:
|
||||
installer.list_models(opt.list_models)
|
||||
list_models(installer, opt.list_models)
|
||||
|
||||
elif opt.add or opt.delete:
|
||||
selections = InstallSelections(install_models=opt.add or [], remove_models=opt.delete or [])
|
||||
installer.install(selections)
|
||||
selections = InstallSelections(
|
||||
install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or []
|
||||
)
|
||||
install_helper.add_or_delete(selections)
|
||||
|
||||
elif opt.default_only:
|
||||
selections = InstallSelections(install_models=installer.default_model())
|
||||
installer.install(selections)
|
||||
default_model = install_helper.default_model()
|
||||
assert default_model is not None
|
||||
selections = InstallSelections(install_models=[default_model])
|
||||
install_helper.add_or_delete(selections)
|
||||
|
||||
elif opt.yes_to_all:
|
||||
selections = InstallSelections(install_models=installer.recommended_models())
|
||||
installer.install(selections)
|
||||
selections = InstallSelections(install_models=install_helper.recommended_models())
|
||||
install_helper.add_or_delete(selections)
|
||||
|
||||
# this is where the TUI is called
|
||||
else:
|
||||
# needed to support the probe() method running under a subprocess
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
|
||||
raise WindowTooSmallException(
|
||||
"Could not increase terminal size. Try running again with a larger window or smaller font size."
|
||||
)
|
||||
|
||||
installApp = AddModelApplication(opt)
|
||||
installApp = AddModelApplication(opt, install_helper)
|
||||
try:
|
||||
installApp.run()
|
||||
except KeyboardInterrupt as e:
|
||||
if hasattr(installApp, "main_form"):
|
||||
if installApp.main_form.subprocess and installApp.main_form.subprocess.is_alive():
|
||||
logger.info("Terminating subprocesses")
|
||||
installApp.main_form.subprocess.terminate()
|
||||
installApp.main_form.subprocess = None
|
||||
raise e
|
||||
process_and_execute(opt, installApp.install_selections)
|
||||
except KeyboardInterrupt:
|
||||
print("Aborted...")
|
||||
sys.exit(-1)
|
||||
|
||||
install_helper.add_or_delete(installApp.install_selections)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||
parser.add_argument(
|
||||
"--add",
|
||||
@ -754,7 +564,7 @@ def main():
|
||||
parser.add_argument(
|
||||
"--delete",
|
||||
nargs="*",
|
||||
help="List of names of models to idelete",
|
||||
help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full-precision",
|
||||
@ -781,14 +591,6 @@ def main():
|
||||
choices=[x.value for x in ModelType],
|
||||
help="list installed models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
"-c",
|
||||
dest="config_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to configuration file to create",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
dest="root",
|
||||
|
@ -6,45 +6,47 @@
|
||||
|
||||
"""
|
||||
This is the npyscreen frontend to the model installation application.
|
||||
It is currently named model_install2.py, but will ultimately replace model_install.py.
|
||||
The work is actually done in backend code in model_install_backend.py.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import curses
|
||||
import logging
|
||||
import sys
|
||||
import textwrap
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from multiprocessing import Process
|
||||
from multiprocessing.connection import Connection, Pipe
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Optional
|
||||
|
||||
import npyscreen
|
||||
import torch
|
||||
from npyscreen import widget
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo
|
||||
from invokeai.backend.model_manager import ModelType
|
||||
from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType
|
||||
from invokeai.backend.model_management import ModelManager, ModelType
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.widgets import (
|
||||
MIN_COLS,
|
||||
MIN_LINES,
|
||||
BufferBox,
|
||||
CenteredTitleText,
|
||||
CyclingForm,
|
||||
MultiSelectColumns,
|
||||
SingleSelectColumns,
|
||||
TextBox,
|
||||
WindowTooSmallException,
|
||||
select_stable_diffusion_config_file,
|
||||
set_min_terminal_size,
|
||||
)
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning) # noqa: E402
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.get_logger("ModelInstallService")
|
||||
logger.setLevel("WARNING")
|
||||
# logger.setLevel('DEBUG')
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
# build a table mapping all non-printable characters to None
|
||||
# for stripping control characters
|
||||
@ -56,42 +58,44 @@ MAX_OTHER_MODELS = 72
|
||||
|
||||
|
||||
def make_printable(s: str) -> str:
|
||||
"""Replace non-printable characters in a string."""
|
||||
"""Replace non-printable characters in a string"""
|
||||
return s.translate(NOPRINT_TRANS_TABLE)
|
||||
|
||||
|
||||
class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
"""Main form for interactive TUI."""
|
||||
|
||||
# for responsive resizing set to False, but this seems to cause a crash!
|
||||
FIX_MINIMUM_SIZE_WHEN_CREATED = True
|
||||
|
||||
# for persistence
|
||||
current_tab = 0
|
||||
|
||||
def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, multipage: bool = False, **keywords: Any):
|
||||
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
|
||||
self.multipage = multipage
|
||||
self.subprocess = None
|
||||
super().__init__(parentApp=parentApp, name=name, **keywords)
|
||||
super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad?
|
||||
|
||||
def create(self) -> None:
|
||||
self.installer = self.parentApp.install_helper.installer
|
||||
self.model_labels = self._get_model_labels()
|
||||
def create(self):
|
||||
self.keypress_timeout = 10
|
||||
self.counter = 0
|
||||
self.subprocess_connection = None
|
||||
|
||||
if not config.model_conf_path.exists():
|
||||
with open(config.model_conf_path, "w") as file:
|
||||
print("# InvokeAI model configuration file", file=file)
|
||||
self.installer = ModelInstall(config)
|
||||
self.all_models = self.installer.all_models()
|
||||
self.starter_models = self.installer.starter_models()
|
||||
self.model_labels = self._get_model_labels()
|
||||
window_width, window_height = get_terminal_size()
|
||||
|
||||
# npyscreen has no typing hints
|
||||
self.nextrely -= 1 # type: ignore
|
||||
self.nextrely -= 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields. Cursor keys navigate, and <space> selects.",
|
||||
editable=False,
|
||||
color="CAUTION",
|
||||
)
|
||||
self.nextrely += 1 # type: ignore
|
||||
self.nextrely += 1
|
||||
self.tabs = self.add_widget_intelligent(
|
||||
SingleSelectColumns,
|
||||
values=[
|
||||
@ -111,9 +115,9 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
)
|
||||
self.tabs.on_changed = self._toggle_tables
|
||||
|
||||
top_of_table = self.nextrely # type: ignore
|
||||
top_of_table = self.nextrely
|
||||
self.starter_pipelines = self.add_starter_pipelines()
|
||||
bottom_of_table = self.nextrely # type: ignore
|
||||
bottom_of_table = self.nextrely
|
||||
|
||||
self.nextrely = top_of_table
|
||||
self.pipeline_models = self.add_pipeline_widgets(
|
||||
@ -158,7 +162,15 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
|
||||
self.nextrely = bottom_of_table + 1
|
||||
|
||||
self.monitor = self.add_widget_intelligent(
|
||||
BufferBox,
|
||||
name="Log Messages",
|
||||
editable=False,
|
||||
max_height=6,
|
||||
)
|
||||
|
||||
self.nextrely += 1
|
||||
done_label = "APPLY CHANGES"
|
||||
back_label = "BACK"
|
||||
cancel_label = "CANCEL"
|
||||
current_position = self.nextrely
|
||||
@ -174,8 +186,14 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel
|
||||
)
|
||||
self.nextrely = current_position
|
||||
self.ok_button = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
name=done_label,
|
||||
relx=(window_width - len(done_label)) // 2,
|
||||
when_pressed_function=self.on_execute,
|
||||
)
|
||||
|
||||
label = "APPLY CHANGES"
|
||||
label = "APPLY CHANGES & EXIT"
|
||||
self.nextrely = current_position
|
||||
self.done = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
@ -192,16 +210,17 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
############# diffusers tab ##########
|
||||
def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
|
||||
"""Add widgets responsible for selecting diffusers models"""
|
||||
widgets: Dict[str, npyscreen.widget] = {}
|
||||
widgets = {}
|
||||
models = self.all_models
|
||||
starters = self.starter_models
|
||||
starter_model_labels = self.model_labels
|
||||
|
||||
all_models = self.all_models # master dict of all models, indexed by key
|
||||
model_list = [x for x in self.starter_models if all_models[x].type in ["main", "vae"]]
|
||||
model_labels = [self.model_labels[x] for x in model_list]
|
||||
self.installed_models = sorted([x for x in starters if models[x].installed])
|
||||
|
||||
widgets.update(
|
||||
label1=self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.",
|
||||
name="Select from a starter set of Stable Diffusion models from HuggingFace.",
|
||||
editable=False,
|
||||
labelColor="CAUTION",
|
||||
)
|
||||
@ -211,24 +230,23 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
# if user has already installed some initial models, then don't patronize them
|
||||
# by showing more recommendations
|
||||
show_recommended = len(self.installed_models) == 0
|
||||
|
||||
checked = [
|
||||
model_list.index(x)
|
||||
for x in model_list
|
||||
if (show_recommended and all_models[x].recommended) or all_models[x].installed
|
||||
]
|
||||
keys = [x for x in models.keys() if x in starters]
|
||||
widgets.update(
|
||||
models_selected=self.add_widget_intelligent(
|
||||
MultiSelectColumns,
|
||||
columns=1,
|
||||
name="Install Starter Models",
|
||||
values=model_labels,
|
||||
value=checked,
|
||||
max_height=len(model_list) + 1,
|
||||
values=[starter_model_labels[x] for x in keys],
|
||||
value=[
|
||||
keys.index(x)
|
||||
for x in keys
|
||||
if (show_recommended and models[x].recommended) or (x in self.installed_models)
|
||||
],
|
||||
max_height=len(starters) + 1,
|
||||
relx=4,
|
||||
scroll_exit=True,
|
||||
),
|
||||
models=model_list,
|
||||
models=keys,
|
||||
)
|
||||
|
||||
self.nextrely += 1
|
||||
@ -239,18 +257,14 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
self,
|
||||
model_type: ModelType,
|
||||
window_width: int = 120,
|
||||
install_prompt: Optional[str] = None,
|
||||
exclude: Optional[Set[str]] = None,
|
||||
install_prompt: str = None,
|
||||
exclude: set = None,
|
||||
) -> dict[str, npyscreen.widget]:
|
||||
"""Generic code to create model selection widgets"""
|
||||
if exclude is None:
|
||||
exclude = set()
|
||||
widgets: Dict[str, npyscreen.widget] = {}
|
||||
all_models = self.all_models
|
||||
model_list = sorted(
|
||||
[x for x in all_models if all_models[x].type == model_type and x not in exclude],
|
||||
key=lambda x: all_models[x].name or "",
|
||||
)
|
||||
widgets = {}
|
||||
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
|
||||
model_labels = [self.model_labels[x] for x in model_list]
|
||||
|
||||
show_recommended = len(self.installed_models) == 0
|
||||
@ -286,7 +300,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
value=[
|
||||
model_list.index(x)
|
||||
for x in model_list
|
||||
if (show_recommended and all_models[x].recommended) or all_models[x].installed
|
||||
if (show_recommended and self.all_models[x].recommended) or self.all_models[x].installed
|
||||
],
|
||||
max_height=len(model_list) // columns + 1,
|
||||
relx=4,
|
||||
@ -310,7 +324,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
download_ids=self.add_widget_intelligent(
|
||||
TextBox,
|
||||
name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):",
|
||||
max_height=6,
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
editable=True,
|
||||
)
|
||||
@ -335,13 +349,13 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
|
||||
return widgets
|
||||
|
||||
def resize(self) -> None:
|
||||
def resize(self):
|
||||
super().resize()
|
||||
if s := self.starter_pipelines.get("models_selected"):
|
||||
if model_list := self.starter_pipelines.get("models"):
|
||||
s.values = [self.model_labels[x] for x in model_list]
|
||||
keys = [x for x in self.all_models.keys() if x in self.starter_models]
|
||||
s.values = [self.model_labels[x] for x in keys]
|
||||
|
||||
def _toggle_tables(self, value: List[int]) -> None:
|
||||
def _toggle_tables(self, value=None):
|
||||
selected_tab = value[0]
|
||||
widgets = [
|
||||
self.starter_pipelines,
|
||||
@ -371,18 +385,17 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
self.display()
|
||||
|
||||
def _get_model_labels(self) -> dict[str, str]:
|
||||
"""Return a list of trimmed labels for all models."""
|
||||
window_width, window_height = get_terminal_size()
|
||||
checkbox_width = 4
|
||||
spacing_width = 2
|
||||
result = {}
|
||||
|
||||
models = self.all_models
|
||||
label_width = max([len(models[x].name or "") for x in self.starter_models])
|
||||
label_width = max([len(models[x].name) for x in models])
|
||||
description_width = window_width - label_width - checkbox_width - spacing_width
|
||||
|
||||
for key in self.all_models:
|
||||
description = models[key].description
|
||||
result = {}
|
||||
for x in models.keys():
|
||||
description = models[x].description
|
||||
description = (
|
||||
description[0 : description_width - 3] + "..."
|
||||
if description and len(description) > description_width
|
||||
@ -390,8 +403,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
if description
|
||||
else ""
|
||||
)
|
||||
result[key] = f"%-{label_width}s %s" % (models[key].name, description)
|
||||
|
||||
result[x] = f"%-{label_width}s %s" % (models[x].name, description)
|
||||
return result
|
||||
|
||||
def _get_columns(self) -> int:
|
||||
@ -401,40 +413,50 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
|
||||
def confirm_deletions(self, selections: InstallSelections) -> bool:
|
||||
remove_models = selections.remove_models
|
||||
if remove_models:
|
||||
model_names = [self.all_models[x].name or "" for x in remove_models]
|
||||
mods = "\n".join(model_names)
|
||||
is_ok = npyscreen.notify_ok_cancel(
|
||||
if len(remove_models) > 0:
|
||||
mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models])
|
||||
return npyscreen.notify_ok_cancel(
|
||||
f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}"
|
||||
)
|
||||
assert isinstance(is_ok, bool) # npyscreen doesn't have return type annotations
|
||||
return is_ok
|
||||
else:
|
||||
return True
|
||||
|
||||
@property
|
||||
def all_models(self) -> Dict[str, UnifiedModelInfo]:
|
||||
# npyscreen doesn't having typing hints
|
||||
return self.parentApp.install_helper.all_models # type: ignore
|
||||
def on_execute(self):
|
||||
self.marshall_arguments()
|
||||
app = self.parentApp
|
||||
if not self.confirm_deletions(app.install_selections):
|
||||
return
|
||||
|
||||
@property
|
||||
def starter_models(self) -> List[str]:
|
||||
return self.parentApp.install_helper._starter_models # type: ignore
|
||||
self.monitor.entry_widget.buffer(["Processing..."], scroll_end=True)
|
||||
self.ok_button.hidden = True
|
||||
self.display()
|
||||
|
||||
@property
|
||||
def installed_models(self) -> List[str]:
|
||||
return self.parentApp.install_helper._installed_models # type: ignore
|
||||
# TO DO: Spawn a worker thread, not a subprocess
|
||||
parent_conn, child_conn = Pipe()
|
||||
p = Process(
|
||||
target=process_and_execute,
|
||||
kwargs={
|
||||
"opt": app.program_opts,
|
||||
"selections": app.install_selections,
|
||||
"conn_out": child_conn,
|
||||
},
|
||||
)
|
||||
p.start()
|
||||
child_conn.close()
|
||||
self.subprocess_connection = parent_conn
|
||||
self.subprocess = p
|
||||
app.install_selections = InstallSelections()
|
||||
|
||||
def on_back(self) -> None:
|
||||
def on_back(self):
|
||||
self.parentApp.switchFormPrevious()
|
||||
self.editing = False
|
||||
|
||||
def on_cancel(self) -> None:
|
||||
def on_cancel(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
self.parentApp.user_cancelled = True
|
||||
self.editing = False
|
||||
|
||||
def on_done(self) -> None:
|
||||
def on_done(self):
|
||||
self.marshall_arguments()
|
||||
if not self.confirm_deletions(self.parentApp.install_selections):
|
||||
return
|
||||
@ -442,7 +464,77 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
self.parentApp.user_cancelled = False
|
||||
self.editing = False
|
||||
|
||||
def marshall_arguments(self) -> None:
|
||||
########## This routine monitors the child process that is performing model installation and removal #####
|
||||
def while_waiting(self):
|
||||
"""Called during idle periods. Main task is to update the Log Messages box with messages
|
||||
from the child process that does the actual installation/removal"""
|
||||
c = self.subprocess_connection
|
||||
if not c:
|
||||
return
|
||||
|
||||
monitor_widget = self.monitor.entry_widget
|
||||
while c.poll():
|
||||
try:
|
||||
data = c.recv_bytes().decode("utf-8")
|
||||
data.strip("\n")
|
||||
|
||||
# processing child is requesting user input to select the
|
||||
# right configuration file
|
||||
if data.startswith("*need v2 config"):
|
||||
_, model_path, *_ = data.split(":", 2)
|
||||
self._return_v2_config(model_path)
|
||||
|
||||
# processing child is done
|
||||
elif data == "*done*":
|
||||
self._close_subprocess_and_regenerate_form()
|
||||
break
|
||||
|
||||
# update the log message box
|
||||
else:
|
||||
data = make_printable(data)
|
||||
data = data.replace("[A", "")
|
||||
monitor_widget.buffer(
|
||||
textwrap.wrap(
|
||||
data,
|
||||
width=monitor_widget.width,
|
||||
subsequent_indent=" ",
|
||||
),
|
||||
scroll_end=True,
|
||||
)
|
||||
self.display()
|
||||
except (EOFError, OSError):
|
||||
self.subprocess_connection = None
|
||||
|
||||
def _return_v2_config(self, model_path: str):
|
||||
c = self.subprocess_connection
|
||||
model_name = Path(model_path).name
|
||||
message = select_stable_diffusion_config_file(model_name=model_name)
|
||||
c.send_bytes(message.encode("utf-8"))
|
||||
|
||||
def _close_subprocess_and_regenerate_form(self):
|
||||
app = self.parentApp
|
||||
self.subprocess_connection.close()
|
||||
self.subprocess_connection = None
|
||||
self.monitor.entry_widget.buffer(["** Action Complete **"])
|
||||
self.display()
|
||||
|
||||
# rebuild the form, saving and restoring some of the fields that need to be preserved.
|
||||
saved_messages = self.monitor.entry_widget.values
|
||||
|
||||
app.main_form = app.addForm(
|
||||
"MAIN",
|
||||
addModelsForm,
|
||||
name="Install Stable Diffusion Models",
|
||||
multipage=self.multipage,
|
||||
)
|
||||
app.switchForm("MAIN")
|
||||
|
||||
app.main_form.monitor.entry_widget.values = saved_messages
|
||||
app.main_form.monitor.entry_widget.buffer([""], scroll_end=True)
|
||||
# app.main_form.pipeline_models['autoload_directory'].value = autoload_dir
|
||||
# app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan
|
||||
|
||||
def marshall_arguments(self):
|
||||
"""
|
||||
Assemble arguments and store as attributes of the application:
|
||||
.starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml
|
||||
@ -472,24 +564,46 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
||||
models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
|
||||
selections.remove_models.extend(models_to_remove)
|
||||
selections.install_models.extend([all_models[x] for x in models_to_install])
|
||||
selections.install_models.extend(
|
||||
all_models[x].path or all_models[x].repo_id
|
||||
for x in models_to_install
|
||||
if all_models[x].path or all_models[x].repo_id
|
||||
)
|
||||
|
||||
# models located in the 'download_ids" section
|
||||
for section in ui_sections:
|
||||
if downloads := section.get("download_ids"):
|
||||
models = [UnifiedModelInfo(source=x) for x in downloads.value.split()]
|
||||
selections.install_models.extend(models)
|
||||
selections.install_models.extend(downloads.value.split())
|
||||
|
||||
# NOT NEEDED - DONE IN BACKEND NOW
|
||||
# # special case for the ipadapter_models. If any of the adapters are
|
||||
# # chosen, then we add the corresponding encoder(s) to the install list.
|
||||
# section = self.ipadapter_models
|
||||
# if section.get("models_selected"):
|
||||
# selected_adapters = [
|
||||
# self.all_models[section["models"][x]].name for x in section.get("models_selected").value
|
||||
# ]
|
||||
# encoders = []
|
||||
# if any(["sdxl" in x for x in selected_adapters]):
|
||||
# encoders.append("ip_adapter_sdxl_image_encoder")
|
||||
# if any(["sd15" in x for x in selected_adapters]):
|
||||
# encoders.append("ip_adapter_sd_image_encoder")
|
||||
# for encoder in encoders:
|
||||
# key = f"any/clip_vision/{encoder}"
|
||||
# repo_id = f"InvokeAI/{encoder}"
|
||||
# if key not in self.all_models:
|
||||
# selections.install_models.append(repo_id)
|
||||
|
||||
|
||||
class AddModelApplication(npyscreen.NPSAppManaged): # type: ignore
|
||||
def __init__(self, opt: Namespace, install_helper: InstallHelper):
|
||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
self.program_opts = opt
|
||||
self.user_cancelled = False
|
||||
# self.autoload_pending = True
|
||||
self.install_selections = InstallSelections()
|
||||
self.install_helper = install_helper
|
||||
|
||||
def onStart(self) -> None:
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main_form = self.addForm(
|
||||
"MAIN",
|
||||
@ -499,62 +613,138 @@ class AddModelApplication(npyscreen.NPSAppManaged): # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def list_models(installer: ModelInstallServiceBase, model_type: ModelType):
|
||||
"""Print out all models of type model_type."""
|
||||
models = installer.record_store.search_by_attr(model_type=model_type)
|
||||
print(f"Installed models of type `{model_type}`:")
|
||||
for model in models:
|
||||
path = (config.models_path / model.path).resolve()
|
||||
print(f"{model.name:40}{model.base.value:5}{model.type.value:8}{model.format.value:12}{path}")
|
||||
class StderrToMessage:
|
||||
def __init__(self, connection: Connection):
|
||||
self.connection = connection
|
||||
|
||||
def write(self, data: str):
|
||||
self.connection.send_bytes(data.encode("utf-8"))
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def select_and_download_models(opt: Namespace) -> None:
|
||||
"""Prompt user for install/delete selections and execute."""
|
||||
def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None) -> SchedulerPredictionType:
|
||||
if tui_conn:
|
||||
logger.debug("Waiting for user response...")
|
||||
return _ask_user_for_pt_tui(model_path, tui_conn)
|
||||
else:
|
||||
return _ask_user_for_pt_cmdline(model_path)
|
||||
|
||||
|
||||
def _ask_user_for_pt_cmdline(model_path: Path) -> Optional[SchedulerPredictionType]:
|
||||
choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
|
||||
print(
|
||||
f"""
|
||||
Please select the scheduler prediction type of the checkpoint named {model_path.name}:
|
||||
[1] "epsilon" - most v1.5 models and v2 models trained on 512 pixel images
|
||||
[2] "vprediction" - v2 models trained on 768 pixel images and a few v1.5 models
|
||||
[3] Accept the best guess; you can fix it in the Web UI later
|
||||
"""
|
||||
)
|
||||
choice = None
|
||||
ok = False
|
||||
while not ok:
|
||||
try:
|
||||
choice = input("select [3]> ").strip()
|
||||
if not choice:
|
||||
return None
|
||||
choice = choices[int(choice) - 1]
|
||||
ok = True
|
||||
except (ValueError, IndexError):
|
||||
print(f"{choice} is not a valid choice")
|
||||
except EOFError:
|
||||
return
|
||||
return choice
|
||||
|
||||
|
||||
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType:
|
||||
tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8"))
|
||||
# note that we don't do any status checking here
|
||||
response = tui_conn.recv_bytes().decode("utf-8")
|
||||
if response is None:
|
||||
return None
|
||||
elif response == "epsilon":
|
||||
return SchedulerPredictionType.epsilon
|
||||
elif response == "v":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
elif response == "guess":
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def process_and_execute(
|
||||
opt: Namespace,
|
||||
selections: InstallSelections,
|
||||
conn_out: Connection = None,
|
||||
):
|
||||
# need to reinitialize config in subprocess
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
args = ["--root", opt.root] if opt.root else []
|
||||
config.parse_args(args)
|
||||
|
||||
# set up so that stderr is sent to conn_out
|
||||
if conn_out:
|
||||
translator = StderrToMessage(conn_out)
|
||||
sys.stderr = translator
|
||||
sys.stdout = translator
|
||||
logger = InvokeAILogger.get_logger()
|
||||
logger.handlers.clear()
|
||||
logger.addHandler(logging.StreamHandler(translator))
|
||||
|
||||
installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x, conn_out))
|
||||
installer.install(selections)
|
||||
|
||||
if conn_out:
|
||||
conn_out.send_bytes("*done*".encode("utf-8"))
|
||||
conn_out.close()
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def select_and_download_models(opt: Namespace):
|
||||
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||
# unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal
|
||||
config.precision = precision # type: ignore
|
||||
install_helper = InstallHelper(config, logger)
|
||||
installer = install_helper.installer
|
||||
|
||||
config.precision = precision
|
||||
installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type)
|
||||
if opt.list_models:
|
||||
list_models(installer, opt.list_models)
|
||||
|
||||
installer.list_models(opt.list_models)
|
||||
elif opt.add or opt.delete:
|
||||
selections = InstallSelections(
|
||||
install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or []
|
||||
)
|
||||
install_helper.add_or_delete(selections)
|
||||
|
||||
selections = InstallSelections(install_models=opt.add or [], remove_models=opt.delete or [])
|
||||
installer.install(selections)
|
||||
elif opt.default_only:
|
||||
default_model = install_helper.default_model()
|
||||
assert default_model is not None
|
||||
selections = InstallSelections(install_models=[default_model])
|
||||
install_helper.add_or_delete(selections)
|
||||
|
||||
selections = InstallSelections(install_models=installer.default_model())
|
||||
installer.install(selections)
|
||||
elif opt.yes_to_all:
|
||||
selections = InstallSelections(install_models=install_helper.recommended_models())
|
||||
install_helper.add_or_delete(selections)
|
||||
selections = InstallSelections(install_models=installer.recommended_models())
|
||||
installer.install(selections)
|
||||
|
||||
# this is where the TUI is called
|
||||
else:
|
||||
# needed to support the probe() method running under a subprocess
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
|
||||
raise WindowTooSmallException(
|
||||
"Could not increase terminal size. Try running again with a larger window or smaller font size."
|
||||
)
|
||||
|
||||
installApp = AddModelApplication(opt, install_helper)
|
||||
installApp = AddModelApplication(opt)
|
||||
try:
|
||||
installApp.run()
|
||||
except KeyboardInterrupt:
|
||||
print("Aborted...")
|
||||
sys.exit(-1)
|
||||
|
||||
install_helper.add_or_delete(installApp.install_selections)
|
||||
except KeyboardInterrupt as e:
|
||||
if hasattr(installApp, "main_form"):
|
||||
if installApp.main_form.subprocess and installApp.main_form.subprocess.is_alive():
|
||||
logger.info("Terminating subprocesses")
|
||||
installApp.main_form.subprocess.terminate()
|
||||
installApp.main_form.subprocess = None
|
||||
raise e
|
||||
process_and_execute(opt, installApp.install_selections)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def main() -> None:
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||
parser.add_argument(
|
||||
"--add",
|
||||
@ -564,7 +754,7 @@ def main() -> None:
|
||||
parser.add_argument(
|
||||
"--delete",
|
||||
nargs="*",
|
||||
help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`",
|
||||
help="List of names of models to idelete",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full-precision",
|
||||
@ -591,6 +781,14 @@ def main() -> None:
|
||||
choices=[x.value for x in ModelType],
|
||||
help="list installed models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
"-c",
|
||||
dest="config_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to configuration file to create",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
dest="root",
|
@ -267,6 +267,17 @@ class SingleSelectWithChanged(npyscreen.SelectOne):
|
||||
self.on_changed(self.value)
|
||||
|
||||
|
||||
class CheckboxWithChanged(npyscreen.Checkbox):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.on_changed = None
|
||||
|
||||
def whenToggled(self):
|
||||
super().whenToggled()
|
||||
if self.on_changed:
|
||||
self.on_changed(self.value)
|
||||
|
||||
|
||||
class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged):
|
||||
"""Row of radio buttons. Spacebar to select."""
|
||||
|
||||
|
Reference in New Issue
Block a user