implemented tabbed model selection; not wired to backend yet

This commit is contained in:
Lincoln Stein 2023-06-01 00:31:46 -04:00
parent d6530df635
commit e9821ab711
4 changed files with 269 additions and 149 deletions

View File

@ -6,6 +6,7 @@ import re
import shutil
import sys
import warnings
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryFile
from typing import List, Dict
@ -20,9 +21,8 @@ from tqdm import tqdm
import invokeai.configs as configs
from invokeai.app.services.config import get_invokeai_config
from ..model_management import ModelManager
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util.logging import InvokeAILogger
warnings.filterwarnings("ignore")
@ -37,6 +37,9 @@ Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
# initial models omegaconf
Datasets = None
# logger
logger = InvokeAILogger.getLogger(name='InvokeAI')
Config_preamble = """
# This file describes the alternative machine learning models
# available to InvokeAI script.
@ -47,6 +50,11 @@ Config_preamble = """
# was trained on.
"""
@dataclass
class ModelInstallList:
'''Class for listing models to be installed/removed'''
install_models: List[str]
remove_models: List[str]
def default_config_file():
return config.model_conf_path
@ -61,11 +69,11 @@ def initial_models():
return (Datasets := OmegaConf.load(Dataset_path)['diffusers'])
def install_requested_models(
install_initial_models: List[str] = None,
remove_models: List[str] = None,
install_cn_models: List[str] = None,
remove_cn_models: List[str] = None,
cn_model_map: Dict[str,str] = None,
diffusers: ModelInstallList = None,
controlnet: ModelInstallList = None,
lora: ModelInstallList = None,
ti: ModelInstallList = None,
cn_model_map: Dict[str,str] = None, # temporary - move to model manager
scan_directory: Path = None,
external_models: List[str] = None,
scan_at_startup: bool = False,
@ -81,33 +89,30 @@ def install_requested_models(
if not config_file_path.exists():
open(config_file_path, "w")
install_controlnet_models(
install_cn_models,
short_name_map = cn_model_map,
precision=precision,
access_token=access_token,
)
delete_controlnet_models(remove_cn_models)
# prevent circular import here
from ..model_management import ModelManager
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token)
model_manager.delete_controlnet_models(controlnet.remove_models)
if remove_models and len(remove_models) > 0:
print("== DELETING UNCHECKED STARTER MODELS ==")
for model in remove_models:
print(f"{model}...")
# TODO: Replace next three paragraphs with calls into new model manager
if diffusers.remove_models and len(diffusers.remove_models) > 0:
logger.info("DELETING UNCHECKED STARTER MODELS")
for model in diffusers.remove_models:
logger.info(f"{model}...")
model_manager.del_model(model, delete_files=purge_deleted)
model_manager.commit(config_file_path)
if install_initial_models and len(install_initial_models) > 0:
print("== INSTALLING SELECTED STARTER MODELS ==")
if diffusers.install_models and len(diffusers.install_models) > 0:
logger.info("INSTALLING SELECTED STARTER MODELS")
successfully_downloaded = download_weight_datasets(
models=install_initial_models,
models=diffusers.install_initial_models,
access_token=None,
precision=precision,
) # FIX: for historical reasons, we don't use model manager here
update_config_file(successfully_downloaded, config_file_path)
if len(successfully_downloaded) < len(install_initial_models):
print("** Some of the model downloads were not successful")
if len(successfully_downloaded) < len(diffusers.install_models):
logger.warning("Some of the model downloads were not successful")
# due to above, we have to reload the model manager because conf file
# was changed behind its back
@ -118,7 +123,7 @@ def install_requested_models(
external_models.append(str(scan_directory))
if len(external_models) > 0:
print("== INSTALLING EXTERNAL MODELS ==")
logger.info("INSTALLING EXTERNAL MODELS")
for path_url_or_repo in external_models:
try:
model_manager.heuristic_import(
@ -190,10 +195,10 @@ def migrate_models_ckpt():
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
return
new_name = initial_models()["stable-diffusion-1.4"]["file"]
print(
logger.warning(
'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.'
)
print(f"model.ckpt => {new_name}")
logger.warning(f"model.ckpt => {new_name}")
os.replace(
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
)
@ -206,7 +211,7 @@ def download_weight_datasets(
migrate_models_ckpt()
successful = dict()
for mod in models:
print(f"Downloading {mod}:")
logger.info(f"Downloading {mod}:")
successful[mod] = _download_repo_or_file(
initial_models()[mod], access_token, precision=precision
)
@ -240,68 +245,6 @@ def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
)
# ---------------------------------------------
def install_controlnet_models(
short_names: List[str],
short_name_map: Dict[str,str],
precision: str='float16',
access_token: str = None,
):
'''
Download list of controlnet models, using their HuggingFace
repo_ids.
'''
dest_dir = config.controlnet_path
if not dest_dir.exists():
dest_dir.mkdir(parents=True,exist_ok=False)
# The model file may be fp32 or fp16, and may be either a
# .bin file or a .safetensors. We try each until we get one,
# preferring 'fp16' if using half precision, and preferring
# safetensors over over bin.
precisions = ['.fp16',''] if precision=='float16' else ['']
formats = ['.safetensors','.bin']
possible_filenames = list()
for p in precisions:
for f in formats:
possible_filenames.append(Path(f'diffusion_pytorch_model{p}{f}'))
for directory_name in short_names:
repo_id = short_name_map[directory_name]
safe_name = directory_name.replace('/','--')
print(f'Downloading ControlNet model {directory_name} ({repo_id})')
hf_download_with_resume(
repo_id = repo_id,
model_dir = dest_dir / safe_name,
model_name = 'config.json',
access_token = access_token
)
path = None
for filename in possible_filenames:
suffix = filename.suffix
dest_filename = Path(f'diffusion_pytorch_model{suffix}')
print(f'Probing {directory_name}/{filename}...')
path = hf_download_with_resume(
repo_id = repo_id,
model_dir = dest_dir / safe_name,
model_name = str(filename),
access_token = access_token,
model_dest = Path(dest_dir, safe_name, dest_filename),
)
if path:
(path.parent / '.download_complete').touch()
break
# ---------------------------------------------
def delete_controlnet_models(short_names: List[str]):
for name in short_names:
safe_name = name.replace('/','--')
directory = config.controlnet_path / safe_name
if directory.exists():
print(f'Purging controlnet model {name}')
shutil.rmtree(str(directory))
# ---------------------------------------------
def download_from_hf(
model_class: object, model_name: str, **kwargs
@ -340,7 +283,7 @@ def _download_diffusion_weights(
if str(e).startswith("fp16 is not a valid"):
pass
else:
print(f"An unexpected error occurred while downloading the model: {e})")
logger.error(f"An unexpected error occurred while downloading the model: {e})")
if path:
break
return path
@ -374,17 +317,17 @@ def hf_download_with_resume(
if (
resp.status_code == 416
): # "range not satisfiable", which means nothing to return
print(f"* {model_name}: complete file found. Skipping.")
logger.info(f"{model_name}: complete file found. Skipping.")
return model_dest
elif resp.status_code == 404:
print("** File not found")
logger.warning("File not found")
return None
elif resp.status_code != 200:
print(f"** Warning: {model_name}: {resp.reason}")
logger.warning(f"{model_name}: {resp.reason}")
elif exist_size > 0:
print(f"* {model_name}: partial file found. Resuming...")
logger.info(f"{model_name}: partial file found. Resuming...")
else:
print(f"* {model_name}: Downloading...")
logger.info(f"{model_name}: Downloading...")
try:
with open(model_dest, open_mode) as file, tqdm(
@ -399,7 +342,7 @@ def hf_download_with_resume(
size = file.write(data)
bar.update(size)
except Exception as e:
print(f"An error occurred while downloading {model_name}: {str(e)}")
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
return None
return model_dest
@ -424,8 +367,8 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
try:
backup = None
if os.path.exists(config_file):
print(
f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
logger.warning(
f"{config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
)
backup = config_file.with_suffix(".yaml.orig")
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
@ -442,16 +385,16 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
new_config.write(tmp.read())
except Exception as e:
print(f"**Error creating config file {config_file}: {str(e)} **")
logger.error(f"Error creating config file {config_file}: {str(e)}")
if backup is not None:
print("restoring previous config file")
logger.info("restoring previous config file")
## workaround, for WinError 183, see above
if sys.platform == "win32" and config_file.is_file():
config_file.unlink()
backup.rename(config_file)
return
print(f"Successfully created new configuration file {config_file}")
logger.info(f"Successfully created new configuration file {config_file}")
# ---------------------------------------------
@ -518,8 +461,8 @@ def delete_weights(model_name: str, conf_stanza: dict):
if re.match("/VAE/", conf_stanza.get("config")):
return
print(
f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
logger.warning(
f"\nThe checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
)
weights = Path(weights)
@ -528,4 +471,4 @@ def delete_weights(model_name: str, conf_stanza: dict):
try:
weights.unlink()
except OSError as e:
print(str(e))
logger.error(str(e))

View File

@ -11,6 +11,7 @@ import gc
import hashlib
import os
import re
import shutil
import sys
import textwrap
import time
@ -49,6 +50,10 @@ from ..stable_diffusion import (
StableDiffusionGeneratorPipeline,
)
from invokeai.app.services.config import get_invokeai_config
from ..install.model_install_backend import (
Dataset_path,
hf_download_with_resume,
)
from ..util import CUDA_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum):
@ -1316,3 +1321,74 @@ class ModelManager(object):
return (
os.getenv("HF_HOME") is not None or os.getenv("XDG_CACHE_HOME") is not None
)
def list_controlnet_models(self)->Dict[str,bool]:
'''Return a dict of installed controlnet models; key is repo_id or short name
of model (defined in INITIAL_MODELS), and valule is True if installed'''
cn_models = OmegaConf.load(Dataset_path).get('controlnet') or {}
installed_models = {x: False for x in cn_models.keys()}
cn_dir = self.globals.controlnet_path
installed_cn_models = dict()
for root, dirs, files in os.walk(cn_dir):
for name in dirs:
if Path(root, name, '.download_complete').exists():
installed_models.update({name.replace('--','/'): True})
return installed_models
def install_controlnet_models(self, model_names: list[str], access_token: str=None):
'''Download list of controlnet models; provide either repo_id or short name listed in INITIAL_MODELS.yaml'''
short_names = OmegaConf.load(Dataset_path).get('controlnet') or {}
dest_dir = self.globals.controlnet_path
dest_dir.mkdir(parents=True,exist_ok=True)
# The model file may be fp32 or fp16, and may be either a
# .bin file or a .safetensors. We try each until we get one,
# preferring 'fp16' if using half precision, and preferring
# safetensors over over bin.
precisions = ['.fp16',''] if self.precision=='float16' else ['']
formats = ['.safetensors','.bin']
possible_filenames = list()
for p in precisions:
for f in formats:
possible_filenames.append(Path(f'diffusion_pytorch_model{p}{f}'))
for directory_name in model_names:
repo_id = short_names.get(directory_name) or directory_name
safe_name = directory_name.replace('/','--')
self.logger.info(f'Downloading ControlNet model {directory_name} ({repo_id})')
hf_download_with_resume(
repo_id = repo_id,
model_dir = dest_dir / safe_name,
model_name = 'config.json',
access_token = access_token
)
path = None
for filename in possible_filenames:
suffix = filename.suffix
dest_filename = Path(f'diffusion_pytorch_model{suffix}')
self.logger.info(f'Checking availability of {directory_name}/{filename}...')
path = hf_download_with_resume(
repo_id = repo_id,
model_dir = dest_dir / safe_name,
model_name = str(filename),
access_token = access_token,
model_dest = Path(dest_dir, safe_name, dest_filename),
)
if path:
(path.parent / '.download_complete').touch()
break
def delete_controlnet_models(self, model_names: List[str]):
'''Remove the list of controlnet models'''
for name in model_names:
safe_name = name.replace('/','--')
directory = self.globals.controlnet_path / safe_name
if directory.exists():
self.logger.info(f'Purging controlnet model {name}')
shutil.rmtree(str(directory))

View File

@ -30,11 +30,15 @@ from ...backend.install.model_install_backend import (
default_dataset,
install_requested_models,
recommended_datasets,
ModelInstallList,
dataclass,
)
from ...backend import ModelManager
from ...backend.util import choose_precision, choose_torch_device
from .widgets import (
CenteredTitleText,
MultiSelectColumns,
SingleSelectColumns,
OffsetButtonPress,
TextBox,
set_min_terminal_size,
@ -53,12 +57,12 @@ class addModelsForm(npyscreen.FormMultiPage):
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
self.multipage = multipage
model_manager = ModelManager(config.model_conf_path)
self.initial_models = OmegaConf.load(Dataset_path)['diffusers']
self.control_net_models = OmegaConf.load(Dataset_path)['controlnet']
self.installed_cn_models = self._get_installed_cn_models()
self._add_additional_cn_models(self.control_net_models,self.installed_cn_models)
self.installed_cn_models = model_manager.list_controlnet_models()
try:
self.existing_models = OmegaConf.load(default_config_file())
except:
@ -79,7 +83,7 @@ class addModelsForm(npyscreen.FormMultiPage):
[x for x in list(self.initial_models.keys()) if x in self.existing_models]
)
cn_model_list = sorted(self.control_net_models.keys())
cn_model_list = sorted(self.installed_cn_models.keys())
self.nextrely -= 1
self.add_widget_intelligent(
@ -180,16 +184,34 @@ class addModelsForm(npyscreen.FormMultiPage):
relx=4,
scroll_exit=True,
)
self.add_widget_intelligent(
CenteredTitleText,
name="== CONTROLNET MODELS ==",
name='_' * (window_width-5),
editable=False,
color="CONTROL",
labelColor='CAUTION'
)
self.nextrely -= 1
self.add_widget_intelligent(
self.nextrely += 1
self.tabs = self.add_widget_intelligent(
SingleSelectColumns,
values=['ADD CONTROLNET MODELS','ADD LORA/LYCORIS MODELS', 'ADD TEXTUAL INVERSION MODELS'],
value=0,
columns = 4,
max_height = 2,
relx=8,
scroll_exit = True,
)
# self.add_widget_intelligent(
# CenteredTitleText,
# name="== CONTROLNET MODELS ==",
# editable=False,
# color="CONTROL",
# )
top_of_table = self.nextrely
self.cn_label_1 = self.add_widget_intelligent(
CenteredTitleText,
name="Select the desired ControlNet models. Unchecked models will be purged from disk.",
name="Select the desired models to install. Unchecked models will be purged from disk.",
editable=False,
labelColor="CAUTION",
)
@ -202,14 +224,14 @@ class addModelsForm(npyscreen.FormMultiPage):
value=[
cn_model_list.index(x)
for x in cn_model_list
if x in self.installed_cn_models
if self.installed_cn_models[x]
],
max_height=len(cn_model_list)//columns + 1,
relx=4,
scroll_exit=True,
)
self.nextrely += 1
self.add_widget_intelligent(
self.cn_label_2 = self.add_widget_intelligent(
npyscreen.TitleFixedText,
name='Additional ControlNet HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):',
relx=4,
@ -221,6 +243,55 @@ class addModelsForm(npyscreen.FormMultiPage):
self.additional_controlnet_ids = self.add_widget_intelligent(
TextBox, max_height=2, scroll_exit=True, editable=True, relx=4
)
bottom_of_table = self.nextrely
self.nextrely = top_of_table
self.lora_label_1 = self.add_widget_intelligent(
npyscreen.TitleFixedText,
name='LoRA/LYCORIS models to download and install (Space separated. Use shift-control-V to paste):',
relx=4,
color='CONTROL',
editable=False,
hidden=True,
scroll_exit=True
)
self.nextrely -= 1
self.loras = self.add_widget_intelligent(
TextBox,
max_height=2,
scroll_exit=True,
editable=True,
relx=4,
hidden=True,
)
self.nextrely = top_of_table
self.ti_label_1 = self.add_widget_intelligent(
npyscreen.TitleFixedText,
name='Textual Inversion models to download and install (Space separated. Use shift-control-V to paste):',
relx=4,
color='CONTROL',
editable=False,
hidden=True,
scroll_exit=True
)
self.nextrely -= 1
self.tis = self.add_widget_intelligent(
TextBox,
max_height=2,
scroll_exit=True,
editable=True,
relx=4,
hidden=True,
)
self.nextrely = bottom_of_table
self.nextrely += 1
self.add_widget_intelligent(
CenteredTitleText,
name='_' * (window_width-5),
editable=False,
labelColor='CAUTION'
)
self.cancel = self.add_widget_intelligent(
npyscreen.ButtonPress,
name="CANCEL",
@ -254,6 +325,9 @@ class addModelsForm(npyscreen.FormMultiPage):
for i in [self.autoload_directory, self.autoscan_on_startup]:
self.show_directory_fields.addVisibleWhenSelected(i)
# self.tabs.when_value_edited = self._toggle_tables
self.tabs.on_changed = self._toggle_tables
self.show_directory_fields.when_value_edited = self._clear_scan_directory
def resize(self):
@ -261,6 +335,21 @@ class addModelsForm(npyscreen.FormMultiPage):
if hasattr(self, "models_selected"):
self.models_selected.values = self._get_starter_model_labels()
def _toggle_tables(self, value=None):
selected_tab = value[0] if value else self.tabs.value[0]
widgets = [
[self.cn_label_1, self.cn_models_selected, self.cn_label_2, self.additional_controlnet_ids],
[self.lora_label_1,self.loras],
[self.ti_label_1,self.tis],
]
for group in widgets:
for w in group:
w.hidden = True
for w in widgets[selected_tab]:
w.hidden = False
self.display()
def _clear_scan_directory(self):
if not self.show_directory_fields.value:
self.autoload_directory.value = ""
@ -361,20 +450,18 @@ class addModelsForm(npyscreen.FormMultiPage):
selections.install_models = [x for x in starter_models if x not in self.existing_models]
selections.remove_models = [x for x in self.starter_model_list if x in self.existing_models and x not in starter_models]
selections.control_net_map = self.control_net_models
selections.install_cn_models = [self.cn_models_selected.values[x]
for x in self.cn_models_selected.value
if self.cn_models_selected.values[x] not in self.installed_cn_models
if not self.installed_cn_models[self.cn_models_selected.values[x]]
]
selections.remove_cn_models = [x
for x in self.cn_models_selected.values
if x in self.installed_cn_models
if self.installed_cn_models[x]
and self.cn_models_selected.values.index(x) not in self.cn_models_selected.value
]
if (additional_cns := self.additional_controlnet_ids.value.split()):
valid_cns = [x for x in additional_cns if '/' in x]
selections.install_cn_models.extend(valid_cns)
selections.control_net_map.update({x: x for x in valid_cns})
# load directory and whether to scan on startup
if self.show_directory_fields.value:
@ -387,22 +474,22 @@ class addModelsForm(npyscreen.FormMultiPage):
# URLs and the like
selections.import_model_paths = self.import_model_paths.value.split()
@dataclass
class UserSelections():
install_models: List[str]=None
remove_models: List[str]=None
purge_deleted_models: bool=False,
install_cn_models: List[str] = None,
remove_cn_models: List[str] = None,
scan_directory: Path=None,
autoscan_on_startup: bool=False,
import_model_paths: str=None,
class AddModelApplication(npyscreen.NPSAppManaged):
def __init__(self):
super().__init__()
self.user_cancelled = False
self.user_selections = Namespace(
install_models=None,
remove_models=None,
purge_deleted_models=False,
install_cn_models = None,
remove_cn_models = None,
control_net_map = None,
scan_directory=None,
autoscan_on_startup=None,
import_model_paths=None,
)
self.user_selections = UserSelections()
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
@ -418,15 +505,9 @@ def process_and_execute(opt: Namespace, selections: Namespace):
directory_to_scan = selections.scan_directory
scan_at_startup = selections.autoscan_on_startup
potential_models_to_install = selections.import_model_paths
print(f'selections.install_cn_models={selections.install_cn_models}')
print(f'selections.remove_cn_models={selections.remove_cn_models}')
print(f'selections.cn_model_map={selections.control_net_map}')
install_requested_models(
install_initial_models=models_to_install,
remove_models=models_to_remove,
install_cn_models=selections.install_cn_models,
remove_cn_models=selections.remove_cn_models,
cn_model_map=selections.control_net_map,
diffusers = ModelInstallList(models_to_install, models_to_remove),
controlnet = ModelInstallList(selections.install_cn_models, selections.remove_cn_models),
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
external_models=potential_models_to_install,
scan_at_startup=scan_at_startup,

View File

@ -94,13 +94,7 @@ class FloatTitleSlider(npyscreen.TitleText):
_entry_type = FloatSlider
class MultiSelectColumns(npyscreen.MultiSelect):
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
self.columns = columns
self.value_cnt = len(values)
self.rows = math.ceil(self.value_cnt / self.columns)
super().__init__(screen, values=values, **keywords)
class SelectColumnBase():
def make_contained_widgets(self):
self._my_widgets = []
column_width = self.width // self.columns
@ -150,6 +144,32 @@ class MultiSelectColumns(npyscreen.MultiSelect):
def h_cursor_line_right(self, ch):
super().h_cursor_line_down(ch)
class MultiSelectColumns( SelectColumnBase, npyscreen.MultiSelect):
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
self.columns = columns
self.value_cnt = len(values)
self.rows = math.ceil(self.value_cnt / self.columns)
super().__init__(screen, values=values, **keywords)
class SingleSelectColumns(SelectColumnBase, npyscreen.SelectOne):
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
self.columns = columns
self.value_cnt = len(values)
self.rows = math.ceil(self.value_cnt / self.columns)
self.on_changed = None
super().__init__(screen, values=values, **keywords)
def h_select(self,ch):
super().h_select(ch)
if self.on_changed:
self.on_changed(self.value)
def when_value_edited(self):
self.h_select(self.cursor_line)
def when_cursor_moved(self):
self.h_select(self.cursor_line)
class TextBox(npyscreen.MultiLineEdit):
def update(self, clear=True):