mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
implemented tabbed model selection; not wired to backend yet
This commit is contained in:
parent
d6530df635
commit
e9821ab711
@ -6,6 +6,7 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryFile
|
from tempfile import TemporaryFile
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
@ -20,9 +21,8 @@ from tqdm import tqdm
|
|||||||
import invokeai.configs as configs
|
import invokeai.configs as configs
|
||||||
|
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
from ..model_management import ModelManager
|
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||||
|
from ..util.logging import InvokeAILogger
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
@ -37,6 +37,9 @@ Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
|||||||
# initial models omegaconf
|
# initial models omegaconf
|
||||||
Datasets = None
|
Datasets = None
|
||||||
|
|
||||||
|
# logger
|
||||||
|
logger = InvokeAILogger.getLogger(name='InvokeAI')
|
||||||
|
|
||||||
Config_preamble = """
|
Config_preamble = """
|
||||||
# This file describes the alternative machine learning models
|
# This file describes the alternative machine learning models
|
||||||
# available to InvokeAI script.
|
# available to InvokeAI script.
|
||||||
@ -47,6 +50,11 @@ Config_preamble = """
|
|||||||
# was trained on.
|
# was trained on.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelInstallList:
|
||||||
|
'''Class for listing models to be installed/removed'''
|
||||||
|
install_models: List[str]
|
||||||
|
remove_models: List[str]
|
||||||
|
|
||||||
def default_config_file():
|
def default_config_file():
|
||||||
return config.model_conf_path
|
return config.model_conf_path
|
||||||
@ -61,11 +69,11 @@ def initial_models():
|
|||||||
return (Datasets := OmegaConf.load(Dataset_path)['diffusers'])
|
return (Datasets := OmegaConf.load(Dataset_path)['diffusers'])
|
||||||
|
|
||||||
def install_requested_models(
|
def install_requested_models(
|
||||||
install_initial_models: List[str] = None,
|
diffusers: ModelInstallList = None,
|
||||||
remove_models: List[str] = None,
|
controlnet: ModelInstallList = None,
|
||||||
install_cn_models: List[str] = None,
|
lora: ModelInstallList = None,
|
||||||
remove_cn_models: List[str] = None,
|
ti: ModelInstallList = None,
|
||||||
cn_model_map: Dict[str,str] = None,
|
cn_model_map: Dict[str,str] = None, # temporary - move to model manager
|
||||||
scan_directory: Path = None,
|
scan_directory: Path = None,
|
||||||
external_models: List[str] = None,
|
external_models: List[str] = None,
|
||||||
scan_at_startup: bool = False,
|
scan_at_startup: bool = False,
|
||||||
@ -81,33 +89,30 @@ def install_requested_models(
|
|||||||
if not config_file_path.exists():
|
if not config_file_path.exists():
|
||||||
open(config_file_path, "w")
|
open(config_file_path, "w")
|
||||||
|
|
||||||
install_controlnet_models(
|
# prevent circular import here
|
||||||
install_cn_models,
|
from ..model_management import ModelManager
|
||||||
short_name_map = cn_model_map,
|
|
||||||
precision=precision,
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
|
||||||
delete_controlnet_models(remove_cn_models)
|
|
||||||
|
|
||||||
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
|
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
|
||||||
|
model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token)
|
||||||
|
model_manager.delete_controlnet_models(controlnet.remove_models)
|
||||||
|
|
||||||
if remove_models and len(remove_models) > 0:
|
# TODO: Replace next three paragraphs with calls into new model manager
|
||||||
print("== DELETING UNCHECKED STARTER MODELS ==")
|
if diffusers.remove_models and len(diffusers.remove_models) > 0:
|
||||||
for model in remove_models:
|
logger.info("DELETING UNCHECKED STARTER MODELS")
|
||||||
print(f"{model}...")
|
for model in diffusers.remove_models:
|
||||||
|
logger.info(f"{model}...")
|
||||||
model_manager.del_model(model, delete_files=purge_deleted)
|
model_manager.del_model(model, delete_files=purge_deleted)
|
||||||
model_manager.commit(config_file_path)
|
model_manager.commit(config_file_path)
|
||||||
|
|
||||||
if install_initial_models and len(install_initial_models) > 0:
|
if diffusers.install_models and len(diffusers.install_models) > 0:
|
||||||
print("== INSTALLING SELECTED STARTER MODELS ==")
|
logger.info("INSTALLING SELECTED STARTER MODELS")
|
||||||
successfully_downloaded = download_weight_datasets(
|
successfully_downloaded = download_weight_datasets(
|
||||||
models=install_initial_models,
|
models=diffusers.install_initial_models,
|
||||||
access_token=None,
|
access_token=None,
|
||||||
precision=precision,
|
precision=precision,
|
||||||
) # FIX: for historical reasons, we don't use model manager here
|
) # FIX: for historical reasons, we don't use model manager here
|
||||||
update_config_file(successfully_downloaded, config_file_path)
|
update_config_file(successfully_downloaded, config_file_path)
|
||||||
if len(successfully_downloaded) < len(install_initial_models):
|
if len(successfully_downloaded) < len(diffusers.install_models):
|
||||||
print("** Some of the model downloads were not successful")
|
logger.warning("Some of the model downloads were not successful")
|
||||||
|
|
||||||
# due to above, we have to reload the model manager because conf file
|
# due to above, we have to reload the model manager because conf file
|
||||||
# was changed behind its back
|
# was changed behind its back
|
||||||
@ -118,7 +123,7 @@ def install_requested_models(
|
|||||||
external_models.append(str(scan_directory))
|
external_models.append(str(scan_directory))
|
||||||
|
|
||||||
if len(external_models) > 0:
|
if len(external_models) > 0:
|
||||||
print("== INSTALLING EXTERNAL MODELS ==")
|
logger.info("INSTALLING EXTERNAL MODELS")
|
||||||
for path_url_or_repo in external_models:
|
for path_url_or_repo in external_models:
|
||||||
try:
|
try:
|
||||||
model_manager.heuristic_import(
|
model_manager.heuristic_import(
|
||||||
@ -190,10 +195,10 @@ def migrate_models_ckpt():
|
|||||||
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
||||||
return
|
return
|
||||||
new_name = initial_models()["stable-diffusion-1.4"]["file"]
|
new_name = initial_models()["stable-diffusion-1.4"]["file"]
|
||||||
print(
|
logger.warning(
|
||||||
'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.'
|
'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.'
|
||||||
)
|
)
|
||||||
print(f"model.ckpt => {new_name}")
|
logger.warning(f"model.ckpt => {new_name}")
|
||||||
os.replace(
|
os.replace(
|
||||||
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
|
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
|
||||||
)
|
)
|
||||||
@ -206,7 +211,7 @@ def download_weight_datasets(
|
|||||||
migrate_models_ckpt()
|
migrate_models_ckpt()
|
||||||
successful = dict()
|
successful = dict()
|
||||||
for mod in models:
|
for mod in models:
|
||||||
print(f"Downloading {mod}:")
|
logger.info(f"Downloading {mod}:")
|
||||||
successful[mod] = _download_repo_or_file(
|
successful[mod] = _download_repo_or_file(
|
||||||
initial_models()[mod], access_token, precision=precision
|
initial_models()[mod], access_token, precision=precision
|
||||||
)
|
)
|
||||||
@ -240,68 +245,6 @@ def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def install_controlnet_models(
|
|
||||||
short_names: List[str],
|
|
||||||
short_name_map: Dict[str,str],
|
|
||||||
precision: str='float16',
|
|
||||||
access_token: str = None,
|
|
||||||
):
|
|
||||||
'''
|
|
||||||
Download list of controlnet models, using their HuggingFace
|
|
||||||
repo_ids.
|
|
||||||
'''
|
|
||||||
dest_dir = config.controlnet_path
|
|
||||||
if not dest_dir.exists():
|
|
||||||
dest_dir.mkdir(parents=True,exist_ok=False)
|
|
||||||
|
|
||||||
# The model file may be fp32 or fp16, and may be either a
|
|
||||||
# .bin file or a .safetensors. We try each until we get one,
|
|
||||||
# preferring 'fp16' if using half precision, and preferring
|
|
||||||
# safetensors over over bin.
|
|
||||||
precisions = ['.fp16',''] if precision=='float16' else ['']
|
|
||||||
formats = ['.safetensors','.bin']
|
|
||||||
possible_filenames = list()
|
|
||||||
for p in precisions:
|
|
||||||
for f in formats:
|
|
||||||
possible_filenames.append(Path(f'diffusion_pytorch_model{p}{f}'))
|
|
||||||
|
|
||||||
for directory_name in short_names:
|
|
||||||
repo_id = short_name_map[directory_name]
|
|
||||||
safe_name = directory_name.replace('/','--')
|
|
||||||
print(f'Downloading ControlNet model {directory_name} ({repo_id})')
|
|
||||||
hf_download_with_resume(
|
|
||||||
repo_id = repo_id,
|
|
||||||
model_dir = dest_dir / safe_name,
|
|
||||||
model_name = 'config.json',
|
|
||||||
access_token = access_token
|
|
||||||
)
|
|
||||||
|
|
||||||
path = None
|
|
||||||
for filename in possible_filenames:
|
|
||||||
suffix = filename.suffix
|
|
||||||
dest_filename = Path(f'diffusion_pytorch_model{suffix}')
|
|
||||||
print(f'Probing {directory_name}/{filename}...')
|
|
||||||
path = hf_download_with_resume(
|
|
||||||
repo_id = repo_id,
|
|
||||||
model_dir = dest_dir / safe_name,
|
|
||||||
model_name = str(filename),
|
|
||||||
access_token = access_token,
|
|
||||||
model_dest = Path(dest_dir, safe_name, dest_filename),
|
|
||||||
)
|
|
||||||
if path:
|
|
||||||
(path.parent / '.download_complete').touch()
|
|
||||||
break
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def delete_controlnet_models(short_names: List[str]):
|
|
||||||
for name in short_names:
|
|
||||||
safe_name = name.replace('/','--')
|
|
||||||
directory = config.controlnet_path / safe_name
|
|
||||||
if directory.exists():
|
|
||||||
print(f'Purging controlnet model {name}')
|
|
||||||
shutil.rmtree(str(directory))
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_from_hf(
|
def download_from_hf(
|
||||||
model_class: object, model_name: str, **kwargs
|
model_class: object, model_name: str, **kwargs
|
||||||
@ -340,7 +283,7 @@ def _download_diffusion_weights(
|
|||||||
if str(e).startswith("fp16 is not a valid"):
|
if str(e).startswith("fp16 is not a valid"):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
print(f"An unexpected error occurred while downloading the model: {e})")
|
logger.error(f"An unexpected error occurred while downloading the model: {e})")
|
||||||
if path:
|
if path:
|
||||||
break
|
break
|
||||||
return path
|
return path
|
||||||
@ -374,17 +317,17 @@ def hf_download_with_resume(
|
|||||||
if (
|
if (
|
||||||
resp.status_code == 416
|
resp.status_code == 416
|
||||||
): # "range not satisfiable", which means nothing to return
|
): # "range not satisfiable", which means nothing to return
|
||||||
print(f"* {model_name}: complete file found. Skipping.")
|
logger.info(f"{model_name}: complete file found. Skipping.")
|
||||||
return model_dest
|
return model_dest
|
||||||
elif resp.status_code == 404:
|
elif resp.status_code == 404:
|
||||||
print("** File not found")
|
logger.warning("File not found")
|
||||||
return None
|
return None
|
||||||
elif resp.status_code != 200:
|
elif resp.status_code != 200:
|
||||||
print(f"** Warning: {model_name}: {resp.reason}")
|
logger.warning(f"{model_name}: {resp.reason}")
|
||||||
elif exist_size > 0:
|
elif exist_size > 0:
|
||||||
print(f"* {model_name}: partial file found. Resuming...")
|
logger.info(f"{model_name}: partial file found. Resuming...")
|
||||||
else:
|
else:
|
||||||
print(f"* {model_name}: Downloading...")
|
logger.info(f"{model_name}: Downloading...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(model_dest, open_mode) as file, tqdm(
|
with open(model_dest, open_mode) as file, tqdm(
|
||||||
@ -399,7 +342,7 @@ def hf_download_with_resume(
|
|||||||
size = file.write(data)
|
size = file.write(data)
|
||||||
bar.update(size)
|
bar.update(size)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred while downloading {model_name}: {str(e)}")
|
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||||
return None
|
return None
|
||||||
return model_dest
|
return model_dest
|
||||||
|
|
||||||
@ -424,8 +367,8 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
|
|||||||
try:
|
try:
|
||||||
backup = None
|
backup = None
|
||||||
if os.path.exists(config_file):
|
if os.path.exists(config_file):
|
||||||
print(
|
logger.warning(
|
||||||
f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
f"{config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
||||||
)
|
)
|
||||||
backup = config_file.with_suffix(".yaml.orig")
|
backup = config_file.with_suffix(".yaml.orig")
|
||||||
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
|
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
|
||||||
@ -442,16 +385,16 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
|
|||||||
new_config.write(tmp.read())
|
new_config.write(tmp.read())
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"**Error creating config file {config_file}: {str(e)} **")
|
logger.error(f"Error creating config file {config_file}: {str(e)}")
|
||||||
if backup is not None:
|
if backup is not None:
|
||||||
print("restoring previous config file")
|
logger.info("restoring previous config file")
|
||||||
## workaround, for WinError 183, see above
|
## workaround, for WinError 183, see above
|
||||||
if sys.platform == "win32" and config_file.is_file():
|
if sys.platform == "win32" and config_file.is_file():
|
||||||
config_file.unlink()
|
config_file.unlink()
|
||||||
backup.rename(config_file)
|
backup.rename(config_file)
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"Successfully created new configuration file {config_file}")
|
logger.info(f"Successfully created new configuration file {config_file}")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
@ -518,8 +461,8 @@ def delete_weights(model_name: str, conf_stanza: dict):
|
|||||||
if re.match("/VAE/", conf_stanza.get("config")):
|
if re.match("/VAE/", conf_stanza.get("config")):
|
||||||
return
|
return
|
||||||
|
|
||||||
print(
|
logger.warning(
|
||||||
f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
|
f"\nThe checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
|
||||||
)
|
)
|
||||||
|
|
||||||
weights = Path(weights)
|
weights = Path(weights)
|
||||||
@ -528,4 +471,4 @@ def delete_weights(model_name: str, conf_stanza: dict):
|
|||||||
try:
|
try:
|
||||||
weights.unlink()
|
weights.unlink()
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
print(str(e))
|
logger.error(str(e))
|
||||||
|
@ -11,6 +11,7 @@ import gc
|
|||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
import time
|
import time
|
||||||
@ -49,6 +50,10 @@ from ..stable_diffusion import (
|
|||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
from ..install.model_install_backend import (
|
||||||
|
Dataset_path,
|
||||||
|
hf_download_with_resume,
|
||||||
|
)
|
||||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||||
|
|
||||||
class SDLegacyType(Enum):
|
class SDLegacyType(Enum):
|
||||||
@ -1316,3 +1321,74 @@ class ModelManager(object):
|
|||||||
return (
|
return (
|
||||||
os.getenv("HF_HOME") is not None or os.getenv("XDG_CACHE_HOME") is not None
|
os.getenv("HF_HOME") is not None or os.getenv("XDG_CACHE_HOME") is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def list_controlnet_models(self)->Dict[str,bool]:
|
||||||
|
'''Return a dict of installed controlnet models; key is repo_id or short name
|
||||||
|
of model (defined in INITIAL_MODELS), and valule is True if installed'''
|
||||||
|
|
||||||
|
cn_models = OmegaConf.load(Dataset_path).get('controlnet') or {}
|
||||||
|
installed_models = {x: False for x in cn_models.keys()}
|
||||||
|
|
||||||
|
cn_dir = self.globals.controlnet_path
|
||||||
|
installed_cn_models = dict()
|
||||||
|
for root, dirs, files in os.walk(cn_dir):
|
||||||
|
for name in dirs:
|
||||||
|
if Path(root, name, '.download_complete').exists():
|
||||||
|
installed_models.update({name.replace('--','/'): True})
|
||||||
|
return installed_models
|
||||||
|
|
||||||
|
def install_controlnet_models(self, model_names: list[str], access_token: str=None):
|
||||||
|
'''Download list of controlnet models; provide either repo_id or short name listed in INITIAL_MODELS.yaml'''
|
||||||
|
short_names = OmegaConf.load(Dataset_path).get('controlnet') or {}
|
||||||
|
dest_dir = self.globals.controlnet_path
|
||||||
|
dest_dir.mkdir(parents=True,exist_ok=True)
|
||||||
|
|
||||||
|
# The model file may be fp32 or fp16, and may be either a
|
||||||
|
# .bin file or a .safetensors. We try each until we get one,
|
||||||
|
# preferring 'fp16' if using half precision, and preferring
|
||||||
|
# safetensors over over bin.
|
||||||
|
precisions = ['.fp16',''] if self.precision=='float16' else ['']
|
||||||
|
formats = ['.safetensors','.bin']
|
||||||
|
possible_filenames = list()
|
||||||
|
for p in precisions:
|
||||||
|
for f in formats:
|
||||||
|
possible_filenames.append(Path(f'diffusion_pytorch_model{p}{f}'))
|
||||||
|
|
||||||
|
for directory_name in model_names:
|
||||||
|
repo_id = short_names.get(directory_name) or directory_name
|
||||||
|
safe_name = directory_name.replace('/','--')
|
||||||
|
self.logger.info(f'Downloading ControlNet model {directory_name} ({repo_id})')
|
||||||
|
hf_download_with_resume(
|
||||||
|
repo_id = repo_id,
|
||||||
|
model_dir = dest_dir / safe_name,
|
||||||
|
model_name = 'config.json',
|
||||||
|
access_token = access_token
|
||||||
|
)
|
||||||
|
|
||||||
|
path = None
|
||||||
|
for filename in possible_filenames:
|
||||||
|
suffix = filename.suffix
|
||||||
|
dest_filename = Path(f'diffusion_pytorch_model{suffix}')
|
||||||
|
self.logger.info(f'Checking availability of {directory_name}/{filename}...')
|
||||||
|
path = hf_download_with_resume(
|
||||||
|
repo_id = repo_id,
|
||||||
|
model_dir = dest_dir / safe_name,
|
||||||
|
model_name = str(filename),
|
||||||
|
access_token = access_token,
|
||||||
|
model_dest = Path(dest_dir, safe_name, dest_filename),
|
||||||
|
)
|
||||||
|
if path:
|
||||||
|
(path.parent / '.download_complete').touch()
|
||||||
|
break
|
||||||
|
|
||||||
|
def delete_controlnet_models(self, model_names: List[str]):
|
||||||
|
'''Remove the list of controlnet models'''
|
||||||
|
for name in model_names:
|
||||||
|
safe_name = name.replace('/','--')
|
||||||
|
directory = self.globals.controlnet_path / safe_name
|
||||||
|
if directory.exists():
|
||||||
|
self.logger.info(f'Purging controlnet model {name}')
|
||||||
|
shutil.rmtree(str(directory))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,11 +30,15 @@ from ...backend.install.model_install_backend import (
|
|||||||
default_dataset,
|
default_dataset,
|
||||||
install_requested_models,
|
install_requested_models,
|
||||||
recommended_datasets,
|
recommended_datasets,
|
||||||
|
ModelInstallList,
|
||||||
|
dataclass,
|
||||||
)
|
)
|
||||||
|
from ...backend import ModelManager
|
||||||
from ...backend.util import choose_precision, choose_torch_device
|
from ...backend.util import choose_precision, choose_torch_device
|
||||||
from .widgets import (
|
from .widgets import (
|
||||||
CenteredTitleText,
|
CenteredTitleText,
|
||||||
MultiSelectColumns,
|
MultiSelectColumns,
|
||||||
|
SingleSelectColumns,
|
||||||
OffsetButtonPress,
|
OffsetButtonPress,
|
||||||
TextBox,
|
TextBox,
|
||||||
set_min_terminal_size,
|
set_min_terminal_size,
|
||||||
@ -53,12 +57,12 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
|
|
||||||
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
|
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
|
||||||
self.multipage = multipage
|
self.multipage = multipage
|
||||||
|
|
||||||
|
model_manager = ModelManager(config.model_conf_path)
|
||||||
|
|
||||||
self.initial_models = OmegaConf.load(Dataset_path)['diffusers']
|
self.initial_models = OmegaConf.load(Dataset_path)['diffusers']
|
||||||
self.control_net_models = OmegaConf.load(Dataset_path)['controlnet']
|
self.installed_cn_models = model_manager.list_controlnet_models()
|
||||||
self.installed_cn_models = self._get_installed_cn_models()
|
|
||||||
self._add_additional_cn_models(self.control_net_models,self.installed_cn_models)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.existing_models = OmegaConf.load(default_config_file())
|
self.existing_models = OmegaConf.load(default_config_file())
|
||||||
except:
|
except:
|
||||||
@ -79,7 +83,7 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
[x for x in list(self.initial_models.keys()) if x in self.existing_models]
|
[x for x in list(self.initial_models.keys()) if x in self.existing_models]
|
||||||
)
|
)
|
||||||
|
|
||||||
cn_model_list = sorted(self.control_net_models.keys())
|
cn_model_list = sorted(self.installed_cn_models.keys())
|
||||||
|
|
||||||
self.nextrely -= 1
|
self.nextrely -= 1
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
@ -180,16 +184,34 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
relx=4,
|
relx=4,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
CenteredTitleText,
|
CenteredTitleText,
|
||||||
name="== CONTROLNET MODELS ==",
|
name='_' * (window_width-5),
|
||||||
editable=False,
|
editable=False,
|
||||||
color="CONTROL",
|
labelColor='CAUTION'
|
||||||
)
|
)
|
||||||
self.nextrely -= 1
|
|
||||||
self.add_widget_intelligent(
|
self.nextrely += 1
|
||||||
|
self.tabs = self.add_widget_intelligent(
|
||||||
|
SingleSelectColumns,
|
||||||
|
values=['ADD CONTROLNET MODELS','ADD LORA/LYCORIS MODELS', 'ADD TEXTUAL INVERSION MODELS'],
|
||||||
|
value=0,
|
||||||
|
columns = 4,
|
||||||
|
max_height = 2,
|
||||||
|
relx=8,
|
||||||
|
scroll_exit = True,
|
||||||
|
)
|
||||||
|
# self.add_widget_intelligent(
|
||||||
|
# CenteredTitleText,
|
||||||
|
# name="== CONTROLNET MODELS ==",
|
||||||
|
# editable=False,
|
||||||
|
# color="CONTROL",
|
||||||
|
# )
|
||||||
|
top_of_table = self.nextrely
|
||||||
|
self.cn_label_1 = self.add_widget_intelligent(
|
||||||
CenteredTitleText,
|
CenteredTitleText,
|
||||||
name="Select the desired ControlNet models. Unchecked models will be purged from disk.",
|
name="Select the desired models to install. Unchecked models will be purged from disk.",
|
||||||
editable=False,
|
editable=False,
|
||||||
labelColor="CAUTION",
|
labelColor="CAUTION",
|
||||||
)
|
)
|
||||||
@ -202,14 +224,14 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
value=[
|
value=[
|
||||||
cn_model_list.index(x)
|
cn_model_list.index(x)
|
||||||
for x in cn_model_list
|
for x in cn_model_list
|
||||||
if x in self.installed_cn_models
|
if self.installed_cn_models[x]
|
||||||
],
|
],
|
||||||
max_height=len(cn_model_list)//columns + 1,
|
max_height=len(cn_model_list)//columns + 1,
|
||||||
relx=4,
|
relx=4,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
self.add_widget_intelligent(
|
self.cn_label_2 = self.add_widget_intelligent(
|
||||||
npyscreen.TitleFixedText,
|
npyscreen.TitleFixedText,
|
||||||
name='Additional ControlNet HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):',
|
name='Additional ControlNet HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):',
|
||||||
relx=4,
|
relx=4,
|
||||||
@ -221,6 +243,55 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
self.additional_controlnet_ids = self.add_widget_intelligent(
|
self.additional_controlnet_ids = self.add_widget_intelligent(
|
||||||
TextBox, max_height=2, scroll_exit=True, editable=True, relx=4
|
TextBox, max_height=2, scroll_exit=True, editable=True, relx=4
|
||||||
)
|
)
|
||||||
|
|
||||||
|
bottom_of_table = self.nextrely
|
||||||
|
self.nextrely = top_of_table
|
||||||
|
self.lora_label_1 = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleFixedText,
|
||||||
|
name='LoRA/LYCORIS models to download and install (Space separated. Use shift-control-V to paste):',
|
||||||
|
relx=4,
|
||||||
|
color='CONTROL',
|
||||||
|
editable=False,
|
||||||
|
hidden=True,
|
||||||
|
scroll_exit=True
|
||||||
|
)
|
||||||
|
self.nextrely -= 1
|
||||||
|
self.loras = self.add_widget_intelligent(
|
||||||
|
TextBox,
|
||||||
|
max_height=2,
|
||||||
|
scroll_exit=True,
|
||||||
|
editable=True,
|
||||||
|
relx=4,
|
||||||
|
hidden=True,
|
||||||
|
)
|
||||||
|
self.nextrely = top_of_table
|
||||||
|
self.ti_label_1 = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleFixedText,
|
||||||
|
name='Textual Inversion models to download and install (Space separated. Use shift-control-V to paste):',
|
||||||
|
relx=4,
|
||||||
|
color='CONTROL',
|
||||||
|
editable=False,
|
||||||
|
hidden=True,
|
||||||
|
scroll_exit=True
|
||||||
|
)
|
||||||
|
self.nextrely -= 1
|
||||||
|
self.tis = self.add_widget_intelligent(
|
||||||
|
TextBox,
|
||||||
|
max_height=2,
|
||||||
|
scroll_exit=True,
|
||||||
|
editable=True,
|
||||||
|
relx=4,
|
||||||
|
hidden=True,
|
||||||
|
)
|
||||||
|
self.nextrely = bottom_of_table
|
||||||
|
self.nextrely += 1
|
||||||
|
self.add_widget_intelligent(
|
||||||
|
CenteredTitleText,
|
||||||
|
name='_' * (window_width-5),
|
||||||
|
editable=False,
|
||||||
|
labelColor='CAUTION'
|
||||||
|
)
|
||||||
|
|
||||||
self.cancel = self.add_widget_intelligent(
|
self.cancel = self.add_widget_intelligent(
|
||||||
npyscreen.ButtonPress,
|
npyscreen.ButtonPress,
|
||||||
name="CANCEL",
|
name="CANCEL",
|
||||||
@ -254,6 +325,9 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
for i in [self.autoload_directory, self.autoscan_on_startup]:
|
for i in [self.autoload_directory, self.autoscan_on_startup]:
|
||||||
self.show_directory_fields.addVisibleWhenSelected(i)
|
self.show_directory_fields.addVisibleWhenSelected(i)
|
||||||
|
|
||||||
|
# self.tabs.when_value_edited = self._toggle_tables
|
||||||
|
self.tabs.on_changed = self._toggle_tables
|
||||||
|
|
||||||
self.show_directory_fields.when_value_edited = self._clear_scan_directory
|
self.show_directory_fields.when_value_edited = self._clear_scan_directory
|
||||||
|
|
||||||
def resize(self):
|
def resize(self):
|
||||||
@ -261,6 +335,21 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
if hasattr(self, "models_selected"):
|
if hasattr(self, "models_selected"):
|
||||||
self.models_selected.values = self._get_starter_model_labels()
|
self.models_selected.values = self._get_starter_model_labels()
|
||||||
|
|
||||||
|
def _toggle_tables(self, value=None):
|
||||||
|
selected_tab = value[0] if value else self.tabs.value[0]
|
||||||
|
widgets = [
|
||||||
|
[self.cn_label_1, self.cn_models_selected, self.cn_label_2, self.additional_controlnet_ids],
|
||||||
|
[self.lora_label_1,self.loras],
|
||||||
|
[self.ti_label_1,self.tis],
|
||||||
|
]
|
||||||
|
|
||||||
|
for group in widgets:
|
||||||
|
for w in group:
|
||||||
|
w.hidden = True
|
||||||
|
for w in widgets[selected_tab]:
|
||||||
|
w.hidden = False
|
||||||
|
self.display()
|
||||||
|
|
||||||
def _clear_scan_directory(self):
|
def _clear_scan_directory(self):
|
||||||
if not self.show_directory_fields.value:
|
if not self.show_directory_fields.value:
|
||||||
self.autoload_directory.value = ""
|
self.autoload_directory.value = ""
|
||||||
@ -361,20 +450,18 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
selections.install_models = [x for x in starter_models if x not in self.existing_models]
|
selections.install_models = [x for x in starter_models if x not in self.existing_models]
|
||||||
selections.remove_models = [x for x in self.starter_model_list if x in self.existing_models and x not in starter_models]
|
selections.remove_models = [x for x in self.starter_model_list if x in self.existing_models and x not in starter_models]
|
||||||
|
|
||||||
selections.control_net_map = self.control_net_models
|
|
||||||
selections.install_cn_models = [self.cn_models_selected.values[x]
|
selections.install_cn_models = [self.cn_models_selected.values[x]
|
||||||
for x in self.cn_models_selected.value
|
for x in self.cn_models_selected.value
|
||||||
if self.cn_models_selected.values[x] not in self.installed_cn_models
|
if not self.installed_cn_models[self.cn_models_selected.values[x]]
|
||||||
]
|
]
|
||||||
selections.remove_cn_models = [x
|
selections.remove_cn_models = [x
|
||||||
for x in self.cn_models_selected.values
|
for x in self.cn_models_selected.values
|
||||||
if x in self.installed_cn_models
|
if self.installed_cn_models[x]
|
||||||
and self.cn_models_selected.values.index(x) not in self.cn_models_selected.value
|
and self.cn_models_selected.values.index(x) not in self.cn_models_selected.value
|
||||||
]
|
]
|
||||||
if (additional_cns := self.additional_controlnet_ids.value.split()):
|
if (additional_cns := self.additional_controlnet_ids.value.split()):
|
||||||
valid_cns = [x for x in additional_cns if '/' in x]
|
valid_cns = [x for x in additional_cns if '/' in x]
|
||||||
selections.install_cn_models.extend(valid_cns)
|
selections.install_cn_models.extend(valid_cns)
|
||||||
selections.control_net_map.update({x: x for x in valid_cns})
|
|
||||||
|
|
||||||
# load directory and whether to scan on startup
|
# load directory and whether to scan on startup
|
||||||
if self.show_directory_fields.value:
|
if self.show_directory_fields.value:
|
||||||
@ -387,22 +474,22 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
# URLs and the like
|
# URLs and the like
|
||||||
selections.import_model_paths = self.import_model_paths.value.split()
|
selections.import_model_paths = self.import_model_paths.value.split()
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UserSelections():
|
||||||
|
install_models: List[str]=None
|
||||||
|
remove_models: List[str]=None
|
||||||
|
purge_deleted_models: bool=False,
|
||||||
|
install_cn_models: List[str] = None,
|
||||||
|
remove_cn_models: List[str] = None,
|
||||||
|
scan_directory: Path=None,
|
||||||
|
autoscan_on_startup: bool=False,
|
||||||
|
import_model_paths: str=None,
|
||||||
|
|
||||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.user_cancelled = False
|
self.user_cancelled = False
|
||||||
self.user_selections = Namespace(
|
self.user_selections = UserSelections()
|
||||||
install_models=None,
|
|
||||||
remove_models=None,
|
|
||||||
purge_deleted_models=False,
|
|
||||||
install_cn_models = None,
|
|
||||||
remove_cn_models = None,
|
|
||||||
control_net_map = None,
|
|
||||||
scan_directory=None,
|
|
||||||
autoscan_on_startup=None,
|
|
||||||
import_model_paths=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def onStart(self):
|
def onStart(self):
|
||||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||||
@ -418,15 +505,9 @@ def process_and_execute(opt: Namespace, selections: Namespace):
|
|||||||
directory_to_scan = selections.scan_directory
|
directory_to_scan = selections.scan_directory
|
||||||
scan_at_startup = selections.autoscan_on_startup
|
scan_at_startup = selections.autoscan_on_startup
|
||||||
potential_models_to_install = selections.import_model_paths
|
potential_models_to_install = selections.import_model_paths
|
||||||
print(f'selections.install_cn_models={selections.install_cn_models}')
|
|
||||||
print(f'selections.remove_cn_models={selections.remove_cn_models}')
|
|
||||||
print(f'selections.cn_model_map={selections.control_net_map}')
|
|
||||||
install_requested_models(
|
install_requested_models(
|
||||||
install_initial_models=models_to_install,
|
diffusers = ModelInstallList(models_to_install, models_to_remove),
|
||||||
remove_models=models_to_remove,
|
controlnet = ModelInstallList(selections.install_cn_models, selections.remove_cn_models),
|
||||||
install_cn_models=selections.install_cn_models,
|
|
||||||
remove_cn_models=selections.remove_cn_models,
|
|
||||||
cn_model_map=selections.control_net_map,
|
|
||||||
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
|
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
|
||||||
external_models=potential_models_to_install,
|
external_models=potential_models_to_install,
|
||||||
scan_at_startup=scan_at_startup,
|
scan_at_startup=scan_at_startup,
|
||||||
|
@ -94,13 +94,7 @@ class FloatTitleSlider(npyscreen.TitleText):
|
|||||||
_entry_type = FloatSlider
|
_entry_type = FloatSlider
|
||||||
|
|
||||||
|
|
||||||
class MultiSelectColumns(npyscreen.MultiSelect):
|
class SelectColumnBase():
|
||||||
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
|
|
||||||
self.columns = columns
|
|
||||||
self.value_cnt = len(values)
|
|
||||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
|
||||||
super().__init__(screen, values=values, **keywords)
|
|
||||||
|
|
||||||
def make_contained_widgets(self):
|
def make_contained_widgets(self):
|
||||||
self._my_widgets = []
|
self._my_widgets = []
|
||||||
column_width = self.width // self.columns
|
column_width = self.width // self.columns
|
||||||
@ -150,6 +144,32 @@ class MultiSelectColumns(npyscreen.MultiSelect):
|
|||||||
def h_cursor_line_right(self, ch):
|
def h_cursor_line_right(self, ch):
|
||||||
super().h_cursor_line_down(ch)
|
super().h_cursor_line_down(ch)
|
||||||
|
|
||||||
|
class MultiSelectColumns( SelectColumnBase, npyscreen.MultiSelect):
|
||||||
|
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
|
||||||
|
self.columns = columns
|
||||||
|
self.value_cnt = len(values)
|
||||||
|
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||||
|
super().__init__(screen, values=values, **keywords)
|
||||||
|
|
||||||
|
|
||||||
|
class SingleSelectColumns(SelectColumnBase, npyscreen.SelectOne):
|
||||||
|
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
|
||||||
|
self.columns = columns
|
||||||
|
self.value_cnt = len(values)
|
||||||
|
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||||
|
self.on_changed = None
|
||||||
|
super().__init__(screen, values=values, **keywords)
|
||||||
|
|
||||||
|
def h_select(self,ch):
|
||||||
|
super().h_select(ch)
|
||||||
|
if self.on_changed:
|
||||||
|
self.on_changed(self.value)
|
||||||
|
|
||||||
|
def when_value_edited(self):
|
||||||
|
self.h_select(self.cursor_line)
|
||||||
|
|
||||||
|
def when_cursor_moved(self):
|
||||||
|
self.h_select(self.cursor_line)
|
||||||
|
|
||||||
class TextBox(npyscreen.MultiLineEdit):
|
class TextBox(npyscreen.MultiLineEdit):
|
||||||
def update(self, clear=True):
|
def update(self, clear=True):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user