ask user for v2 variant when model manager can't infer it

This commit is contained in:
Lincoln Stein 2023-06-04 11:27:44 -04:00
parent 31e97ead2a
commit 1a7fb601dc
5 changed files with 225 additions and 48 deletions

View File

@ -9,7 +9,7 @@ import warnings
from dataclasses import dataclass,field
from pathlib import Path
from tempfile import TemporaryFile
from typing import List, Dict
from typing import List, Dict, Callable
import requests
from diffusers import AutoencoderKL
@ -95,6 +95,7 @@ def install_requested_models(
precision: str = "float16",
purge_deleted: bool = False,
config_file_path: Path = None,
model_config_file_callback: Callable[[Path],Path] = None
):
"""
Entry point for installing/deleting starter models, or installing external models.
@ -118,19 +119,19 @@ def install_requested_models(
# TODO: Replace next three paragraphs with calls into new model manager
if diffusers.remove_models and len(diffusers.remove_models) > 0:
logger.info("DELETING UNCHECKED STARTER MODELS")
logger.info("Processing requested deletions")
for model in diffusers.remove_models:
logger.info(f"{model}...")
model_manager.del_model(model, delete_files=purge_deleted)
model_manager.commit(config_file_path)
if diffusers.install_models and len(diffusers.install_models) > 0:
logger.info("INSTALLING SELECTED STARTER MODELS")
logger.info("Installing requested models")
downloaded_paths = download_weight_datasets(
models=diffusers.install_models,
access_token=None,
precision=precision,
) # FIX: for historical reasons, we don't use model manager here
)
successful = {x:v for x,v in downloaded_paths.items() if v is not None}
if len(successful) > 0:
update_config_file(successful, config_file_path)
@ -153,6 +154,7 @@ def install_requested_models(
model_manager.heuristic_import(
path_url_or_repo,
commit_to_conf=config_file_path,
config_file_callback = model_config_file_callback,
)
except KeyboardInterrupt:
sys.exit(-1)

View File

@ -874,14 +874,12 @@ class ModelManager(object):
model_config_file = self.globals.legacy_conf_path / "v2-inference.yaml"
elif model_type == SDLegacyType.V2:
self.logger.warning(
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined."
)
return
else:
self.logger.warning(
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model."
)
return
if not model_config_file and config_file_callback:
model_config_file = config_file_callback(model_path)

View File

@ -48,7 +48,9 @@ from .widgets import (
OffsetButtonPress,
TextBox,
BufferBox,
FileBox,
set_min_terminal_size,
select_stable_diffusion_config_file,
)
from invokeai.app.services.config import get_invokeai_config
@ -363,7 +365,9 @@ class addModelsForm(npyscreen.FormMultiPage):
self.nextrely += 2
widgets.update(
autoload_directory = self.add_widget_intelligent(
npyscreen.TitleFilename,
# npyscreen.TitleFilename,
FileBox,
max_height=3,
name=label,
select_dir=True,
must_exist=True,
@ -485,39 +489,67 @@ class addModelsForm(npyscreen.FormMultiPage):
self.parentApp.user_cancelled = True
self.editing = False
########## This routine monitors the child process that is performing model installation and removal #####
def while_waiting(self):
app = self.parentApp
'''Called during idle periods. Main task is to update the Log Messages box with messages
from the child process that does the actual installation/removal'''
c = self.subprocess_connection
if not c:
return
monitor_widget = self.monitor.entry_widget
if c := self.subprocess_connection:
while c.poll():
try:
data = c.recv_bytes().decode('utf-8')
data.strip('\n')
if data=='*done*':
self.subprocess_connection = None
monitor_widget.buffer(['** Action Complete **'])
self.display()
# rebuild the form, saving log messages
saved_messages = monitor_widget.values
app.main_form = app.addForm(
"MAIN", addModelsForm, name="Install Stable Diffusion Models"
)
app.switchForm('MAIN')
app.main_form.monitor.entry_widget.values = saved_messages
app.main_form.monitor.entry_widget.buffer([''],scroll_end=True)
break
else:
monitor_widget.buffer(
textwrap.wrap(data,
width=monitor_widget.width,
subsequent_indent=' ',
),
scroll_end=True
)
self.display()
except (EOFError,OSError):
self.subprocess_connection = None
while c.poll():
try:
data = c.recv_bytes().decode('utf-8')
data.strip('\n')
# processing child is requesting user input to select the
# right configuration file
if data.startswith('*need v2 config'):
_,model_path,*_ = data.split(":",2)
self._return_v2_config(model_path)
# processing child is done
elif data=='*done*':
self._close_subprocess_and_regenerate_form()
break
# update the log message box
else:
monitor_widget.buffer(
textwrap.wrap(data,
width=monitor_widget.width,
subsequent_indent=' ',
),
scroll_end=True
)
self.display()
except (EOFError,OSError):
self.subprocess_connection = None
def _return_v2_config(self,model_path: str):
c = self.subprocess_connection
model_name = Path(model_path).name
message = select_stable_diffusion_config_file(model_name=model_name)
c.send_bytes(message.encode('utf-8'))
def _close_subprocess_and_regenerate_form(self):
app = self.parentApp
self.subprocess_connection.close()
self.subprocess_connection = None
self.monitor.entry_widget.buffer(['** Action Complete **'])
self.display()
# rebuild the form, saving log messages
saved_messages = self.monitor.entry_widget.values
app.main_form = app.addForm(
"MAIN", addModelsForm, name="Install Stable Diffusion Models"
)
app.switchForm('MAIN')
app.main_form.monitor.entry_widget.values = saved_messages
app.main_form.monitor.entry_widget.buffer([''],scroll_end=True)
###############################################################
def list_additional_diffusers_models(self,
manager: ModelManager,
starters:dict
@ -628,7 +660,7 @@ class AddModelApplication(npyscreen.NPSAppManaged):
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
self.main_form = self.addForm(
"MAIN", addModelsForm, name="Install Stable Diffusion Models"
"MAIN", addModelsForm, name="Install Stable Diffusion Models", cycle_widgets=True,
)
class StderrToMessage():
@ -640,6 +672,57 @@ class StderrToMessage():
def flush(self):
pass
# --------------------------------------------------------
def ask_user_for_config_file(model_path: Path,
tui_conn: Connection=None
)->Path:
logger.debug(f'Waiting for user action in dialog box (above).')
if tui_conn:
return _ask_user_for_cf_tui(model_path, tui_conn)
else:
return _ask_user_for_cf_cmdline(model_path)
def _ask_user_for_cf_cmdline(model_path):
choices = [
config.model_conf_path / 'stable-diffusion' / x
for x in ['v2-inference.yaml','v2-inference-v.yaml']
]
print(
f"""
Please select the type of the V2 checkpoint named {model_path.name}:
[1] A Stable Diffusion v2.x base model (512 pixels; there should be no 'parameterization:' line in its yaml file)
[2] A Stable Diffusion v2.x v-predictive model (768 pixels; look for a 'parameterization: "v"' line in its yaml file)
"""
)
choice = None
ok = False
while not ok:
try:
choice = input('select> ').strip()
choice = choices[int(choice)-1]
ok = True
except (ValueError, IndexError):
print(f'{choice} is not a valid choice')
except EOFError:
return
return choice
def _ask_user_for_cf_tui(model_path: Path, tui_conn: Connection)->Path:
try:
tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8'))
# note that we don't do any status checking here
response = tui_conn.recv_bytes().decode('utf-8')
if response is None:
return None
elif response == 'epsilon':
return config.legacy_conf_path / 'v2-inference.yaml'
elif response == 'v':
return config.legacy_conf_path / 'v2-inference-v.yaml'
else:
return Path(response)
except:
return None
# --------------------------------------------------------
def process_and_execute(opt: Namespace,
@ -651,13 +734,16 @@ def process_and_execute(opt: Namespace,
translator = StderrToMessage(conn_out)
sys.stderr = translator
sys.stdout = translator
InvokeAILogger.getLogger().handlers[0]=logging.StreamHandler(translator)
logger = InvokeAILogger.getLogger()
logger.handlers.clear()
logger.addHandler(logging.StreamHandler(translator))
models_to_install = selections.install_models
models_to_remove = selections.remove_models
directory_to_scan = selections.scan_directory
scan_at_startup = selections.autoscan_on_startup
potential_models_to_install = selections.import_model_paths
install_requested_models(
diffusers = ModelInstallList(models_to_install, models_to_remove),
controlnet = ModelInstallList(selections.install_cn_models, selections.remove_cn_models),
@ -671,6 +757,7 @@ def process_and_execute(opt: Namespace,
else choose_precision(torch.device(choose_torch_device())),
purge_deleted=selections.purge_deleted_models,
config_file_path=Path(opt.config_file) if opt.config_file else None,
model_config_file_callback = lambda x: ask_user_for_config_file(x,conn_out)
)
if conn_out:

View File

@ -8,9 +8,12 @@ import platform
import pyperclip
import struct
import sys
import npyscreen
import textwrap
import npyscreen.wgmultiline as wgmultiline
from npyscreen import fmPopup
from shutil import get_terminal_size
from curses import BUTTON2_CLICKED,BUTTON3_CLICKED
import npyscreen
# -------------------------------------
def set_terminal_size(columns: int, lines: int):
@ -151,8 +154,13 @@ class MultiSelectColumns( SelectColumnBase, npyscreen.MultiSelect):
self.rows = math.ceil(self.value_cnt / self.columns)
super().__init__(screen, values=values, **keywords)
class SingleSelectWithChanged(npyscreen.SelectOne):
def h_select(self,ch):
super().h_select(ch)
if self.on_changed:
self.on_changed(self.value)
class SingleSelectColumns(SelectColumnBase, npyscreen.SelectOne):
class SingleSelectColumns(SelectColumnBase, SingleSelectWithChanged):
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
self.columns = columns
self.value_cnt = len(values)
@ -160,11 +168,6 @@ class SingleSelectColumns(SelectColumnBase, npyscreen.SelectOne):
self.on_changed = None
super().__init__(screen, values=values, **keywords)
def h_select(self,ch):
super().h_select(ch)
if self.on_changed:
self.on_changed(self.value)
def when_value_edited(self):
self.h_select(self.cursor_line)
@ -271,4 +274,90 @@ class TextBox(npyscreen.MultiLineEdit):
class BufferBox(npyscreen.BoxTitle):
_contained_widget = npyscreen.BufferPager
class ConfirmCancelPopup(fmPopup.ActionPopup):
DEFAULT_COLUMNS = 100
def on_ok(self):
self.value = True
def on_cancel(self):
self.value = False
class FileBox(npyscreen.BoxTitle):
_contained_widget = npyscreen.Filename
def _wrap_message_lines(message, line_length):
lines = []
for line in message.split('\n'):
lines.extend(textwrap.wrap(line.rstrip(), line_length))
return lines
def _prepare_message(message):
if isinstance(message, list) or isinstance(message, tuple):
return "\n".join([ s.rstrip() for s in message])
#return "\n".join(message)
else:
return message
def select_stable_diffusion_config_file(
form_color: str='DANGER',
wrap:bool =True,
model_name:str='Unknown',
):
message = "Please select the correct base model for the V2 checkpoint named {model_name}. Press <CANCEL> to skip installation."
title = "CONFIG FILE SELECTION"
options=[
"An SD v2.x base model (512 pixels; no 'parameterization:' line in its yaml file)",
"An SD v2.x v-predictive model (768 pixels; 'parameterization: \"v\"' line in its yaml file)",
"Enter config file path manually",
]
F = ConfirmCancelPopup(
name=title,
color=form_color,
cycle_widgets=True,
lines=16,
)
F.preserve_selected_widget = True
mlw = F.add(
wgmultiline.Pager,
max_height=4,
editable=False,
)
mlw_width = mlw.width-1
if wrap:
message = _wrap_message_lines(message, mlw_width)
mlw.values = message
choice = F.add(
SingleSelectWithChanged,
values = options,
value = [0],
max_height = len(options)+1,
scroll_exit=True,
)
file = F.add(
FileBox,
name='Path to config file',
max_height=3,
hidden=True,
must_exist=True,
scroll_exit=True
)
def toggle_visible(value):
value = value[0]
if value==2:
file.hidden=False
else:
file.hidden=True
F.display()
choice.on_changed = toggle_visible
F.editw = 1
F.edit()
if not F.value:
return None
assert choice.value[0] in range(0,3),'invalid choice'
choices = ['epsilon','v',file.value]
return choices[choice.value[0]]

View File

@ -13,6 +13,7 @@ def main():
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
if '--web' in sys.argv:
sys.argv.remove('--web')
from invokeai.app.api_app import invoke_api
invoke_api()
else: