mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
ask user for v2 variant when model manager can't infer it
This commit is contained in:
parent
31e97ead2a
commit
1a7fb601dc
@ -9,7 +9,7 @@ import warnings
|
|||||||
from dataclasses import dataclass,field
|
from dataclasses import dataclass,field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryFile
|
from tempfile import TemporaryFile
|
||||||
from typing import List, Dict
|
from typing import List, Dict, Callable
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import AutoencoderKL
|
||||||
@ -95,6 +95,7 @@ def install_requested_models(
|
|||||||
precision: str = "float16",
|
precision: str = "float16",
|
||||||
purge_deleted: bool = False,
|
purge_deleted: bool = False,
|
||||||
config_file_path: Path = None,
|
config_file_path: Path = None,
|
||||||
|
model_config_file_callback: Callable[[Path],Path] = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Entry point for installing/deleting starter models, or installing external models.
|
Entry point for installing/deleting starter models, or installing external models.
|
||||||
@ -118,19 +119,19 @@ def install_requested_models(
|
|||||||
|
|
||||||
# TODO: Replace next three paragraphs with calls into new model manager
|
# TODO: Replace next three paragraphs with calls into new model manager
|
||||||
if diffusers.remove_models and len(diffusers.remove_models) > 0:
|
if diffusers.remove_models and len(diffusers.remove_models) > 0:
|
||||||
logger.info("DELETING UNCHECKED STARTER MODELS")
|
logger.info("Processing requested deletions")
|
||||||
for model in diffusers.remove_models:
|
for model in diffusers.remove_models:
|
||||||
logger.info(f"{model}...")
|
logger.info(f"{model}...")
|
||||||
model_manager.del_model(model, delete_files=purge_deleted)
|
model_manager.del_model(model, delete_files=purge_deleted)
|
||||||
model_manager.commit(config_file_path)
|
model_manager.commit(config_file_path)
|
||||||
|
|
||||||
if diffusers.install_models and len(diffusers.install_models) > 0:
|
if diffusers.install_models and len(diffusers.install_models) > 0:
|
||||||
logger.info("INSTALLING SELECTED STARTER MODELS")
|
logger.info("Installing requested models")
|
||||||
downloaded_paths = download_weight_datasets(
|
downloaded_paths = download_weight_datasets(
|
||||||
models=diffusers.install_models,
|
models=diffusers.install_models,
|
||||||
access_token=None,
|
access_token=None,
|
||||||
precision=precision,
|
precision=precision,
|
||||||
) # FIX: for historical reasons, we don't use model manager here
|
)
|
||||||
successful = {x:v for x,v in downloaded_paths.items() if v is not None}
|
successful = {x:v for x,v in downloaded_paths.items() if v is not None}
|
||||||
if len(successful) > 0:
|
if len(successful) > 0:
|
||||||
update_config_file(successful, config_file_path)
|
update_config_file(successful, config_file_path)
|
||||||
@ -153,6 +154,7 @@ def install_requested_models(
|
|||||||
model_manager.heuristic_import(
|
model_manager.heuristic_import(
|
||||||
path_url_or_repo,
|
path_url_or_repo,
|
||||||
commit_to_conf=config_file_path,
|
commit_to_conf=config_file_path,
|
||||||
|
config_file_callback = model_config_file_callback,
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
@ -874,14 +874,12 @@ class ModelManager(object):
|
|||||||
model_config_file = self.globals.legacy_conf_path / "v2-inference.yaml"
|
model_config_file = self.globals.legacy_conf_path / "v2-inference.yaml"
|
||||||
elif model_type == SDLegacyType.V2:
|
elif model_type == SDLegacyType.V2:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined."
|
||||||
)
|
)
|
||||||
return
|
|
||||||
else:
|
else:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model."
|
||||||
)
|
)
|
||||||
return
|
|
||||||
|
|
||||||
if not model_config_file and config_file_callback:
|
if not model_config_file and config_file_callback:
|
||||||
model_config_file = config_file_callback(model_path)
|
model_config_file = config_file_callback(model_path)
|
||||||
|
@ -48,7 +48,9 @@ from .widgets import (
|
|||||||
OffsetButtonPress,
|
OffsetButtonPress,
|
||||||
TextBox,
|
TextBox,
|
||||||
BufferBox,
|
BufferBox,
|
||||||
|
FileBox,
|
||||||
set_min_terminal_size,
|
set_min_terminal_size,
|
||||||
|
select_stable_diffusion_config_file,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
|
||||||
@ -363,7 +365,9 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
self.nextrely += 2
|
self.nextrely += 2
|
||||||
widgets.update(
|
widgets.update(
|
||||||
autoload_directory = self.add_widget_intelligent(
|
autoload_directory = self.add_widget_intelligent(
|
||||||
npyscreen.TitleFilename,
|
# npyscreen.TitleFilename,
|
||||||
|
FileBox,
|
||||||
|
max_height=3,
|
||||||
name=label,
|
name=label,
|
||||||
select_dir=True,
|
select_dir=True,
|
||||||
must_exist=True,
|
must_exist=True,
|
||||||
@ -485,39 +489,67 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
self.parentApp.user_cancelled = True
|
self.parentApp.user_cancelled = True
|
||||||
self.editing = False
|
self.editing = False
|
||||||
|
|
||||||
|
########## This routine monitors the child process that is performing model installation and removal #####
|
||||||
def while_waiting(self):
|
def while_waiting(self):
|
||||||
app = self.parentApp
|
'''Called during idle periods. Main task is to update the Log Messages box with messages
|
||||||
|
from the child process that does the actual installation/removal'''
|
||||||
|
c = self.subprocess_connection
|
||||||
|
if not c:
|
||||||
|
return
|
||||||
|
|
||||||
monitor_widget = self.monitor.entry_widget
|
monitor_widget = self.monitor.entry_widget
|
||||||
if c := self.subprocess_connection:
|
while c.poll():
|
||||||
while c.poll():
|
try:
|
||||||
try:
|
data = c.recv_bytes().decode('utf-8')
|
||||||
data = c.recv_bytes().decode('utf-8')
|
data.strip('\n')
|
||||||
data.strip('\n')
|
|
||||||
if data=='*done*':
|
|
||||||
self.subprocess_connection = None
|
|
||||||
monitor_widget.buffer(['** Action Complete **'])
|
|
||||||
self.display()
|
|
||||||
# rebuild the form, saving log messages
|
|
||||||
saved_messages = monitor_widget.values
|
|
||||||
app.main_form = app.addForm(
|
|
||||||
"MAIN", addModelsForm, name="Install Stable Diffusion Models"
|
|
||||||
)
|
|
||||||
app.switchForm('MAIN')
|
|
||||||
app.main_form.monitor.entry_widget.values = saved_messages
|
|
||||||
app.main_form.monitor.entry_widget.buffer([''],scroll_end=True)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
monitor_widget.buffer(
|
|
||||||
textwrap.wrap(data,
|
|
||||||
width=monitor_widget.width,
|
|
||||||
subsequent_indent=' ',
|
|
||||||
),
|
|
||||||
scroll_end=True
|
|
||||||
)
|
|
||||||
self.display()
|
|
||||||
except (EOFError,OSError):
|
|
||||||
self.subprocess_connection = None
|
|
||||||
|
|
||||||
|
# processing child is requesting user input to select the
|
||||||
|
# right configuration file
|
||||||
|
if data.startswith('*need v2 config'):
|
||||||
|
_,model_path,*_ = data.split(":",2)
|
||||||
|
self._return_v2_config(model_path)
|
||||||
|
|
||||||
|
# processing child is done
|
||||||
|
elif data=='*done*':
|
||||||
|
self._close_subprocess_and_regenerate_form()
|
||||||
|
break
|
||||||
|
|
||||||
|
# update the log message box
|
||||||
|
else:
|
||||||
|
monitor_widget.buffer(
|
||||||
|
textwrap.wrap(data,
|
||||||
|
width=monitor_widget.width,
|
||||||
|
subsequent_indent=' ',
|
||||||
|
),
|
||||||
|
scroll_end=True
|
||||||
|
)
|
||||||
|
self.display()
|
||||||
|
except (EOFError,OSError):
|
||||||
|
self.subprocess_connection = None
|
||||||
|
|
||||||
|
def _return_v2_config(self,model_path: str):
|
||||||
|
c = self.subprocess_connection
|
||||||
|
model_name = Path(model_path).name
|
||||||
|
message = select_stable_diffusion_config_file(model_name=model_name)
|
||||||
|
c.send_bytes(message.encode('utf-8'))
|
||||||
|
|
||||||
|
def _close_subprocess_and_regenerate_form(self):
|
||||||
|
app = self.parentApp
|
||||||
|
self.subprocess_connection.close()
|
||||||
|
self.subprocess_connection = None
|
||||||
|
self.monitor.entry_widget.buffer(['** Action Complete **'])
|
||||||
|
self.display()
|
||||||
|
# rebuild the form, saving log messages
|
||||||
|
saved_messages = self.monitor.entry_widget.values
|
||||||
|
app.main_form = app.addForm(
|
||||||
|
"MAIN", addModelsForm, name="Install Stable Diffusion Models"
|
||||||
|
)
|
||||||
|
app.switchForm('MAIN')
|
||||||
|
app.main_form.monitor.entry_widget.values = saved_messages
|
||||||
|
app.main_form.monitor.entry_widget.buffer([''],scroll_end=True)
|
||||||
|
|
||||||
|
###############################################################
|
||||||
|
|
||||||
def list_additional_diffusers_models(self,
|
def list_additional_diffusers_models(self,
|
||||||
manager: ModelManager,
|
manager: ModelManager,
|
||||||
starters:dict
|
starters:dict
|
||||||
@ -628,7 +660,7 @@ class AddModelApplication(npyscreen.NPSAppManaged):
|
|||||||
def onStart(self):
|
def onStart(self):
|
||||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||||
self.main_form = self.addForm(
|
self.main_form = self.addForm(
|
||||||
"MAIN", addModelsForm, name="Install Stable Diffusion Models"
|
"MAIN", addModelsForm, name="Install Stable Diffusion Models", cycle_widgets=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
class StderrToMessage():
|
class StderrToMessage():
|
||||||
@ -640,6 +672,57 @@ class StderrToMessage():
|
|||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# --------------------------------------------------------
|
||||||
|
def ask_user_for_config_file(model_path: Path,
|
||||||
|
tui_conn: Connection=None
|
||||||
|
)->Path:
|
||||||
|
logger.debug(f'Waiting for user action in dialog box (above).')
|
||||||
|
if tui_conn:
|
||||||
|
return _ask_user_for_cf_tui(model_path, tui_conn)
|
||||||
|
else:
|
||||||
|
return _ask_user_for_cf_cmdline(model_path)
|
||||||
|
|
||||||
|
def _ask_user_for_cf_cmdline(model_path):
|
||||||
|
choices = [
|
||||||
|
config.model_conf_path / 'stable-diffusion' / x
|
||||||
|
for x in ['v2-inference.yaml','v2-inference-v.yaml']
|
||||||
|
]
|
||||||
|
print(
|
||||||
|
f"""
|
||||||
|
Please select the type of the V2 checkpoint named {model_path.name}:
|
||||||
|
[1] A Stable Diffusion v2.x base model (512 pixels; there should be no 'parameterization:' line in its yaml file)
|
||||||
|
[2] A Stable Diffusion v2.x v-predictive model (768 pixels; look for a 'parameterization: "v"' line in its yaml file)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
choice = None
|
||||||
|
ok = False
|
||||||
|
while not ok:
|
||||||
|
try:
|
||||||
|
choice = input('select> ').strip()
|
||||||
|
choice = choices[int(choice)-1]
|
||||||
|
ok = True
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
print(f'{choice} is not a valid choice')
|
||||||
|
except EOFError:
|
||||||
|
return
|
||||||
|
return choice
|
||||||
|
|
||||||
|
def _ask_user_for_cf_tui(model_path: Path, tui_conn: Connection)->Path:
|
||||||
|
try:
|
||||||
|
tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8'))
|
||||||
|
# note that we don't do any status checking here
|
||||||
|
response = tui_conn.recv_bytes().decode('utf-8')
|
||||||
|
if response is None:
|
||||||
|
return None
|
||||||
|
elif response == 'epsilon':
|
||||||
|
return config.legacy_conf_path / 'v2-inference.yaml'
|
||||||
|
elif response == 'v':
|
||||||
|
return config.legacy_conf_path / 'v2-inference-v.yaml'
|
||||||
|
else:
|
||||||
|
return Path(response)
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
def process_and_execute(opt: Namespace,
|
def process_and_execute(opt: Namespace,
|
||||||
@ -651,13 +734,16 @@ def process_and_execute(opt: Namespace,
|
|||||||
translator = StderrToMessage(conn_out)
|
translator = StderrToMessage(conn_out)
|
||||||
sys.stderr = translator
|
sys.stderr = translator
|
||||||
sys.stdout = translator
|
sys.stdout = translator
|
||||||
InvokeAILogger.getLogger().handlers[0]=logging.StreamHandler(translator)
|
logger = InvokeAILogger.getLogger()
|
||||||
|
logger.handlers.clear()
|
||||||
|
logger.addHandler(logging.StreamHandler(translator))
|
||||||
|
|
||||||
models_to_install = selections.install_models
|
models_to_install = selections.install_models
|
||||||
models_to_remove = selections.remove_models
|
models_to_remove = selections.remove_models
|
||||||
directory_to_scan = selections.scan_directory
|
directory_to_scan = selections.scan_directory
|
||||||
scan_at_startup = selections.autoscan_on_startup
|
scan_at_startup = selections.autoscan_on_startup
|
||||||
potential_models_to_install = selections.import_model_paths
|
potential_models_to_install = selections.import_model_paths
|
||||||
|
|
||||||
install_requested_models(
|
install_requested_models(
|
||||||
diffusers = ModelInstallList(models_to_install, models_to_remove),
|
diffusers = ModelInstallList(models_to_install, models_to_remove),
|
||||||
controlnet = ModelInstallList(selections.install_cn_models, selections.remove_cn_models),
|
controlnet = ModelInstallList(selections.install_cn_models, selections.remove_cn_models),
|
||||||
@ -671,6 +757,7 @@ def process_and_execute(opt: Namespace,
|
|||||||
else choose_precision(torch.device(choose_torch_device())),
|
else choose_precision(torch.device(choose_torch_device())),
|
||||||
purge_deleted=selections.purge_deleted_models,
|
purge_deleted=selections.purge_deleted_models,
|
||||||
config_file_path=Path(opt.config_file) if opt.config_file else None,
|
config_file_path=Path(opt.config_file) if opt.config_file else None,
|
||||||
|
model_config_file_callback = lambda x: ask_user_for_config_file(x,conn_out)
|
||||||
)
|
)
|
||||||
|
|
||||||
if conn_out:
|
if conn_out:
|
||||||
|
@ -8,9 +8,12 @@ import platform
|
|||||||
import pyperclip
|
import pyperclip
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
|
import npyscreen
|
||||||
|
import textwrap
|
||||||
|
import npyscreen.wgmultiline as wgmultiline
|
||||||
|
from npyscreen import fmPopup
|
||||||
from shutil import get_terminal_size
|
from shutil import get_terminal_size
|
||||||
from curses import BUTTON2_CLICKED,BUTTON3_CLICKED
|
from curses import BUTTON2_CLICKED,BUTTON3_CLICKED
|
||||||
import npyscreen
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def set_terminal_size(columns: int, lines: int):
|
def set_terminal_size(columns: int, lines: int):
|
||||||
@ -151,8 +154,13 @@ class MultiSelectColumns( SelectColumnBase, npyscreen.MultiSelect):
|
|||||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||||
super().__init__(screen, values=values, **keywords)
|
super().__init__(screen, values=values, **keywords)
|
||||||
|
|
||||||
|
class SingleSelectWithChanged(npyscreen.SelectOne):
|
||||||
|
def h_select(self,ch):
|
||||||
|
super().h_select(ch)
|
||||||
|
if self.on_changed:
|
||||||
|
self.on_changed(self.value)
|
||||||
|
|
||||||
class SingleSelectColumns(SelectColumnBase, npyscreen.SelectOne):
|
class SingleSelectColumns(SelectColumnBase, SingleSelectWithChanged):
|
||||||
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
|
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
|
||||||
self.columns = columns
|
self.columns = columns
|
||||||
self.value_cnt = len(values)
|
self.value_cnt = len(values)
|
||||||
@ -160,11 +168,6 @@ class SingleSelectColumns(SelectColumnBase, npyscreen.SelectOne):
|
|||||||
self.on_changed = None
|
self.on_changed = None
|
||||||
super().__init__(screen, values=values, **keywords)
|
super().__init__(screen, values=values, **keywords)
|
||||||
|
|
||||||
def h_select(self,ch):
|
|
||||||
super().h_select(ch)
|
|
||||||
if self.on_changed:
|
|
||||||
self.on_changed(self.value)
|
|
||||||
|
|
||||||
def when_value_edited(self):
|
def when_value_edited(self):
|
||||||
self.h_select(self.cursor_line)
|
self.h_select(self.cursor_line)
|
||||||
|
|
||||||
@ -271,4 +274,90 @@ class TextBox(npyscreen.MultiLineEdit):
|
|||||||
class BufferBox(npyscreen.BoxTitle):
|
class BufferBox(npyscreen.BoxTitle):
|
||||||
_contained_widget = npyscreen.BufferPager
|
_contained_widget = npyscreen.BufferPager
|
||||||
|
|
||||||
|
class ConfirmCancelPopup(fmPopup.ActionPopup):
|
||||||
|
DEFAULT_COLUMNS = 100
|
||||||
|
def on_ok(self):
|
||||||
|
self.value = True
|
||||||
|
def on_cancel(self):
|
||||||
|
self.value = False
|
||||||
|
|
||||||
|
class FileBox(npyscreen.BoxTitle):
|
||||||
|
_contained_widget = npyscreen.Filename
|
||||||
|
|
||||||
|
def _wrap_message_lines(message, line_length):
|
||||||
|
lines = []
|
||||||
|
for line in message.split('\n'):
|
||||||
|
lines.extend(textwrap.wrap(line.rstrip(), line_length))
|
||||||
|
return lines
|
||||||
|
|
||||||
|
def _prepare_message(message):
|
||||||
|
if isinstance(message, list) or isinstance(message, tuple):
|
||||||
|
return "\n".join([ s.rstrip() for s in message])
|
||||||
|
#return "\n".join(message)
|
||||||
|
else:
|
||||||
|
return message
|
||||||
|
|
||||||
|
def select_stable_diffusion_config_file(
|
||||||
|
form_color: str='DANGER',
|
||||||
|
wrap:bool =True,
|
||||||
|
model_name:str='Unknown',
|
||||||
|
):
|
||||||
|
message = "Please select the correct base model for the V2 checkpoint named {model_name}. Press <CANCEL> to skip installation."
|
||||||
|
title = "CONFIG FILE SELECTION"
|
||||||
|
options=[
|
||||||
|
"An SD v2.x base model (512 pixels; no 'parameterization:' line in its yaml file)",
|
||||||
|
"An SD v2.x v-predictive model (768 pixels; 'parameterization: \"v\"' line in its yaml file)",
|
||||||
|
"Enter config file path manually",
|
||||||
|
]
|
||||||
|
|
||||||
|
F = ConfirmCancelPopup(
|
||||||
|
name=title,
|
||||||
|
color=form_color,
|
||||||
|
cycle_widgets=True,
|
||||||
|
lines=16,
|
||||||
|
)
|
||||||
|
F.preserve_selected_widget = True
|
||||||
|
|
||||||
|
mlw = F.add(
|
||||||
|
wgmultiline.Pager,
|
||||||
|
max_height=4,
|
||||||
|
editable=False,
|
||||||
|
)
|
||||||
|
mlw_width = mlw.width-1
|
||||||
|
if wrap:
|
||||||
|
message = _wrap_message_lines(message, mlw_width)
|
||||||
|
mlw.values = message
|
||||||
|
|
||||||
|
choice = F.add(
|
||||||
|
SingleSelectWithChanged,
|
||||||
|
values = options,
|
||||||
|
value = [0],
|
||||||
|
max_height = len(options)+1,
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
file = F.add(
|
||||||
|
FileBox,
|
||||||
|
name='Path to config file',
|
||||||
|
max_height=3,
|
||||||
|
hidden=True,
|
||||||
|
must_exist=True,
|
||||||
|
scroll_exit=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def toggle_visible(value):
|
||||||
|
value = value[0]
|
||||||
|
if value==2:
|
||||||
|
file.hidden=False
|
||||||
|
else:
|
||||||
|
file.hidden=True
|
||||||
|
F.display()
|
||||||
|
|
||||||
|
choice.on_changed = toggle_visible
|
||||||
|
|
||||||
|
F.editw = 1
|
||||||
|
F.edit()
|
||||||
|
if not F.value:
|
||||||
|
return None
|
||||||
|
assert choice.value[0] in range(0,3),'invalid choice'
|
||||||
|
choices = ['epsilon','v',file.value]
|
||||||
|
return choices[choice.value[0]]
|
||||||
|
@ -13,6 +13,7 @@ def main():
|
|||||||
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
if '--web' in sys.argv:
|
if '--web' in sys.argv:
|
||||||
|
sys.argv.remove('--web')
|
||||||
from invokeai.app.api_app import invoke_api
|
from invokeai.app.api_app import invoke_api
|
||||||
invoke_api()
|
invoke_api()
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user