new TUI is fully functional; needs some polishing

This commit is contained in:
Lincoln Stein 2023-06-02 17:20:50 -04:00
parent 41f7758977
commit 1390b65a9c
5 changed files with 186 additions and 71 deletions

View File

@ -95,9 +95,12 @@ def install_requested_models(
model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token) model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token)
model_manager.delete_controlnet_models(controlnet.remove_models) model_manager.delete_controlnet_models(controlnet.remove_models)
model_manager.install_lora_models(lora.install_models) model_manager.install_lora_models(lora.install_models, access_token=access_token)
model_manager.delete_lora_models(lora.remove_models) model_manager.delete_lora_models(lora.remove_models)
model_manager.install_ti_models(ti.install_models, access_token=access_token)
model_manager.delete_ti_models(ti.remove_models)
# TODO: Replace next three paragraphs with calls into new model manager # TODO: Replace next three paragraphs with calls into new model manager
if diffusers.remove_models and len(diffusers.remove_models) > 0: if diffusers.remove_models and len(diffusers.remove_models) > 0:
logger.info("DELETING UNCHECKED STARTER MODELS") logger.info("DELETING UNCHECKED STARTER MODELS")
@ -109,7 +112,7 @@ def install_requested_models(
if diffusers.install_models and len(diffusers.install_models) > 0: if diffusers.install_models and len(diffusers.install_models) > 0:
logger.info("INSTALLING SELECTED STARTER MODELS") logger.info("INSTALLING SELECTED STARTER MODELS")
successfully_downloaded = download_weight_datasets( successfully_downloaded = download_weight_datasets(
models=diffusers.install_initial_models, models=diffusers.install_models,
access_token=None, access_token=None,
precision=precision, precision=precision,
) # FIX: for historical reasons, we don't use model manager here ) # FIX: for historical reasons, we don't use model manager here

View File

@ -1334,24 +1334,54 @@ class ModelManager(object):
installed_models = dict() installed_models = dict()
for root, dirs, files in os.walk(dir): for root, dirs, files in os.walk(dir):
for name in files: for name in files:
if Path(name).suffix in ['.safetensors','.ckpt','.pt']: if Path(name).suffix not in ['.safetensors','.ckpt','.pt','.bin']:
installed_models.update({name: True}) continue
if name == 'pytorch_lora_weights.bin':
name = Path(root,name).parent.stem #Path(root,name).stem
else:
name = Path(name).stem
installed_models.update({name: True})
return installed_models return installed_models
def install_lora_models(self, model_names: list[str]): def install_lora_models(self, model_names: list[str], access_token:str=None):
'''Download list of LoRA/LyCORIS models''' '''Download list of LoRA/LyCORIS models'''
short_names = OmegaConf.load(Dataset_path).get('lora') or {} short_names = OmegaConf.load(Dataset_path).get('lora') or {}
for name in model_names: for name in model_names:
url = short_names.get(name) or name print(name)
download_with_resume(url, self.globals.lora_path)
name = short_names.get(name) or name
# HuggingFace style LoRA
if re.match(r"^[\w.+-]+/([\w.+-]+)$", name):
self.logger.info(f'Downloading LoRA/LyCORIS model {name}')
_,dest_dir = name.split("/")
hf_download_with_resume(
repo_id = name,
model_dir = self.globals.lora_path / dest_dir,
model_name = 'pytorch_lora_weights.bin',
access_token = access_token,
)
elif name.startswith(("http:", "https:", "ftp:")):
download_with_resume(name, self.globals.lora_path)
else:
self.logger.error(f"Unknown repo_id or URL: {name}")
def delete_lora_models(self, model_names: List[str]): def delete_lora_models(self, model_names: List[str]):
'''Remove the list of lora models''' '''Remove the list of lora models'''
for name in model_names: for name in model_names:
path = self.globals.lora_path / name file_or_directory = self.globals.lora_path / name
if path.exists(): if file_or_directory.is_dir():
self.logger.info(f'Purging lora model {name}') self.logger.info(f'Purging LoRA/LyCORIS {name}')
path.unlink() shutil.rmtree(str(file_or_directory))
else:
for path in self.globals.lora_path.glob(f'{name}.*'):
self.logger.info(f'Purging LoRA/LyCORIS {name}')
path.unlink()
def list_ti_models(self)->Dict[str,bool]: def list_ti_models(self)->Dict[str,bool]:
'''Return a dict of installed textual models; key is either the shortname '''Return a dict of installed textual models; key is either the shortname
@ -1362,21 +1392,50 @@ class ModelManager(object):
installed_models = {x: False for x in models.keys()} installed_models = {x: False for x in models.keys()}
dir = self.globals.embedding_path dir = self.globals.embedding_path
installed_models = dict()
for root, dirs, files in os.walk(dir): for root, dirs, files in os.walk(dir):
for name in files: for name in files:
if not Path(name).suffix in ['.bin','.pt','.ckpt','.safetensors']:
continue
if name == 'learned_embeds.bin': if name == 'learned_embeds.bin':
name = str(Path(root,name).parent) name = Path(root,name).parent.stem #Path(root,name).stem
else:
name = Path(name).stem
installed_models.update({name: True}) installed_models.update({name: True})
return installed_models return installed_models
def install_ti_models(self, model_names: list[str]): def install_ti_models(self, model_names: list[str], access_token: str=None):
'''Download list of textual inversion embeddings''' '''Download list of textual inversion embeddings'''
short_names = OmegaConf.load(Dataset_path).get('textual_inversion') or {} short_names = OmegaConf.load(Dataset_path).get('textual_inversion') or {}
for name in model_names: for name in model_names:
url = short_names.get(name) or name name = short_names.get(name) or name
download_with_resume(url, self.globals.embedding_path)
if re.match(r"^[\w.+-]+/([\w.+-]+)$", name):
self.logger.info(f'Downloading Textual Inversion embedding {name}')
_,dest_dir = name.split("/")
hf_download_with_resume(
repo_id = name,
model_dir = self.globals.embedding_path / dest_dir,
model_name = 'learned_embeds.bin',
access_token = access_token
)
elif name.startswith(('http:','https:','ftp:')):
download_with_resume(name, self.globals.embedding_path)
else:
self.logger.error(f'{name} does not look like either a HuggingFace repo_id or a downloadable URL')
def delete_ti_models(self, model_names: list[str]):
'''Remove TI embeddings from disk'''
for name in model_names:
file_or_directory = self.globals.embedding_path / name
if file_or_directory.is_dir():
self.logger.info(f'Purging textual inversion embedding {name}')
shutil.rmtree(str(file_or_directory))
else:
for path in self.globals.embedding_path.glob(f'{name}.*'):
self.logger.info(f'Purging textual inversion embedding {name}')
path.unlink()
def list_controlnet_models(self)->Dict[str,bool]: def list_controlnet_models(self)->Dict[str,bool]:
'''Return a dict of installed controlnet models; key is repo_id or short name '''Return a dict of installed controlnet models; key is repo_id or short name
of model (defined in INITIAL_MODELS), and value is True if installed''' of model (defined in INITIAL_MODELS), and value is True if installed'''

View File

@ -322,8 +322,8 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
logger.warning("corrupt existing file found. re-downloading") logger.warning("corrupt existing file found. re-downloading")
os.remove(dest) os.remove(dest)
exist_size = 0 exist_size = 0
if resp.status_code == 416 or exist_size == content_length: if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
logger.warning(f"{dest}: complete file found. Skipping.") logger.warning(f"{dest}: complete file found. Skipping.")
return dest return dest
elif resp.status_code == 206 or exist_size > 0: elif resp.status_code == 206 or exist_size > 0:

View File

@ -97,7 +97,10 @@ controlnet:
tile: lllyasviel/control_v11f1e_sd15_tile tile: lllyasviel/control_v11f1e_sd15_tile
ip2p: lllyasviel/control_v11e_sd15_ip2p ip2p: lllyasviel/control_v11e_sd15_ip2p
textual_inversion: textual_inversion:
'EasyNegative.safetensors': https://huggingface.co/embed/EasyNegative/blob/main/EasyNegative.safetensors 'EasyNegative': https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
'ahx-beta-453407d': sd-concepts-library/ahx-beta-453407d
lora: lora:
'LowRA.safetensors': https://civitai.com/api/download/models/63006 'LowRA': https://civitai.com/api/download/models/63006
'Ink scenery.safetensors': https://civitai.com/api/download/models/83390 'Ink scenery': https://civitai.com/api/download/models/83390
'sd-model-finetuned-lora-t4': sayakpaul/sd-model-finetuned-lora-t4

View File

@ -10,6 +10,7 @@ The work is actually done in backend code in model_install_backend.py.
""" """
import argparse import argparse
import curses
import os import os
import sys import sys
from argparse import Namespace from argparse import Namespace
@ -23,6 +24,7 @@ from npyscreen import widget
from omegaconf import OmegaConf from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from dataclasses import dataclass,field
from ...backend.install.model_install_backend import ( from ...backend.install.model_install_backend import (
Dataset_path, Dataset_path,
@ -31,7 +33,6 @@ from ...backend.install.model_install_backend import (
install_requested_models, install_requested_models,
recommended_datasets, recommended_datasets,
ModelInstallList, ModelInstallList,
dataclass,
) )
from ...backend import ModelManager from ...backend import ModelManager
from ...backend.util import choose_precision, choose_torch_device from ...backend.util import choose_precision, choose_torch_device
@ -51,9 +52,13 @@ MIN_LINES = 50
config = get_invokeai_config() config = get_invokeai_config()
class addModelsForm(npyscreen.FormMultiPage): class addModelsForm(npyscreen.FormMultiPage):
# for responsive resizing - disabled # for responsive resizing - disabled
# FIX_MINIMUM_SIZE_WHEN_CREATED = False # FIX_MINIMUM_SIZE_WHEN_CREATED = False
# for persistence
current_tab = 0
def __init__(self, parentApp, name, multipage=False, *args, **keywords): def __init__(self, parentApp, name, multipage=False, *args, **keywords):
self.multipage = multipage self.multipage = multipage
@ -89,7 +94,6 @@ class addModelsForm(npyscreen.FormMultiPage):
editable=False, editable=False,
color="CAUTION", color="CAUTION",
) )
self.nextrely += 1 self.nextrely += 1
self.tabs = self.add_widget_intelligent( self.tabs = self.add_widget_intelligent(
SingleSelectColumns, SingleSelectColumns,
@ -99,7 +103,7 @@ class addModelsForm(npyscreen.FormMultiPage):
'LORA/LYCORIS MODELS', 'LORA/LYCORIS MODELS',
'TEXTUAL INVERSION MODELS' 'TEXTUAL INVERSION MODELS'
], ],
value=0, value=[self.current_tab],
columns = 4, columns = 4,
max_height = 2, max_height = 2,
relx=8, relx=8,
@ -110,7 +114,7 @@ class addModelsForm(npyscreen.FormMultiPage):
top_of_table = self.nextrely top_of_table = self.nextrely
self.diffusers_models = self.add_diffusers() self.diffusers_models = self.add_diffusers()
bottom_of_table = self.nextrely bottom_of_table = self.nextrely
self.nextrely = top_of_table self.nextrely = top_of_table
self.controlnet_models = self.add_controlnets() self.controlnet_models = self.add_controlnets()
@ -123,14 +127,7 @@ class addModelsForm(npyscreen.FormMultiPage):
self.nextrely = bottom_of_table self.nextrely = bottom_of_table
self.nextrely += 1 self.nextrely += 1
self.cancel = self.add_widget_intelligent( done_label = "INSTALL/REMOVE"
npyscreen.ButtonPress,
name="CANCEL",
rely=-3,
when_pressed_function=self.on_cancel,
)
done_label = "DONE"
back_label = "BACK" back_label = "BACK"
button_length = len(done_label) button_length = len(done_label)
button_offset = 0 button_offset = 0
@ -154,7 +151,18 @@ class addModelsForm(npyscreen.FormMultiPage):
when_pressed_function=self.on_ok, when_pressed_function=self.on_ok,
) )
self._toggle_tables([0]) self.cancel = self.add_widget_intelligent(
npyscreen.ButtonPress,
name="QUIT",
rely=-3,
relx=window_width-20,
when_pressed_function=self.on_cancel,
)
# This restores the selected page on return from an installation
for i in range(1,self.current_tab+1):
self.tabs.h_cursor_line_down(1)
self._toggle_tables([self.current_tab])
def add_diffusers(self)->dict[str, npyscreen.widget]: def add_diffusers(self)->dict[str, npyscreen.widget]:
'''Add widgets responsible for selecting diffusers models''' '''Add widgets responsible for selecting diffusers models'''
@ -172,16 +180,6 @@ class addModelsForm(npyscreen.FormMultiPage):
widgets.update( widgets.update(
label1 = self.add_widget_intelligent( label1 = self.add_widget_intelligent(
CenteredTitleText,
name="== DIFFUSERS MODEL STARTER PACK ==",
editable=False,
color="CONTROL",
)
)
self.nextrely -= 1
widgets.update(
label2 = self.add_widget_intelligent(
CenteredTitleText, CenteredTitleText,
name="Select from a starter set of Stable Diffusion models from HuggingFace.", name="Select from a starter set of Stable Diffusion models from HuggingFace.",
editable=False, editable=False,
@ -283,7 +281,7 @@ class addModelsForm(npyscreen.FormMultiPage):
widgets.update( widgets.update(
label1 = self.add_widget_intelligent( label1 = self.add_widget_intelligent(
CenteredTitleText, CenteredTitleText,
name="Select the desired models to install. Unchecked models will be purged from disk.", name="Select the desired ControlNet models to install. Unchecked models will be purged from disk.",
editable=False, editable=False,
labelColor="CAUTION", labelColor="CAUTION",
) )
@ -322,7 +320,11 @@ class addModelsForm(npyscreen.FormMultiPage):
self.nextrely -= 1 self.nextrely -= 1
widgets.update( widgets.update(
download_ids = self.add_widget_intelligent( download_ids = self.add_widget_intelligent(
TextBox, max_height=2, scroll_exit=True, editable=True, relx=4 TextBox,
max_height=4,
scroll_exit=True,
editable=True,
relx=4
) )
) )
return widgets return widgets
@ -341,7 +343,7 @@ class addModelsForm(npyscreen.FormMultiPage):
) )
) )
columns=min(len(model_list),3) columns=min(len(model_list),3) or 1
widgets.update( widgets.update(
models_selected = self.add_widget_intelligent( models_selected = self.add_widget_intelligent(
MultiSelectColumns, MultiSelectColumns,
@ -376,7 +378,7 @@ class addModelsForm(npyscreen.FormMultiPage):
widgets.update( widgets.update(
download_ids = self.add_widget_intelligent( download_ids = self.add_widget_intelligent(
TextBox, TextBox,
max_height=2, max_height=4,
scroll_exit=True, scroll_exit=True,
editable=True, editable=True,
relx=4, relx=4,
@ -387,10 +389,39 @@ class addModelsForm(npyscreen.FormMultiPage):
def add_tis(self)->dict[str, npyscreen.widget]: def add_tis(self)->dict[str, npyscreen.widget]:
widgets = dict() widgets = dict()
model_list = sorted(self.installed_ti_models.keys())
widgets.update( widgets.update(
label1 = self.add_widget_intelligent( label1 = self.add_widget_intelligent(
CenteredTitleText,
name="Select the desired models to install. Unchecked models will be purged from disk.",
editable=False,
labelColor="CAUTION",
)
)
columns=min(len(model_list),6) or 1
widgets.update(
models_selected = self.add_widget_intelligent(
MultiSelectColumns,
columns=columns,
name="Install Textual Inversion Embeddings",
values=model_list,
value=[
model_list.index(x)
for x in model_list
if self.installed_ti_models[x]
],
max_height=len(model_list)//columns + 1,
relx=4,
scroll_exit=True,
)
)
widgets.update(
label2 = self.add_widget_intelligent(
npyscreen.TitleFixedText, npyscreen.TitleFixedText,
name='Textual Inversion models to download and install (Space separated. Use shift-control-V to paste):', name='Textual Inversion models to download, use URLs or HugggingFace repo_ids (Space separated. Use shift-control-V to paste):',
relx=4, relx=4,
color='CONTROL', color='CONTROL',
editable=False, editable=False,
@ -403,7 +434,7 @@ class addModelsForm(npyscreen.FormMultiPage):
widgets.update( widgets.update(
download_ids = self.add_widget_intelligent( download_ids = self.add_widget_intelligent(
TextBox, TextBox,
max_height=2, max_height=4,
scroll_exit=True, scroll_exit=True,
editable=True, editable=True,
relx=4, relx=4,
@ -431,6 +462,7 @@ class addModelsForm(npyscreen.FormMultiPage):
v.hidden = True v.hidden = True
for k,v in widgets[selected_tab].items(): for k,v in widgets[selected_tab].items():
v.hidden = False v.hidden = False
self.__class__.current_tab = selected_tab # for persistence
self.display() self.display()
def _get_starter_model_labels(self) -> List[str]: def _get_starter_model_labels(self) -> List[str]:
@ -477,12 +509,9 @@ class addModelsForm(npyscreen.FormMultiPage):
self.editing = False self.editing = False
def on_cancel(self): def on_cancel(self):
if npyscreen.notify_yes_no( self.parentApp.setNextForm(None)
"Are you sure you want to cancel?\nYou may re-run this script later using the invoke.sh or invoke.bat command.\n" self.parentApp.user_cancelled = True
): self.editing = False
self.parentApp.setNextForm(None)
self.parentApp.user_cancelled = True
self.editing = False
def marshall_arguments(self): def marshall_arguments(self):
""" """
@ -510,7 +539,7 @@ class addModelsForm(npyscreen.FormMultiPage):
selections.install_models = [x for x in starter_models if x not in self.existing_models] selections.install_models = [x for x in starter_models if x not in self.existing_models]
selections.remove_models = [x for x in self.starter_model_list if x in self.existing_models and x not in starter_models] selections.remove_models = [x for x in self.starter_model_list if x in self.existing_models and x not in starter_models]
# TODO: REFACTOR CUT AND PASTE CODE # TODO: REFACTOR THIS REPETITIVE CODE
cn_models_selected = self.controlnet_models['models_selected'] cn_models_selected = self.controlnet_models['models_selected']
selections.install_cn_models = [cn_models_selected.values[x] selections.install_cn_models = [cn_models_selected.values[x]
for x in cn_models_selected.value for x in cn_models_selected.value
@ -538,9 +567,24 @@ class addModelsForm(npyscreen.FormMultiPage):
] ]
if (additional_loras := self.lora_models['download_ids'].value.split()): if (additional_loras := self.lora_models['download_ids'].value.split()):
valid_loras = [x for x in additional_loras if x.startswith(('http:','https:','ftp:'))] selections.install_lora_models.extend(additional_loras)
selections.install_lora_models.extend(valid_loras)
# same thing, for TIs
# TODO: refactor
tis_selected = self.ti_models['models_selected']
selections.install_ti_models = [tis_selected.values[x]
for x in tis_selected.value
if not self.installed_ti_models[tis_selected.values[x]]
]
selections.remove_ti_models = [x
for x in tis_selected.values
if self.installed_ti_models[x]
and tis_selected.values.index(x) not in tis_selected.value
]
if (additional_tis := self.ti_models['download_ids'].value.split()):
selections.install_ti_models.extend(additional_tis)
# load directory and whether to scan on startup # load directory and whether to scan on startup
selections.scan_directory = self.diffusers_models['autoload_directory'].value selections.scan_directory = self.diffusers_models['autoload_directory'].value
selections.autoscan_on_startup = self.diffusers_models['autoscan_on_startup'].value selections.autoscan_on_startup = self.diffusers_models['autoscan_on_startup'].value
@ -548,16 +592,21 @@ class addModelsForm(npyscreen.FormMultiPage):
# URLs and the like # URLs and the like
selections.import_model_paths = self.diffusers_models['download_ids'].value.split() selections.import_model_paths = self.diffusers_models['download_ids'].value.split()
@dataclass @dataclass
class UserSelections(): class UserSelections():
install_models: List[str]=None install_models: List[str]= field(default_factory=list)
remove_models: List[str]=None remove_models: List[str]=field(default_factory=list)
purge_deleted_models: bool=False, purge_deleted_models: bool=field(default_factory=list)
install_cn_models: List[str] = None, install_cn_models: List[str] = field(default_factory=list)
remove_cn_models: List[str] = None, remove_cn_models: List[str] = field(default_factory=list)
scan_directory: Path=None, install_lora_models: List[str] = field(default_factory=list)
autoscan_on_startup: bool=False, remove_lora_models: List[str] = field(default_factory=list)
import_model_paths: str=None, install_ti_models: List[str] = field(default_factory=list)
remove_ti_models: List[str] = field(default_factory=list)
scan_directory: Path = None
autoscan_on_startup: bool=False
import_model_paths: str=None
class AddModelApplication(npyscreen.NPSAppManaged): class AddModelApplication(npyscreen.NPSAppManaged):
def __init__(self): def __init__(self):
@ -583,6 +632,7 @@ def process_and_execute(opt: Namespace, selections: Namespace):
diffusers = ModelInstallList(models_to_install, models_to_remove), diffusers = ModelInstallList(models_to_install, models_to_remove),
controlnet = ModelInstallList(selections.install_cn_models, selections.remove_cn_models), controlnet = ModelInstallList(selections.install_cn_models, selections.remove_cn_models),
lora = ModelInstallList(selections.install_lora_models, selections.remove_lora_models), lora = ModelInstallList(selections.install_lora_models, selections.remove_lora_models),
ti = ModelInstallList(selections.install_ti_models, selections.remove_ti_models),
scan_directory=Path(directory_to_scan) if directory_to_scan else None, scan_directory=Path(directory_to_scan) if directory_to_scan else None,
external_models=potential_models_to_install, external_models=potential_models_to_install,
scan_at_startup=scan_at_startup, scan_at_startup=scan_at_startup,
@ -615,10 +665,7 @@ def select_and_download_models(opt: Namespace):
set_min_terminal_size(MIN_COLS, MIN_LINES) set_min_terminal_size(MIN_COLS, MIN_LINES)
installApp = AddModelApplication() installApp = AddModelApplication()
installApp.run() installApp.run()
process_and_execute(opt, installApp.user_selections)
if not installApp.user_cancelled:
process_and_execute(opt, installApp.user_selections)
# ------------------------------------- # -------------------------------------
def main(): def main():
@ -679,6 +726,9 @@ def main():
logger.error(e) logger.error(e)
sys.exit(-1) sys.exit(-1)
except KeyboardInterrupt: except KeyboardInterrupt:
curses.nocbreak()
curses.echo()
curses.endwin()
logger.info("Goodbye! Come back soon.") logger.info("Goodbye! Come back soon.")
except widget.NotEnoughSpaceForWidget as e: except widget.NotEnoughSpaceForWidget as e:
if str(e).startswith("Height of 1 allocated"): if str(e).startswith("Height of 1 allocated"):