mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
implemented tabbed model selection; not wired to backend yet
This commit is contained in:
parent
d6530df635
commit
e9821ab711
@ -6,6 +6,7 @@ import re
|
||||
import shutil
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryFile
|
||||
from typing import List, Dict
|
||||
@ -20,9 +21,8 @@ from tqdm import tqdm
|
||||
import invokeai.configs as configs
|
||||
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
from ..model_management import ModelManager
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
|
||||
from ..util.logging import InvokeAILogger
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
@ -37,6 +37,9 @@ Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
# initial models omegaconf
|
||||
Datasets = None
|
||||
|
||||
# logger
|
||||
logger = InvokeAILogger.getLogger(name='InvokeAI')
|
||||
|
||||
Config_preamble = """
|
||||
# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
@ -47,6 +50,11 @@ Config_preamble = """
|
||||
# was trained on.
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class ModelInstallList:
|
||||
'''Class for listing models to be installed/removed'''
|
||||
install_models: List[str]
|
||||
remove_models: List[str]
|
||||
|
||||
def default_config_file():
|
||||
return config.model_conf_path
|
||||
@ -61,11 +69,11 @@ def initial_models():
|
||||
return (Datasets := OmegaConf.load(Dataset_path)['diffusers'])
|
||||
|
||||
def install_requested_models(
|
||||
install_initial_models: List[str] = None,
|
||||
remove_models: List[str] = None,
|
||||
install_cn_models: List[str] = None,
|
||||
remove_cn_models: List[str] = None,
|
||||
cn_model_map: Dict[str,str] = None,
|
||||
diffusers: ModelInstallList = None,
|
||||
controlnet: ModelInstallList = None,
|
||||
lora: ModelInstallList = None,
|
||||
ti: ModelInstallList = None,
|
||||
cn_model_map: Dict[str,str] = None, # temporary - move to model manager
|
||||
scan_directory: Path = None,
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
@ -81,33 +89,30 @@ def install_requested_models(
|
||||
if not config_file_path.exists():
|
||||
open(config_file_path, "w")
|
||||
|
||||
install_controlnet_models(
|
||||
install_cn_models,
|
||||
short_name_map = cn_model_map,
|
||||
precision=precision,
|
||||
access_token=access_token,
|
||||
)
|
||||
delete_controlnet_models(remove_cn_models)
|
||||
|
||||
# prevent circular import here
|
||||
from ..model_management import ModelManager
|
||||
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
|
||||
model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token)
|
||||
model_manager.delete_controlnet_models(controlnet.remove_models)
|
||||
|
||||
if remove_models and len(remove_models) > 0:
|
||||
print("== DELETING UNCHECKED STARTER MODELS ==")
|
||||
for model in remove_models:
|
||||
print(f"{model}...")
|
||||
# TODO: Replace next three paragraphs with calls into new model manager
|
||||
if diffusers.remove_models and len(diffusers.remove_models) > 0:
|
||||
logger.info("DELETING UNCHECKED STARTER MODELS")
|
||||
for model in diffusers.remove_models:
|
||||
logger.info(f"{model}...")
|
||||
model_manager.del_model(model, delete_files=purge_deleted)
|
||||
model_manager.commit(config_file_path)
|
||||
|
||||
if install_initial_models and len(install_initial_models) > 0:
|
||||
print("== INSTALLING SELECTED STARTER MODELS ==")
|
||||
if diffusers.install_models and len(diffusers.install_models) > 0:
|
||||
logger.info("INSTALLING SELECTED STARTER MODELS")
|
||||
successfully_downloaded = download_weight_datasets(
|
||||
models=install_initial_models,
|
||||
models=diffusers.install_initial_models,
|
||||
access_token=None,
|
||||
precision=precision,
|
||||
) # FIX: for historical reasons, we don't use model manager here
|
||||
update_config_file(successfully_downloaded, config_file_path)
|
||||
if len(successfully_downloaded) < len(install_initial_models):
|
||||
print("** Some of the model downloads were not successful")
|
||||
if len(successfully_downloaded) < len(diffusers.install_models):
|
||||
logger.warning("Some of the model downloads were not successful")
|
||||
|
||||
# due to above, we have to reload the model manager because conf file
|
||||
# was changed behind its back
|
||||
@ -118,7 +123,7 @@ def install_requested_models(
|
||||
external_models.append(str(scan_directory))
|
||||
|
||||
if len(external_models) > 0:
|
||||
print("== INSTALLING EXTERNAL MODELS ==")
|
||||
logger.info("INSTALLING EXTERNAL MODELS")
|
||||
for path_url_or_repo in external_models:
|
||||
try:
|
||||
model_manager.heuristic_import(
|
||||
@ -190,10 +195,10 @@ def migrate_models_ckpt():
|
||||
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
||||
return
|
||||
new_name = initial_models()["stable-diffusion-1.4"]["file"]
|
||||
print(
|
||||
logger.warning(
|
||||
'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.'
|
||||
)
|
||||
print(f"model.ckpt => {new_name}")
|
||||
logger.warning(f"model.ckpt => {new_name}")
|
||||
os.replace(
|
||||
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
|
||||
)
|
||||
@ -206,7 +211,7 @@ def download_weight_datasets(
|
||||
migrate_models_ckpt()
|
||||
successful = dict()
|
||||
for mod in models:
|
||||
print(f"Downloading {mod}:")
|
||||
logger.info(f"Downloading {mod}:")
|
||||
successful[mod] = _download_repo_or_file(
|
||||
initial_models()[mod], access_token, precision=precision
|
||||
)
|
||||
@ -240,68 +245,6 @@ def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def install_controlnet_models(
|
||||
short_names: List[str],
|
||||
short_name_map: Dict[str,str],
|
||||
precision: str='float16',
|
||||
access_token: str = None,
|
||||
):
|
||||
'''
|
||||
Download list of controlnet models, using their HuggingFace
|
||||
repo_ids.
|
||||
'''
|
||||
dest_dir = config.controlnet_path
|
||||
if not dest_dir.exists():
|
||||
dest_dir.mkdir(parents=True,exist_ok=False)
|
||||
|
||||
# The model file may be fp32 or fp16, and may be either a
|
||||
# .bin file or a .safetensors. We try each until we get one,
|
||||
# preferring 'fp16' if using half precision, and preferring
|
||||
# safetensors over over bin.
|
||||
precisions = ['.fp16',''] if precision=='float16' else ['']
|
||||
formats = ['.safetensors','.bin']
|
||||
possible_filenames = list()
|
||||
for p in precisions:
|
||||
for f in formats:
|
||||
possible_filenames.append(Path(f'diffusion_pytorch_model{p}{f}'))
|
||||
|
||||
for directory_name in short_names:
|
||||
repo_id = short_name_map[directory_name]
|
||||
safe_name = directory_name.replace('/','--')
|
||||
print(f'Downloading ControlNet model {directory_name} ({repo_id})')
|
||||
hf_download_with_resume(
|
||||
repo_id = repo_id,
|
||||
model_dir = dest_dir / safe_name,
|
||||
model_name = 'config.json',
|
||||
access_token = access_token
|
||||
)
|
||||
|
||||
path = None
|
||||
for filename in possible_filenames:
|
||||
suffix = filename.suffix
|
||||
dest_filename = Path(f'diffusion_pytorch_model{suffix}')
|
||||
print(f'Probing {directory_name}/{filename}...')
|
||||
path = hf_download_with_resume(
|
||||
repo_id = repo_id,
|
||||
model_dir = dest_dir / safe_name,
|
||||
model_name = str(filename),
|
||||
access_token = access_token,
|
||||
model_dest = Path(dest_dir, safe_name, dest_filename),
|
||||
)
|
||||
if path:
|
||||
(path.parent / '.download_complete').touch()
|
||||
break
|
||||
|
||||
# ---------------------------------------------
|
||||
def delete_controlnet_models(short_names: List[str]):
|
||||
for name in short_names:
|
||||
safe_name = name.replace('/','--')
|
||||
directory = config.controlnet_path / safe_name
|
||||
if directory.exists():
|
||||
print(f'Purging controlnet model {name}')
|
||||
shutil.rmtree(str(directory))
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_from_hf(
|
||||
model_class: object, model_name: str, **kwargs
|
||||
@ -340,7 +283,7 @@ def _download_diffusion_weights(
|
||||
if str(e).startswith("fp16 is not a valid"):
|
||||
pass
|
||||
else:
|
||||
print(f"An unexpected error occurred while downloading the model: {e})")
|
||||
logger.error(f"An unexpected error occurred while downloading the model: {e})")
|
||||
if path:
|
||||
break
|
||||
return path
|
||||
@ -374,17 +317,17 @@ def hf_download_with_resume(
|
||||
if (
|
||||
resp.status_code == 416
|
||||
): # "range not satisfiable", which means nothing to return
|
||||
print(f"* {model_name}: complete file found. Skipping.")
|
||||
logger.info(f"{model_name}: complete file found. Skipping.")
|
||||
return model_dest
|
||||
elif resp.status_code == 404:
|
||||
print("** File not found")
|
||||
logger.warning("File not found")
|
||||
return None
|
||||
elif resp.status_code != 200:
|
||||
print(f"** Warning: {model_name}: {resp.reason}")
|
||||
logger.warning(f"{model_name}: {resp.reason}")
|
||||
elif exist_size > 0:
|
||||
print(f"* {model_name}: partial file found. Resuming...")
|
||||
logger.info(f"{model_name}: partial file found. Resuming...")
|
||||
else:
|
||||
print(f"* {model_name}: Downloading...")
|
||||
logger.info(f"{model_name}: Downloading...")
|
||||
|
||||
try:
|
||||
with open(model_dest, open_mode) as file, tqdm(
|
||||
@ -399,7 +342,7 @@ def hf_download_with_resume(
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
return None
|
||||
return model_dest
|
||||
|
||||
@ -424,8 +367,8 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
|
||||
try:
|
||||
backup = None
|
||||
if os.path.exists(config_file):
|
||||
print(
|
||||
f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
||||
logger.warning(
|
||||
f"{config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
||||
)
|
||||
backup = config_file.with_suffix(".yaml.orig")
|
||||
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
|
||||
@ -442,16 +385,16 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
|
||||
new_config.write(tmp.read())
|
||||
|
||||
except Exception as e:
|
||||
print(f"**Error creating config file {config_file}: {str(e)} **")
|
||||
logger.error(f"Error creating config file {config_file}: {str(e)}")
|
||||
if backup is not None:
|
||||
print("restoring previous config file")
|
||||
logger.info("restoring previous config file")
|
||||
## workaround, for WinError 183, see above
|
||||
if sys.platform == "win32" and config_file.is_file():
|
||||
config_file.unlink()
|
||||
backup.rename(config_file)
|
||||
return
|
||||
|
||||
print(f"Successfully created new configuration file {config_file}")
|
||||
|
||||
logger.info(f"Successfully created new configuration file {config_file}")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
@ -518,8 +461,8 @@ def delete_weights(model_name: str, conf_stanza: dict):
|
||||
if re.match("/VAE/", conf_stanza.get("config")):
|
||||
return
|
||||
|
||||
print(
|
||||
f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
|
||||
logger.warning(
|
||||
f"\nThe checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
|
||||
)
|
||||
|
||||
weights = Path(weights)
|
||||
@ -528,4 +471,4 @@ def delete_weights(model_name: str, conf_stanza: dict):
|
||||
try:
|
||||
weights.unlink()
|
||||
except OSError as e:
|
||||
print(str(e))
|
||||
logger.error(str(e))
|
||||
|
@ -11,6 +11,7 @@ import gc
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import textwrap
|
||||
import time
|
||||
@ -49,6 +50,10 @@ from ..stable_diffusion import (
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
from ..install.model_install_backend import (
|
||||
Dataset_path,
|
||||
hf_download_with_resume,
|
||||
)
|
||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
@ -1316,3 +1321,74 @@ class ModelManager(object):
|
||||
return (
|
||||
os.getenv("HF_HOME") is not None or os.getenv("XDG_CACHE_HOME") is not None
|
||||
)
|
||||
|
||||
def list_controlnet_models(self)->Dict[str,bool]:
|
||||
'''Return a dict of installed controlnet models; key is repo_id or short name
|
||||
of model (defined in INITIAL_MODELS), and valule is True if installed'''
|
||||
|
||||
cn_models = OmegaConf.load(Dataset_path).get('controlnet') or {}
|
||||
installed_models = {x: False for x in cn_models.keys()}
|
||||
|
||||
cn_dir = self.globals.controlnet_path
|
||||
installed_cn_models = dict()
|
||||
for root, dirs, files in os.walk(cn_dir):
|
||||
for name in dirs:
|
||||
if Path(root, name, '.download_complete').exists():
|
||||
installed_models.update({name.replace('--','/'): True})
|
||||
return installed_models
|
||||
|
||||
def install_controlnet_models(self, model_names: list[str], access_token: str=None):
|
||||
'''Download list of controlnet models; provide either repo_id or short name listed in INITIAL_MODELS.yaml'''
|
||||
short_names = OmegaConf.load(Dataset_path).get('controlnet') or {}
|
||||
dest_dir = self.globals.controlnet_path
|
||||
dest_dir.mkdir(parents=True,exist_ok=True)
|
||||
|
||||
# The model file may be fp32 or fp16, and may be either a
|
||||
# .bin file or a .safetensors. We try each until we get one,
|
||||
# preferring 'fp16' if using half precision, and preferring
|
||||
# safetensors over over bin.
|
||||
precisions = ['.fp16',''] if self.precision=='float16' else ['']
|
||||
formats = ['.safetensors','.bin']
|
||||
possible_filenames = list()
|
||||
for p in precisions:
|
||||
for f in formats:
|
||||
possible_filenames.append(Path(f'diffusion_pytorch_model{p}{f}'))
|
||||
|
||||
for directory_name in model_names:
|
||||
repo_id = short_names.get(directory_name) or directory_name
|
||||
safe_name = directory_name.replace('/','--')
|
||||
self.logger.info(f'Downloading ControlNet model {directory_name} ({repo_id})')
|
||||
hf_download_with_resume(
|
||||
repo_id = repo_id,
|
||||
model_dir = dest_dir / safe_name,
|
||||
model_name = 'config.json',
|
||||
access_token = access_token
|
||||
)
|
||||
|
||||
path = None
|
||||
for filename in possible_filenames:
|
||||
suffix = filename.suffix
|
||||
dest_filename = Path(f'diffusion_pytorch_model{suffix}')
|
||||
self.logger.info(f'Checking availability of {directory_name}/{filename}...')
|
||||
path = hf_download_with_resume(
|
||||
repo_id = repo_id,
|
||||
model_dir = dest_dir / safe_name,
|
||||
model_name = str(filename),
|
||||
access_token = access_token,
|
||||
model_dest = Path(dest_dir, safe_name, dest_filename),
|
||||
)
|
||||
if path:
|
||||
(path.parent / '.download_complete').touch()
|
||||
break
|
||||
|
||||
def delete_controlnet_models(self, model_names: List[str]):
|
||||
'''Remove the list of controlnet models'''
|
||||
for name in model_names:
|
||||
safe_name = name.replace('/','--')
|
||||
directory = self.globals.controlnet_path / safe_name
|
||||
if directory.exists():
|
||||
self.logger.info(f'Purging controlnet model {name}')
|
||||
shutil.rmtree(str(directory))
|
||||
|
||||
|
||||
|
||||
|
@ -30,11 +30,15 @@ from ...backend.install.model_install_backend import (
|
||||
default_dataset,
|
||||
install_requested_models,
|
||||
recommended_datasets,
|
||||
ModelInstallList,
|
||||
dataclass,
|
||||
)
|
||||
from ...backend import ModelManager
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from .widgets import (
|
||||
CenteredTitleText,
|
||||
MultiSelectColumns,
|
||||
SingleSelectColumns,
|
||||
OffsetButtonPress,
|
||||
TextBox,
|
||||
set_min_terminal_size,
|
||||
@ -53,12 +57,12 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
|
||||
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
|
||||
self.multipage = multipage
|
||||
|
||||
model_manager = ModelManager(config.model_conf_path)
|
||||
|
||||
self.initial_models = OmegaConf.load(Dataset_path)['diffusers']
|
||||
self.control_net_models = OmegaConf.load(Dataset_path)['controlnet']
|
||||
self.installed_cn_models = self._get_installed_cn_models()
|
||||
self._add_additional_cn_models(self.control_net_models,self.installed_cn_models)
|
||||
|
||||
self.installed_cn_models = model_manager.list_controlnet_models()
|
||||
|
||||
try:
|
||||
self.existing_models = OmegaConf.load(default_config_file())
|
||||
except:
|
||||
@ -79,7 +83,7 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
[x for x in list(self.initial_models.keys()) if x in self.existing_models]
|
||||
)
|
||||
|
||||
cn_model_list = sorted(self.control_net_models.keys())
|
||||
cn_model_list = sorted(self.installed_cn_models.keys())
|
||||
|
||||
self.nextrely -= 1
|
||||
self.add_widget_intelligent(
|
||||
@ -180,16 +184,34 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
relx=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="== CONTROLNET MODELS ==",
|
||||
name='_' * (window_width-5),
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
labelColor='CAUTION'
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.add_widget_intelligent(
|
||||
|
||||
self.nextrely += 1
|
||||
self.tabs = self.add_widget_intelligent(
|
||||
SingleSelectColumns,
|
||||
values=['ADD CONTROLNET MODELS','ADD LORA/LYCORIS MODELS', 'ADD TEXTUAL INVERSION MODELS'],
|
||||
value=0,
|
||||
columns = 4,
|
||||
max_height = 2,
|
||||
relx=8,
|
||||
scroll_exit = True,
|
||||
)
|
||||
# self.add_widget_intelligent(
|
||||
# CenteredTitleText,
|
||||
# name="== CONTROLNET MODELS ==",
|
||||
# editable=False,
|
||||
# color="CONTROL",
|
||||
# )
|
||||
top_of_table = self.nextrely
|
||||
self.cn_label_1 = self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="Select the desired ControlNet models. Unchecked models will be purged from disk.",
|
||||
name="Select the desired models to install. Unchecked models will be purged from disk.",
|
||||
editable=False,
|
||||
labelColor="CAUTION",
|
||||
)
|
||||
@ -202,14 +224,14 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
value=[
|
||||
cn_model_list.index(x)
|
||||
for x in cn_model_list
|
||||
if x in self.installed_cn_models
|
||||
if self.installed_cn_models[x]
|
||||
],
|
||||
max_height=len(cn_model_list)//columns + 1,
|
||||
relx=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
self.cn_label_2 = self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name='Additional ControlNet HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):',
|
||||
relx=4,
|
||||
@ -221,6 +243,55 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
self.additional_controlnet_ids = self.add_widget_intelligent(
|
||||
TextBox, max_height=2, scroll_exit=True, editable=True, relx=4
|
||||
)
|
||||
|
||||
bottom_of_table = self.nextrely
|
||||
self.nextrely = top_of_table
|
||||
self.lora_label_1 = self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name='LoRA/LYCORIS models to download and install (Space separated. Use shift-control-V to paste):',
|
||||
relx=4,
|
||||
color='CONTROL',
|
||||
editable=False,
|
||||
hidden=True,
|
||||
scroll_exit=True
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.loras = self.add_widget_intelligent(
|
||||
TextBox,
|
||||
max_height=2,
|
||||
scroll_exit=True,
|
||||
editable=True,
|
||||
relx=4,
|
||||
hidden=True,
|
||||
)
|
||||
self.nextrely = top_of_table
|
||||
self.ti_label_1 = self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name='Textual Inversion models to download and install (Space separated. Use shift-control-V to paste):',
|
||||
relx=4,
|
||||
color='CONTROL',
|
||||
editable=False,
|
||||
hidden=True,
|
||||
scroll_exit=True
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.tis = self.add_widget_intelligent(
|
||||
TextBox,
|
||||
max_height=2,
|
||||
scroll_exit=True,
|
||||
editable=True,
|
||||
relx=4,
|
||||
hidden=True,
|
||||
)
|
||||
self.nextrely = bottom_of_table
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name='_' * (window_width-5),
|
||||
editable=False,
|
||||
labelColor='CAUTION'
|
||||
)
|
||||
|
||||
self.cancel = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
name="CANCEL",
|
||||
@ -254,6 +325,9 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
for i in [self.autoload_directory, self.autoscan_on_startup]:
|
||||
self.show_directory_fields.addVisibleWhenSelected(i)
|
||||
|
||||
# self.tabs.when_value_edited = self._toggle_tables
|
||||
self.tabs.on_changed = self._toggle_tables
|
||||
|
||||
self.show_directory_fields.when_value_edited = self._clear_scan_directory
|
||||
|
||||
def resize(self):
|
||||
@ -261,6 +335,21 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
if hasattr(self, "models_selected"):
|
||||
self.models_selected.values = self._get_starter_model_labels()
|
||||
|
||||
def _toggle_tables(self, value=None):
|
||||
selected_tab = value[0] if value else self.tabs.value[0]
|
||||
widgets = [
|
||||
[self.cn_label_1, self.cn_models_selected, self.cn_label_2, self.additional_controlnet_ids],
|
||||
[self.lora_label_1,self.loras],
|
||||
[self.ti_label_1,self.tis],
|
||||
]
|
||||
|
||||
for group in widgets:
|
||||
for w in group:
|
||||
w.hidden = True
|
||||
for w in widgets[selected_tab]:
|
||||
w.hidden = False
|
||||
self.display()
|
||||
|
||||
def _clear_scan_directory(self):
|
||||
if not self.show_directory_fields.value:
|
||||
self.autoload_directory.value = ""
|
||||
@ -361,20 +450,18 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
selections.install_models = [x for x in starter_models if x not in self.existing_models]
|
||||
selections.remove_models = [x for x in self.starter_model_list if x in self.existing_models and x not in starter_models]
|
||||
|
||||
selections.control_net_map = self.control_net_models
|
||||
selections.install_cn_models = [self.cn_models_selected.values[x]
|
||||
for x in self.cn_models_selected.value
|
||||
if self.cn_models_selected.values[x] not in self.installed_cn_models
|
||||
if not self.installed_cn_models[self.cn_models_selected.values[x]]
|
||||
]
|
||||
selections.remove_cn_models = [x
|
||||
for x in self.cn_models_selected.values
|
||||
if x in self.installed_cn_models
|
||||
if self.installed_cn_models[x]
|
||||
and self.cn_models_selected.values.index(x) not in self.cn_models_selected.value
|
||||
]
|
||||
if (additional_cns := self.additional_controlnet_ids.value.split()):
|
||||
valid_cns = [x for x in additional_cns if '/' in x]
|
||||
selections.install_cn_models.extend(valid_cns)
|
||||
selections.control_net_map.update({x: x for x in valid_cns})
|
||||
|
||||
# load directory and whether to scan on startup
|
||||
if self.show_directory_fields.value:
|
||||
@ -387,22 +474,22 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
# URLs and the like
|
||||
selections.import_model_paths = self.import_model_paths.value.split()
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserSelections():
|
||||
install_models: List[str]=None
|
||||
remove_models: List[str]=None
|
||||
purge_deleted_models: bool=False,
|
||||
install_cn_models: List[str] = None,
|
||||
remove_cn_models: List[str] = None,
|
||||
scan_directory: Path=None,
|
||||
autoscan_on_startup: bool=False,
|
||||
import_model_paths: str=None,
|
||||
|
||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.user_cancelled = False
|
||||
self.user_selections = Namespace(
|
||||
install_models=None,
|
||||
remove_models=None,
|
||||
purge_deleted_models=False,
|
||||
install_cn_models = None,
|
||||
remove_cn_models = None,
|
||||
control_net_map = None,
|
||||
scan_directory=None,
|
||||
autoscan_on_startup=None,
|
||||
import_model_paths=None,
|
||||
)
|
||||
self.user_selections = UserSelections()
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
@ -418,15 +505,9 @@ def process_and_execute(opt: Namespace, selections: Namespace):
|
||||
directory_to_scan = selections.scan_directory
|
||||
scan_at_startup = selections.autoscan_on_startup
|
||||
potential_models_to_install = selections.import_model_paths
|
||||
print(f'selections.install_cn_models={selections.install_cn_models}')
|
||||
print(f'selections.remove_cn_models={selections.remove_cn_models}')
|
||||
print(f'selections.cn_model_map={selections.control_net_map}')
|
||||
install_requested_models(
|
||||
install_initial_models=models_to_install,
|
||||
remove_models=models_to_remove,
|
||||
install_cn_models=selections.install_cn_models,
|
||||
remove_cn_models=selections.remove_cn_models,
|
||||
cn_model_map=selections.control_net_map,
|
||||
diffusers = ModelInstallList(models_to_install, models_to_remove),
|
||||
controlnet = ModelInstallList(selections.install_cn_models, selections.remove_cn_models),
|
||||
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
|
||||
external_models=potential_models_to_install,
|
||||
scan_at_startup=scan_at_startup,
|
||||
|
@ -94,13 +94,7 @@ class FloatTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = FloatSlider
|
||||
|
||||
|
||||
class MultiSelectColumns(npyscreen.MultiSelect):
|
||||
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
|
||||
self.columns = columns
|
||||
self.value_cnt = len(values)
|
||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||
super().__init__(screen, values=values, **keywords)
|
||||
|
||||
class SelectColumnBase():
|
||||
def make_contained_widgets(self):
|
||||
self._my_widgets = []
|
||||
column_width = self.width // self.columns
|
||||
@ -150,6 +144,32 @@ class MultiSelectColumns(npyscreen.MultiSelect):
|
||||
def h_cursor_line_right(self, ch):
|
||||
super().h_cursor_line_down(ch)
|
||||
|
||||
class MultiSelectColumns( SelectColumnBase, npyscreen.MultiSelect):
|
||||
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
|
||||
self.columns = columns
|
||||
self.value_cnt = len(values)
|
||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||
super().__init__(screen, values=values, **keywords)
|
||||
|
||||
|
||||
class SingleSelectColumns(SelectColumnBase, npyscreen.SelectOne):
|
||||
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
|
||||
self.columns = columns
|
||||
self.value_cnt = len(values)
|
||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||
self.on_changed = None
|
||||
super().__init__(screen, values=values, **keywords)
|
||||
|
||||
def h_select(self,ch):
|
||||
super().h_select(ch)
|
||||
if self.on_changed:
|
||||
self.on_changed(self.value)
|
||||
|
||||
def when_value_edited(self):
|
||||
self.h_select(self.cursor_line)
|
||||
|
||||
def when_cursor_moved(self):
|
||||
self.h_select(self.cursor_line)
|
||||
|
||||
class TextBox(npyscreen.MultiLineEdit):
|
||||
def update(self, clear=True):
|
||||
|
Loading…
Reference in New Issue
Block a user