implemented tabbed model selection; not wired to backend yet

This commit is contained in:
Lincoln Stein 2023-06-01 00:31:46 -04:00
parent d6530df635
commit e9821ab711
4 changed files with 269 additions and 149 deletions

View File

@ -6,6 +6,7 @@ import re
import shutil import shutil
import sys import sys
import warnings import warnings
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from tempfile import TemporaryFile from tempfile import TemporaryFile
from typing import List, Dict from typing import List, Dict
@ -20,9 +21,8 @@ from tqdm import tqdm
import invokeai.configs as configs import invokeai.configs as configs
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import get_invokeai_config
from ..model_management import ModelManager
from ..stable_diffusion import StableDiffusionGeneratorPipeline from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util.logging import InvokeAILogger
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -37,6 +37,9 @@ Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
# initial models omegaconf # initial models omegaconf
Datasets = None Datasets = None
# logger
logger = InvokeAILogger.getLogger(name='InvokeAI')
Config_preamble = """ Config_preamble = """
# This file describes the alternative machine learning models # This file describes the alternative machine learning models
# available to InvokeAI script. # available to InvokeAI script.
@ -47,6 +50,11 @@ Config_preamble = """
# was trained on. # was trained on.
""" """
@dataclass
class ModelInstallList:
'''Class for listing models to be installed/removed'''
install_models: List[str]
remove_models: List[str]
def default_config_file(): def default_config_file():
return config.model_conf_path return config.model_conf_path
@ -61,11 +69,11 @@ def initial_models():
return (Datasets := OmegaConf.load(Dataset_path)['diffusers']) return (Datasets := OmegaConf.load(Dataset_path)['diffusers'])
def install_requested_models( def install_requested_models(
install_initial_models: List[str] = None, diffusers: ModelInstallList = None,
remove_models: List[str] = None, controlnet: ModelInstallList = None,
install_cn_models: List[str] = None, lora: ModelInstallList = None,
remove_cn_models: List[str] = None, ti: ModelInstallList = None,
cn_model_map: Dict[str,str] = None, cn_model_map: Dict[str,str] = None, # temporary - move to model manager
scan_directory: Path = None, scan_directory: Path = None,
external_models: List[str] = None, external_models: List[str] = None,
scan_at_startup: bool = False, scan_at_startup: bool = False,
@ -81,33 +89,30 @@ def install_requested_models(
if not config_file_path.exists(): if not config_file_path.exists():
open(config_file_path, "w") open(config_file_path, "w")
install_controlnet_models( # prevent circular import here
install_cn_models, from ..model_management import ModelManager
short_name_map = cn_model_map,
precision=precision,
access_token=access_token,
)
delete_controlnet_models(remove_cn_models)
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision) model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token)
model_manager.delete_controlnet_models(controlnet.remove_models)
if remove_models and len(remove_models) > 0: # TODO: Replace next three paragraphs with calls into new model manager
print("== DELETING UNCHECKED STARTER MODELS ==") if diffusers.remove_models and len(diffusers.remove_models) > 0:
for model in remove_models: logger.info("DELETING UNCHECKED STARTER MODELS")
print(f"{model}...") for model in diffusers.remove_models:
logger.info(f"{model}...")
model_manager.del_model(model, delete_files=purge_deleted) model_manager.del_model(model, delete_files=purge_deleted)
model_manager.commit(config_file_path) model_manager.commit(config_file_path)
if install_initial_models and len(install_initial_models) > 0: if diffusers.install_models and len(diffusers.install_models) > 0:
print("== INSTALLING SELECTED STARTER MODELS ==") logger.info("INSTALLING SELECTED STARTER MODELS")
successfully_downloaded = download_weight_datasets( successfully_downloaded = download_weight_datasets(
models=install_initial_models, models=diffusers.install_initial_models,
access_token=None, access_token=None,
precision=precision, precision=precision,
) # FIX: for historical reasons, we don't use model manager here ) # FIX: for historical reasons, we don't use model manager here
update_config_file(successfully_downloaded, config_file_path) update_config_file(successfully_downloaded, config_file_path)
if len(successfully_downloaded) < len(install_initial_models): if len(successfully_downloaded) < len(diffusers.install_models):
print("** Some of the model downloads were not successful") logger.warning("Some of the model downloads were not successful")
# due to above, we have to reload the model manager because conf file # due to above, we have to reload the model manager because conf file
# was changed behind its back # was changed behind its back
@ -118,7 +123,7 @@ def install_requested_models(
external_models.append(str(scan_directory)) external_models.append(str(scan_directory))
if len(external_models) > 0: if len(external_models) > 0:
print("== INSTALLING EXTERNAL MODELS ==") logger.info("INSTALLING EXTERNAL MODELS")
for path_url_or_repo in external_models: for path_url_or_repo in external_models:
try: try:
model_manager.heuristic_import( model_manager.heuristic_import(
@ -190,10 +195,10 @@ def migrate_models_ckpt():
if not os.path.exists(os.path.join(model_path, "model.ckpt")): if not os.path.exists(os.path.join(model_path, "model.ckpt")):
return return
new_name = initial_models()["stable-diffusion-1.4"]["file"] new_name = initial_models()["stable-diffusion-1.4"]["file"]
print( logger.warning(
'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.' 'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.'
) )
print(f"model.ckpt => {new_name}") logger.warning(f"model.ckpt => {new_name}")
os.replace( os.replace(
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name) os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
) )
@ -206,7 +211,7 @@ def download_weight_datasets(
migrate_models_ckpt() migrate_models_ckpt()
successful = dict() successful = dict()
for mod in models: for mod in models:
print(f"Downloading {mod}:") logger.info(f"Downloading {mod}:")
successful[mod] = _download_repo_or_file( successful[mod] = _download_repo_or_file(
initial_models()[mod], access_token, precision=precision initial_models()[mod], access_token, precision=precision
) )
@ -240,68 +245,6 @@ def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
) )
# ---------------------------------------------
def install_controlnet_models(
short_names: List[str],
short_name_map: Dict[str,str],
precision: str='float16',
access_token: str = None,
):
'''
Download list of controlnet models, using their HuggingFace
repo_ids.
'''
dest_dir = config.controlnet_path
if not dest_dir.exists():
dest_dir.mkdir(parents=True,exist_ok=False)
# The model file may be fp32 or fp16, and may be either a
# .bin file or a .safetensors. We try each until we get one,
# preferring 'fp16' if using half precision, and preferring
# safetensors over over bin.
precisions = ['.fp16',''] if precision=='float16' else ['']
formats = ['.safetensors','.bin']
possible_filenames = list()
for p in precisions:
for f in formats:
possible_filenames.append(Path(f'diffusion_pytorch_model{p}{f}'))
for directory_name in short_names:
repo_id = short_name_map[directory_name]
safe_name = directory_name.replace('/','--')
print(f'Downloading ControlNet model {directory_name} ({repo_id})')
hf_download_with_resume(
repo_id = repo_id,
model_dir = dest_dir / safe_name,
model_name = 'config.json',
access_token = access_token
)
path = None
for filename in possible_filenames:
suffix = filename.suffix
dest_filename = Path(f'diffusion_pytorch_model{suffix}')
print(f'Probing {directory_name}/{filename}...')
path = hf_download_with_resume(
repo_id = repo_id,
model_dir = dest_dir / safe_name,
model_name = str(filename),
access_token = access_token,
model_dest = Path(dest_dir, safe_name, dest_filename),
)
if path:
(path.parent / '.download_complete').touch()
break
# ---------------------------------------------
def delete_controlnet_models(short_names: List[str]):
for name in short_names:
safe_name = name.replace('/','--')
directory = config.controlnet_path / safe_name
if directory.exists():
print(f'Purging controlnet model {name}')
shutil.rmtree(str(directory))
# --------------------------------------------- # ---------------------------------------------
def download_from_hf( def download_from_hf(
model_class: object, model_name: str, **kwargs model_class: object, model_name: str, **kwargs
@ -340,7 +283,7 @@ def _download_diffusion_weights(
if str(e).startswith("fp16 is not a valid"): if str(e).startswith("fp16 is not a valid"):
pass pass
else: else:
print(f"An unexpected error occurred while downloading the model: {e})") logger.error(f"An unexpected error occurred while downloading the model: {e})")
if path: if path:
break break
return path return path
@ -374,17 +317,17 @@ def hf_download_with_resume(
if ( if (
resp.status_code == 416 resp.status_code == 416
): # "range not satisfiable", which means nothing to return ): # "range not satisfiable", which means nothing to return
print(f"* {model_name}: complete file found. Skipping.") logger.info(f"{model_name}: complete file found. Skipping.")
return model_dest return model_dest
elif resp.status_code == 404: elif resp.status_code == 404:
print("** File not found") logger.warning("File not found")
return None return None
elif resp.status_code != 200: elif resp.status_code != 200:
print(f"** Warning: {model_name}: {resp.reason}") logger.warning(f"{model_name}: {resp.reason}")
elif exist_size > 0: elif exist_size > 0:
print(f"* {model_name}: partial file found. Resuming...") logger.info(f"{model_name}: partial file found. Resuming...")
else: else:
print(f"* {model_name}: Downloading...") logger.info(f"{model_name}: Downloading...")
try: try:
with open(model_dest, open_mode) as file, tqdm( with open(model_dest, open_mode) as file, tqdm(
@ -399,7 +342,7 @@ def hf_download_with_resume(
size = file.write(data) size = file.write(data)
bar.update(size) bar.update(size)
except Exception as e: except Exception as e:
print(f"An error occurred while downloading {model_name}: {str(e)}") logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
return None return None
return model_dest return model_dest
@ -424,8 +367,8 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
try: try:
backup = None backup = None
if os.path.exists(config_file): if os.path.exists(config_file):
print( logger.warning(
f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig" f"{config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
) )
backup = config_file.with_suffix(".yaml.orig") backup = config_file.with_suffix(".yaml.orig")
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183 ## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
@ -442,16 +385,16 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
new_config.write(tmp.read()) new_config.write(tmp.read())
except Exception as e: except Exception as e:
print(f"**Error creating config file {config_file}: {str(e)} **") logger.error(f"Error creating config file {config_file}: {str(e)}")
if backup is not None: if backup is not None:
print("restoring previous config file") logger.info("restoring previous config file")
## workaround, for WinError 183, see above ## workaround, for WinError 183, see above
if sys.platform == "win32" and config_file.is_file(): if sys.platform == "win32" and config_file.is_file():
config_file.unlink() config_file.unlink()
backup.rename(config_file) backup.rename(config_file)
return return
print(f"Successfully created new configuration file {config_file}") logger.info(f"Successfully created new configuration file {config_file}")
# --------------------------------------------- # ---------------------------------------------
@ -518,8 +461,8 @@ def delete_weights(model_name: str, conf_stanza: dict):
if re.match("/VAE/", conf_stanza.get("config")): if re.match("/VAE/", conf_stanza.get("config")):
return return
print( logger.warning(
f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?" f"\nThe checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
) )
weights = Path(weights) weights = Path(weights)
@ -528,4 +471,4 @@ def delete_weights(model_name: str, conf_stanza: dict):
try: try:
weights.unlink() weights.unlink()
except OSError as e: except OSError as e:
print(str(e)) logger.error(str(e))

View File

@ -11,6 +11,7 @@ import gc
import hashlib import hashlib
import os import os
import re import re
import shutil
import sys import sys
import textwrap import textwrap
import time import time
@ -49,6 +50,10 @@ from ..stable_diffusion import (
StableDiffusionGeneratorPipeline, StableDiffusionGeneratorPipeline,
) )
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import get_invokeai_config
from ..install.model_install_backend import (
Dataset_path,
hf_download_with_resume,
)
from ..util import CUDA_DEVICE, ask_user, download_with_resume from ..util import CUDA_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum): class SDLegacyType(Enum):
@ -1316,3 +1321,74 @@ class ModelManager(object):
return ( return (
os.getenv("HF_HOME") is not None or os.getenv("XDG_CACHE_HOME") is not None os.getenv("HF_HOME") is not None or os.getenv("XDG_CACHE_HOME") is not None
) )
def list_controlnet_models(self)->Dict[str,bool]:
'''Return a dict of installed controlnet models; key is repo_id or short name
of model (defined in INITIAL_MODELS), and valule is True if installed'''
cn_models = OmegaConf.load(Dataset_path).get('controlnet') or {}
installed_models = {x: False for x in cn_models.keys()}
cn_dir = self.globals.controlnet_path
installed_cn_models = dict()
for root, dirs, files in os.walk(cn_dir):
for name in dirs:
if Path(root, name, '.download_complete').exists():
installed_models.update({name.replace('--','/'): True})
return installed_models
def install_controlnet_models(self, model_names: list[str], access_token: str=None):
'''Download list of controlnet models; provide either repo_id or short name listed in INITIAL_MODELS.yaml'''
short_names = OmegaConf.load(Dataset_path).get('controlnet') or {}
dest_dir = self.globals.controlnet_path
dest_dir.mkdir(parents=True,exist_ok=True)
# The model file may be fp32 or fp16, and may be either a
# .bin file or a .safetensors. We try each until we get one,
# preferring 'fp16' if using half precision, and preferring
# safetensors over over bin.
precisions = ['.fp16',''] if self.precision=='float16' else ['']
formats = ['.safetensors','.bin']
possible_filenames = list()
for p in precisions:
for f in formats:
possible_filenames.append(Path(f'diffusion_pytorch_model{p}{f}'))
for directory_name in model_names:
repo_id = short_names.get(directory_name) or directory_name
safe_name = directory_name.replace('/','--')
self.logger.info(f'Downloading ControlNet model {directory_name} ({repo_id})')
hf_download_with_resume(
repo_id = repo_id,
model_dir = dest_dir / safe_name,
model_name = 'config.json',
access_token = access_token
)
path = None
for filename in possible_filenames:
suffix = filename.suffix
dest_filename = Path(f'diffusion_pytorch_model{suffix}')
self.logger.info(f'Checking availability of {directory_name}/{filename}...')
path = hf_download_with_resume(
repo_id = repo_id,
model_dir = dest_dir / safe_name,
model_name = str(filename),
access_token = access_token,
model_dest = Path(dest_dir, safe_name, dest_filename),
)
if path:
(path.parent / '.download_complete').touch()
break
def delete_controlnet_models(self, model_names: List[str]):
'''Remove the list of controlnet models'''
for name in model_names:
safe_name = name.replace('/','--')
directory = self.globals.controlnet_path / safe_name
if directory.exists():
self.logger.info(f'Purging controlnet model {name}')
shutil.rmtree(str(directory))

View File

@ -30,11 +30,15 @@ from ...backend.install.model_install_backend import (
default_dataset, default_dataset,
install_requested_models, install_requested_models,
recommended_datasets, recommended_datasets,
ModelInstallList,
dataclass,
) )
from ...backend import ModelManager
from ...backend.util import choose_precision, choose_torch_device from ...backend.util import choose_precision, choose_torch_device
from .widgets import ( from .widgets import (
CenteredTitleText, CenteredTitleText,
MultiSelectColumns, MultiSelectColumns,
SingleSelectColumns,
OffsetButtonPress, OffsetButtonPress,
TextBox, TextBox,
set_min_terminal_size, set_min_terminal_size,
@ -53,12 +57,12 @@ class addModelsForm(npyscreen.FormMultiPage):
def __init__(self, parentApp, name, multipage=False, *args, **keywords): def __init__(self, parentApp, name, multipage=False, *args, **keywords):
self.multipage = multipage self.multipage = multipage
model_manager = ModelManager(config.model_conf_path)
self.initial_models = OmegaConf.load(Dataset_path)['diffusers'] self.initial_models = OmegaConf.load(Dataset_path)['diffusers']
self.control_net_models = OmegaConf.load(Dataset_path)['controlnet'] self.installed_cn_models = model_manager.list_controlnet_models()
self.installed_cn_models = self._get_installed_cn_models()
self._add_additional_cn_models(self.control_net_models,self.installed_cn_models)
try: try:
self.existing_models = OmegaConf.load(default_config_file()) self.existing_models = OmegaConf.load(default_config_file())
except: except:
@ -79,7 +83,7 @@ class addModelsForm(npyscreen.FormMultiPage):
[x for x in list(self.initial_models.keys()) if x in self.existing_models] [x for x in list(self.initial_models.keys()) if x in self.existing_models]
) )
cn_model_list = sorted(self.control_net_models.keys()) cn_model_list = sorted(self.installed_cn_models.keys())
self.nextrely -= 1 self.nextrely -= 1
self.add_widget_intelligent( self.add_widget_intelligent(
@ -180,16 +184,34 @@ class addModelsForm(npyscreen.FormMultiPage):
relx=4, relx=4,
scroll_exit=True, scroll_exit=True,
) )
self.add_widget_intelligent( self.add_widget_intelligent(
CenteredTitleText, CenteredTitleText,
name="== CONTROLNET MODELS ==", name='_' * (window_width-5),
editable=False, editable=False,
color="CONTROL", labelColor='CAUTION'
) )
self.nextrely -= 1
self.add_widget_intelligent( self.nextrely += 1
self.tabs = self.add_widget_intelligent(
SingleSelectColumns,
values=['ADD CONTROLNET MODELS','ADD LORA/LYCORIS MODELS', 'ADD TEXTUAL INVERSION MODELS'],
value=0,
columns = 4,
max_height = 2,
relx=8,
scroll_exit = True,
)
# self.add_widget_intelligent(
# CenteredTitleText,
# name="== CONTROLNET MODELS ==",
# editable=False,
# color="CONTROL",
# )
top_of_table = self.nextrely
self.cn_label_1 = self.add_widget_intelligent(
CenteredTitleText, CenteredTitleText,
name="Select the desired ControlNet models. Unchecked models will be purged from disk.", name="Select the desired models to install. Unchecked models will be purged from disk.",
editable=False, editable=False,
labelColor="CAUTION", labelColor="CAUTION",
) )
@ -202,14 +224,14 @@ class addModelsForm(npyscreen.FormMultiPage):
value=[ value=[
cn_model_list.index(x) cn_model_list.index(x)
for x in cn_model_list for x in cn_model_list
if x in self.installed_cn_models if self.installed_cn_models[x]
], ],
max_height=len(cn_model_list)//columns + 1, max_height=len(cn_model_list)//columns + 1,
relx=4, relx=4,
scroll_exit=True, scroll_exit=True,
) )
self.nextrely += 1 self.nextrely += 1
self.add_widget_intelligent( self.cn_label_2 = self.add_widget_intelligent(
npyscreen.TitleFixedText, npyscreen.TitleFixedText,
name='Additional ControlNet HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):', name='Additional ControlNet HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):',
relx=4, relx=4,
@ -221,6 +243,55 @@ class addModelsForm(npyscreen.FormMultiPage):
self.additional_controlnet_ids = self.add_widget_intelligent( self.additional_controlnet_ids = self.add_widget_intelligent(
TextBox, max_height=2, scroll_exit=True, editable=True, relx=4 TextBox, max_height=2, scroll_exit=True, editable=True, relx=4
) )
bottom_of_table = self.nextrely
self.nextrely = top_of_table
self.lora_label_1 = self.add_widget_intelligent(
npyscreen.TitleFixedText,
name='LoRA/LYCORIS models to download and install (Space separated. Use shift-control-V to paste):',
relx=4,
color='CONTROL',
editable=False,
hidden=True,
scroll_exit=True
)
self.nextrely -= 1
self.loras = self.add_widget_intelligent(
TextBox,
max_height=2,
scroll_exit=True,
editable=True,
relx=4,
hidden=True,
)
self.nextrely = top_of_table
self.ti_label_1 = self.add_widget_intelligent(
npyscreen.TitleFixedText,
name='Textual Inversion models to download and install (Space separated. Use shift-control-V to paste):',
relx=4,
color='CONTROL',
editable=False,
hidden=True,
scroll_exit=True
)
self.nextrely -= 1
self.tis = self.add_widget_intelligent(
TextBox,
max_height=2,
scroll_exit=True,
editable=True,
relx=4,
hidden=True,
)
self.nextrely = bottom_of_table
self.nextrely += 1
self.add_widget_intelligent(
CenteredTitleText,
name='_' * (window_width-5),
editable=False,
labelColor='CAUTION'
)
self.cancel = self.add_widget_intelligent( self.cancel = self.add_widget_intelligent(
npyscreen.ButtonPress, npyscreen.ButtonPress,
name="CANCEL", name="CANCEL",
@ -254,6 +325,9 @@ class addModelsForm(npyscreen.FormMultiPage):
for i in [self.autoload_directory, self.autoscan_on_startup]: for i in [self.autoload_directory, self.autoscan_on_startup]:
self.show_directory_fields.addVisibleWhenSelected(i) self.show_directory_fields.addVisibleWhenSelected(i)
# self.tabs.when_value_edited = self._toggle_tables
self.tabs.on_changed = self._toggle_tables
self.show_directory_fields.when_value_edited = self._clear_scan_directory self.show_directory_fields.when_value_edited = self._clear_scan_directory
def resize(self): def resize(self):
@ -261,6 +335,21 @@ class addModelsForm(npyscreen.FormMultiPage):
if hasattr(self, "models_selected"): if hasattr(self, "models_selected"):
self.models_selected.values = self._get_starter_model_labels() self.models_selected.values = self._get_starter_model_labels()
def _toggle_tables(self, value=None):
selected_tab = value[0] if value else self.tabs.value[0]
widgets = [
[self.cn_label_1, self.cn_models_selected, self.cn_label_2, self.additional_controlnet_ids],
[self.lora_label_1,self.loras],
[self.ti_label_1,self.tis],
]
for group in widgets:
for w in group:
w.hidden = True
for w in widgets[selected_tab]:
w.hidden = False
self.display()
def _clear_scan_directory(self): def _clear_scan_directory(self):
if not self.show_directory_fields.value: if not self.show_directory_fields.value:
self.autoload_directory.value = "" self.autoload_directory.value = ""
@ -361,20 +450,18 @@ class addModelsForm(npyscreen.FormMultiPage):
selections.install_models = [x for x in starter_models if x not in self.existing_models] selections.install_models = [x for x in starter_models if x not in self.existing_models]
selections.remove_models = [x for x in self.starter_model_list if x in self.existing_models and x not in starter_models] selections.remove_models = [x for x in self.starter_model_list if x in self.existing_models and x not in starter_models]
selections.control_net_map = self.control_net_models
selections.install_cn_models = [self.cn_models_selected.values[x] selections.install_cn_models = [self.cn_models_selected.values[x]
for x in self.cn_models_selected.value for x in self.cn_models_selected.value
if self.cn_models_selected.values[x] not in self.installed_cn_models if not self.installed_cn_models[self.cn_models_selected.values[x]]
] ]
selections.remove_cn_models = [x selections.remove_cn_models = [x
for x in self.cn_models_selected.values for x in self.cn_models_selected.values
if x in self.installed_cn_models if self.installed_cn_models[x]
and self.cn_models_selected.values.index(x) not in self.cn_models_selected.value and self.cn_models_selected.values.index(x) not in self.cn_models_selected.value
] ]
if (additional_cns := self.additional_controlnet_ids.value.split()): if (additional_cns := self.additional_controlnet_ids.value.split()):
valid_cns = [x for x in additional_cns if '/' in x] valid_cns = [x for x in additional_cns if '/' in x]
selections.install_cn_models.extend(valid_cns) selections.install_cn_models.extend(valid_cns)
selections.control_net_map.update({x: x for x in valid_cns})
# load directory and whether to scan on startup # load directory and whether to scan on startup
if self.show_directory_fields.value: if self.show_directory_fields.value:
@ -387,22 +474,22 @@ class addModelsForm(npyscreen.FormMultiPage):
# URLs and the like # URLs and the like
selections.import_model_paths = self.import_model_paths.value.split() selections.import_model_paths = self.import_model_paths.value.split()
@dataclass
class UserSelections():
install_models: List[str]=None
remove_models: List[str]=None
purge_deleted_models: bool=False,
install_cn_models: List[str] = None,
remove_cn_models: List[str] = None,
scan_directory: Path=None,
autoscan_on_startup: bool=False,
import_model_paths: str=None,
class AddModelApplication(npyscreen.NPSAppManaged): class AddModelApplication(npyscreen.NPSAppManaged):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.user_cancelled = False self.user_cancelled = False
self.user_selections = Namespace( self.user_selections = UserSelections()
install_models=None,
remove_models=None,
purge_deleted_models=False,
install_cn_models = None,
remove_cn_models = None,
control_net_map = None,
scan_directory=None,
autoscan_on_startup=None,
import_model_paths=None,
)
def onStart(self): def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme) npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
@ -418,15 +505,9 @@ def process_and_execute(opt: Namespace, selections: Namespace):
directory_to_scan = selections.scan_directory directory_to_scan = selections.scan_directory
scan_at_startup = selections.autoscan_on_startup scan_at_startup = selections.autoscan_on_startup
potential_models_to_install = selections.import_model_paths potential_models_to_install = selections.import_model_paths
print(f'selections.install_cn_models={selections.install_cn_models}')
print(f'selections.remove_cn_models={selections.remove_cn_models}')
print(f'selections.cn_model_map={selections.control_net_map}')
install_requested_models( install_requested_models(
install_initial_models=models_to_install, diffusers = ModelInstallList(models_to_install, models_to_remove),
remove_models=models_to_remove, controlnet = ModelInstallList(selections.install_cn_models, selections.remove_cn_models),
install_cn_models=selections.install_cn_models,
remove_cn_models=selections.remove_cn_models,
cn_model_map=selections.control_net_map,
scan_directory=Path(directory_to_scan) if directory_to_scan else None, scan_directory=Path(directory_to_scan) if directory_to_scan else None,
external_models=potential_models_to_install, external_models=potential_models_to_install,
scan_at_startup=scan_at_startup, scan_at_startup=scan_at_startup,

View File

@ -94,13 +94,7 @@ class FloatTitleSlider(npyscreen.TitleText):
_entry_type = FloatSlider _entry_type = FloatSlider
class MultiSelectColumns(npyscreen.MultiSelect): class SelectColumnBase():
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
self.columns = columns
self.value_cnt = len(values)
self.rows = math.ceil(self.value_cnt / self.columns)
super().__init__(screen, values=values, **keywords)
def make_contained_widgets(self): def make_contained_widgets(self):
self._my_widgets = [] self._my_widgets = []
column_width = self.width // self.columns column_width = self.width // self.columns
@ -150,6 +144,32 @@ class MultiSelectColumns(npyscreen.MultiSelect):
def h_cursor_line_right(self, ch): def h_cursor_line_right(self, ch):
super().h_cursor_line_down(ch) super().h_cursor_line_down(ch)
class MultiSelectColumns( SelectColumnBase, npyscreen.MultiSelect):
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
self.columns = columns
self.value_cnt = len(values)
self.rows = math.ceil(self.value_cnt / self.columns)
super().__init__(screen, values=values, **keywords)
class SingleSelectColumns(SelectColumnBase, npyscreen.SelectOne):
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
self.columns = columns
self.value_cnt = len(values)
self.rows = math.ceil(self.value_cnt / self.columns)
self.on_changed = None
super().__init__(screen, values=values, **keywords)
def h_select(self,ch):
super().h_select(ch)
if self.on_changed:
self.on_changed(self.value)
def when_value_edited(self):
self.h_select(self.cursor_line)
def when_cursor_moved(self):
self.h_select(self.cursor_line)
class TextBox(npyscreen.MultiLineEdit): class TextBox(npyscreen.MultiLineEdit):
def update(self, clear=True): def update(self, clear=True):