mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
ask user for v2 variant when model manager can't infer it
This commit is contained in:
parent
31e97ead2a
commit
1a7fb601dc
@ -9,7 +9,7 @@ import warnings
|
||||
from dataclasses import dataclass,field
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryFile
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Callable
|
||||
|
||||
import requests
|
||||
from diffusers import AutoencoderKL
|
||||
@ -95,6 +95,7 @@ def install_requested_models(
|
||||
precision: str = "float16",
|
||||
purge_deleted: bool = False,
|
||||
config_file_path: Path = None,
|
||||
model_config_file_callback: Callable[[Path],Path] = None
|
||||
):
|
||||
"""
|
||||
Entry point for installing/deleting starter models, or installing external models.
|
||||
@ -118,19 +119,19 @@ def install_requested_models(
|
||||
|
||||
# TODO: Replace next three paragraphs with calls into new model manager
|
||||
if diffusers.remove_models and len(diffusers.remove_models) > 0:
|
||||
logger.info("DELETING UNCHECKED STARTER MODELS")
|
||||
logger.info("Processing requested deletions")
|
||||
for model in diffusers.remove_models:
|
||||
logger.info(f"{model}...")
|
||||
model_manager.del_model(model, delete_files=purge_deleted)
|
||||
model_manager.commit(config_file_path)
|
||||
|
||||
if diffusers.install_models and len(diffusers.install_models) > 0:
|
||||
logger.info("INSTALLING SELECTED STARTER MODELS")
|
||||
logger.info("Installing requested models")
|
||||
downloaded_paths = download_weight_datasets(
|
||||
models=diffusers.install_models,
|
||||
access_token=None,
|
||||
precision=precision,
|
||||
) # FIX: for historical reasons, we don't use model manager here
|
||||
)
|
||||
successful = {x:v for x,v in downloaded_paths.items() if v is not None}
|
||||
if len(successful) > 0:
|
||||
update_config_file(successful, config_file_path)
|
||||
@ -153,6 +154,7 @@ def install_requested_models(
|
||||
model_manager.heuristic_import(
|
||||
path_url_or_repo,
|
||||
commit_to_conf=config_file_path,
|
||||
config_file_callback = model_config_file_callback,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
|
@ -874,14 +874,12 @@ class ModelManager(object):
|
||||
model_config_file = self.globals.legacy_conf_path / "v2-inference.yaml"
|
||||
elif model_type == SDLegacyType.V2:
|
||||
self.logger.warning(
|
||||
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
||||
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined."
|
||||
)
|
||||
return
|
||||
else:
|
||||
self.logger.warning(
|
||||
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
||||
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model."
|
||||
)
|
||||
return
|
||||
|
||||
if not model_config_file and config_file_callback:
|
||||
model_config_file = config_file_callback(model_path)
|
||||
|
@ -48,7 +48,9 @@ from .widgets import (
|
||||
OffsetButtonPress,
|
||||
TextBox,
|
||||
BufferBox,
|
||||
FileBox,
|
||||
set_min_terminal_size,
|
||||
select_stable_diffusion_config_file,
|
||||
)
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
|
||||
@ -363,7 +365,9 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
self.nextrely += 2
|
||||
widgets.update(
|
||||
autoload_directory = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
# npyscreen.TitleFilename,
|
||||
FileBox,
|
||||
max_height=3,
|
||||
name=label,
|
||||
select_dir=True,
|
||||
must_exist=True,
|
||||
@ -485,39 +489,67 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
self.parentApp.user_cancelled = True
|
||||
self.editing = False
|
||||
|
||||
########## This routine monitors the child process that is performing model installation and removal #####
|
||||
def while_waiting(self):
|
||||
app = self.parentApp
|
||||
'''Called during idle periods. Main task is to update the Log Messages box with messages
|
||||
from the child process that does the actual installation/removal'''
|
||||
c = self.subprocess_connection
|
||||
if not c:
|
||||
return
|
||||
|
||||
monitor_widget = self.monitor.entry_widget
|
||||
if c := self.subprocess_connection:
|
||||
while c.poll():
|
||||
try:
|
||||
data = c.recv_bytes().decode('utf-8')
|
||||
data.strip('\n')
|
||||
if data=='*done*':
|
||||
self.subprocess_connection = None
|
||||
monitor_widget.buffer(['** Action Complete **'])
|
||||
self.display()
|
||||
# rebuild the form, saving log messages
|
||||
saved_messages = monitor_widget.values
|
||||
app.main_form = app.addForm(
|
||||
"MAIN", addModelsForm, name="Install Stable Diffusion Models"
|
||||
)
|
||||
app.switchForm('MAIN')
|
||||
app.main_form.monitor.entry_widget.values = saved_messages
|
||||
app.main_form.monitor.entry_widget.buffer([''],scroll_end=True)
|
||||
break
|
||||
else:
|
||||
monitor_widget.buffer(
|
||||
textwrap.wrap(data,
|
||||
width=monitor_widget.width,
|
||||
subsequent_indent=' ',
|
||||
),
|
||||
scroll_end=True
|
||||
)
|
||||
self.display()
|
||||
except (EOFError,OSError):
|
||||
self.subprocess_connection = None
|
||||
while c.poll():
|
||||
try:
|
||||
data = c.recv_bytes().decode('utf-8')
|
||||
data.strip('\n')
|
||||
|
||||
# processing child is requesting user input to select the
|
||||
# right configuration file
|
||||
if data.startswith('*need v2 config'):
|
||||
_,model_path,*_ = data.split(":",2)
|
||||
self._return_v2_config(model_path)
|
||||
|
||||
# processing child is done
|
||||
elif data=='*done*':
|
||||
self._close_subprocess_and_regenerate_form()
|
||||
break
|
||||
|
||||
# update the log message box
|
||||
else:
|
||||
monitor_widget.buffer(
|
||||
textwrap.wrap(data,
|
||||
width=monitor_widget.width,
|
||||
subsequent_indent=' ',
|
||||
),
|
||||
scroll_end=True
|
||||
)
|
||||
self.display()
|
||||
except (EOFError,OSError):
|
||||
self.subprocess_connection = None
|
||||
|
||||
def _return_v2_config(self,model_path: str):
|
||||
c = self.subprocess_connection
|
||||
model_name = Path(model_path).name
|
||||
message = select_stable_diffusion_config_file(model_name=model_name)
|
||||
c.send_bytes(message.encode('utf-8'))
|
||||
|
||||
def _close_subprocess_and_regenerate_form(self):
|
||||
app = self.parentApp
|
||||
self.subprocess_connection.close()
|
||||
self.subprocess_connection = None
|
||||
self.monitor.entry_widget.buffer(['** Action Complete **'])
|
||||
self.display()
|
||||
# rebuild the form, saving log messages
|
||||
saved_messages = self.monitor.entry_widget.values
|
||||
app.main_form = app.addForm(
|
||||
"MAIN", addModelsForm, name="Install Stable Diffusion Models"
|
||||
)
|
||||
app.switchForm('MAIN')
|
||||
app.main_form.monitor.entry_widget.values = saved_messages
|
||||
app.main_form.monitor.entry_widget.buffer([''],scroll_end=True)
|
||||
|
||||
###############################################################
|
||||
|
||||
def list_additional_diffusers_models(self,
|
||||
manager: ModelManager,
|
||||
starters:dict
|
||||
@ -628,7 +660,7 @@ class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main_form = self.addForm(
|
||||
"MAIN", addModelsForm, name="Install Stable Diffusion Models"
|
||||
"MAIN", addModelsForm, name="Install Stable Diffusion Models", cycle_widgets=True,
|
||||
)
|
||||
|
||||
class StderrToMessage():
|
||||
@ -640,6 +672,57 @@ class StderrToMessage():
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
# --------------------------------------------------------
|
||||
def ask_user_for_config_file(model_path: Path,
|
||||
tui_conn: Connection=None
|
||||
)->Path:
|
||||
logger.debug(f'Waiting for user action in dialog box (above).')
|
||||
if tui_conn:
|
||||
return _ask_user_for_cf_tui(model_path, tui_conn)
|
||||
else:
|
||||
return _ask_user_for_cf_cmdline(model_path)
|
||||
|
||||
def _ask_user_for_cf_cmdline(model_path):
|
||||
choices = [
|
||||
config.model_conf_path / 'stable-diffusion' / x
|
||||
for x in ['v2-inference.yaml','v2-inference-v.yaml']
|
||||
]
|
||||
print(
|
||||
f"""
|
||||
Please select the type of the V2 checkpoint named {model_path.name}:
|
||||
[1] A Stable Diffusion v2.x base model (512 pixels; there should be no 'parameterization:' line in its yaml file)
|
||||
[2] A Stable Diffusion v2.x v-predictive model (768 pixels; look for a 'parameterization: "v"' line in its yaml file)
|
||||
"""
|
||||
)
|
||||
choice = None
|
||||
ok = False
|
||||
while not ok:
|
||||
try:
|
||||
choice = input('select> ').strip()
|
||||
choice = choices[int(choice)-1]
|
||||
ok = True
|
||||
except (ValueError, IndexError):
|
||||
print(f'{choice} is not a valid choice')
|
||||
except EOFError:
|
||||
return
|
||||
return choice
|
||||
|
||||
def _ask_user_for_cf_tui(model_path: Path, tui_conn: Connection)->Path:
|
||||
try:
|
||||
tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8'))
|
||||
# note that we don't do any status checking here
|
||||
response = tui_conn.recv_bytes().decode('utf-8')
|
||||
if response is None:
|
||||
return None
|
||||
elif response == 'epsilon':
|
||||
return config.legacy_conf_path / 'v2-inference.yaml'
|
||||
elif response == 'v':
|
||||
return config.legacy_conf_path / 'v2-inference-v.yaml'
|
||||
else:
|
||||
return Path(response)
|
||||
except:
|
||||
return None
|
||||
|
||||
# --------------------------------------------------------
|
||||
def process_and_execute(opt: Namespace,
|
||||
@ -651,13 +734,16 @@ def process_and_execute(opt: Namespace,
|
||||
translator = StderrToMessage(conn_out)
|
||||
sys.stderr = translator
|
||||
sys.stdout = translator
|
||||
InvokeAILogger.getLogger().handlers[0]=logging.StreamHandler(translator)
|
||||
logger = InvokeAILogger.getLogger()
|
||||
logger.handlers.clear()
|
||||
logger.addHandler(logging.StreamHandler(translator))
|
||||
|
||||
models_to_install = selections.install_models
|
||||
models_to_remove = selections.remove_models
|
||||
directory_to_scan = selections.scan_directory
|
||||
scan_at_startup = selections.autoscan_on_startup
|
||||
potential_models_to_install = selections.import_model_paths
|
||||
|
||||
install_requested_models(
|
||||
diffusers = ModelInstallList(models_to_install, models_to_remove),
|
||||
controlnet = ModelInstallList(selections.install_cn_models, selections.remove_cn_models),
|
||||
@ -671,6 +757,7 @@ def process_and_execute(opt: Namespace,
|
||||
else choose_precision(torch.device(choose_torch_device())),
|
||||
purge_deleted=selections.purge_deleted_models,
|
||||
config_file_path=Path(opt.config_file) if opt.config_file else None,
|
||||
model_config_file_callback = lambda x: ask_user_for_config_file(x,conn_out)
|
||||
)
|
||||
|
||||
if conn_out:
|
||||
|
@ -8,9 +8,12 @@ import platform
|
||||
import pyperclip
|
||||
import struct
|
||||
import sys
|
||||
import npyscreen
|
||||
import textwrap
|
||||
import npyscreen.wgmultiline as wgmultiline
|
||||
from npyscreen import fmPopup
|
||||
from shutil import get_terminal_size
|
||||
from curses import BUTTON2_CLICKED,BUTTON3_CLICKED
|
||||
import npyscreen
|
||||
|
||||
# -------------------------------------
|
||||
def set_terminal_size(columns: int, lines: int):
|
||||
@ -151,8 +154,13 @@ class MultiSelectColumns( SelectColumnBase, npyscreen.MultiSelect):
|
||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||
super().__init__(screen, values=values, **keywords)
|
||||
|
||||
class SingleSelectWithChanged(npyscreen.SelectOne):
|
||||
def h_select(self,ch):
|
||||
super().h_select(ch)
|
||||
if self.on_changed:
|
||||
self.on_changed(self.value)
|
||||
|
||||
class SingleSelectColumns(SelectColumnBase, npyscreen.SelectOne):
|
||||
class SingleSelectColumns(SelectColumnBase, SingleSelectWithChanged):
|
||||
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
|
||||
self.columns = columns
|
||||
self.value_cnt = len(values)
|
||||
@ -160,11 +168,6 @@ class SingleSelectColumns(SelectColumnBase, npyscreen.SelectOne):
|
||||
self.on_changed = None
|
||||
super().__init__(screen, values=values, **keywords)
|
||||
|
||||
def h_select(self,ch):
|
||||
super().h_select(ch)
|
||||
if self.on_changed:
|
||||
self.on_changed(self.value)
|
||||
|
||||
def when_value_edited(self):
|
||||
self.h_select(self.cursor_line)
|
||||
|
||||
@ -271,4 +274,90 @@ class TextBox(npyscreen.MultiLineEdit):
|
||||
class BufferBox(npyscreen.BoxTitle):
|
||||
_contained_widget = npyscreen.BufferPager
|
||||
|
||||
class ConfirmCancelPopup(fmPopup.ActionPopup):
|
||||
DEFAULT_COLUMNS = 100
|
||||
def on_ok(self):
|
||||
self.value = True
|
||||
def on_cancel(self):
|
||||
self.value = False
|
||||
|
||||
class FileBox(npyscreen.BoxTitle):
|
||||
_contained_widget = npyscreen.Filename
|
||||
|
||||
def _wrap_message_lines(message, line_length):
|
||||
lines = []
|
||||
for line in message.split('\n'):
|
||||
lines.extend(textwrap.wrap(line.rstrip(), line_length))
|
||||
return lines
|
||||
|
||||
def _prepare_message(message):
|
||||
if isinstance(message, list) or isinstance(message, tuple):
|
||||
return "\n".join([ s.rstrip() for s in message])
|
||||
#return "\n".join(message)
|
||||
else:
|
||||
return message
|
||||
|
||||
def select_stable_diffusion_config_file(
|
||||
form_color: str='DANGER',
|
||||
wrap:bool =True,
|
||||
model_name:str='Unknown',
|
||||
):
|
||||
message = "Please select the correct base model for the V2 checkpoint named {model_name}. Press <CANCEL> to skip installation."
|
||||
title = "CONFIG FILE SELECTION"
|
||||
options=[
|
||||
"An SD v2.x base model (512 pixels; no 'parameterization:' line in its yaml file)",
|
||||
"An SD v2.x v-predictive model (768 pixels; 'parameterization: \"v\"' line in its yaml file)",
|
||||
"Enter config file path manually",
|
||||
]
|
||||
|
||||
F = ConfirmCancelPopup(
|
||||
name=title,
|
||||
color=form_color,
|
||||
cycle_widgets=True,
|
||||
lines=16,
|
||||
)
|
||||
F.preserve_selected_widget = True
|
||||
|
||||
mlw = F.add(
|
||||
wgmultiline.Pager,
|
||||
max_height=4,
|
||||
editable=False,
|
||||
)
|
||||
mlw_width = mlw.width-1
|
||||
if wrap:
|
||||
message = _wrap_message_lines(message, mlw_width)
|
||||
mlw.values = message
|
||||
|
||||
choice = F.add(
|
||||
SingleSelectWithChanged,
|
||||
values = options,
|
||||
value = [0],
|
||||
max_height = len(options)+1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
file = F.add(
|
||||
FileBox,
|
||||
name='Path to config file',
|
||||
max_height=3,
|
||||
hidden=True,
|
||||
must_exist=True,
|
||||
scroll_exit=True
|
||||
)
|
||||
|
||||
def toggle_visible(value):
|
||||
value = value[0]
|
||||
if value==2:
|
||||
file.hidden=False
|
||||
else:
|
||||
file.hidden=True
|
||||
F.display()
|
||||
|
||||
choice.on_changed = toggle_visible
|
||||
|
||||
F.editw = 1
|
||||
F.edit()
|
||||
if not F.value:
|
||||
return None
|
||||
assert choice.value[0] in range(0,3),'invalid choice'
|
||||
choices = ['epsilon','v',file.value]
|
||||
return choices[choice.value[0]]
|
||||
|
@ -13,6 +13,7 @@ def main():
|
||||
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
if '--web' in sys.argv:
|
||||
sys.argv.remove('--web')
|
||||
from invokeai.app.api_app import invoke_api
|
||||
invoke_api()
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user