From f28d50070e14374946dfc335d08cc75220ce4bbe Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 16 Jun 2023 22:54:36 -0400 Subject: [PATCH] configure/install basically working; needs edge case testing --- .../backend/install/invokeai_configure.py | 33 +- .../backend/install/model_install_backend.py | 618 ++++++++---------- invokeai/backend/model_management/__init__.py | 1 + .../backend/model_management/model_manager.py | 7 +- .../backend/model_management/model_probe.py | 34 +- .../backend/model_management/models/base.py | 1 - invokeai/configs/INITIAL_MODELS.yaml | 3 +- .../v2-inpainting-inference-v.yaml | 159 +++++ .../v2-inpainting-inference.yaml | 158 +++++ invokeai/frontend/install/model_install.py | 238 ++----- invokeai/frontend/install/widgets.py | 29 +- scripts/migrate_models_to_3.0.py | 8 +- 12 files changed, 701 insertions(+), 588 deletions(-) create mode 100644 invokeai/configs/stable-diffusion/v2-inpainting-inference-v.yaml create mode 100644 invokeai/configs/stable-diffusion/v2-inpainting-inference.yaml diff --git a/invokeai/backend/install/invokeai_configure.py b/invokeai/backend/install/invokeai_configure.py index cdb3f47755..582b24cbfa 100755 --- a/invokeai/backend/install/invokeai_configure.py +++ b/invokeai/backend/install/invokeai_configure.py @@ -56,11 +56,10 @@ from invokeai.frontend.install.widgets import ( ) from invokeai.backend.install.legacy_arg_parsing import legacy_parser from invokeai.backend.install.model_install_backend import ( - default_dataset, - download_from_hf, + hf_download_from_pretrained, hf_download_with_resume, - recommended_datasets, - UserSelections, + InstallSelections, + ModelInstall, ) from invokeai.backend.model_management.model_probe import ( ModelProbe, ModelType, BaseModelType, SchedulerPredictionType @@ -198,8 +197,8 @@ def download_conversion_models(): # sd-1 repo_id = 'openai/clip-vit-large-patch14' - download_from_hf(CLIPTokenizer, repo_id, target_dir / 'clip-vit-large-patch14') - download_from_hf(CLIPTextModel, repo_id, target_dir / 'clip-vit-large-patch14') + hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / 'clip-vit-large-patch14') + hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / 'clip-vit-large-patch14') # sd-2 repo_id = "stabilityai/stable-diffusion-2" @@ -275,8 +274,8 @@ def download_clipseg(): logger.info("Installing clipseg model for text-based masking...") CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined" try: - download_from_hf(AutoProcessor, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg') - download_from_hf(CLIPSegForImageSegmentation, CLIPSEG_MODEL,'models/core/misc/clipseg') + hf_download_from_pretrained(AutoProcessor, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg') + hf_download_from_pretrained(CLIPSegForImageSegmentation, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg') except Exception: logger.info("Error installing clipseg model:") logger.info(traceback.format_exc()) @@ -592,7 +591,7 @@ class EditOptApplication(npyscreen.NPSAppManaged): self.program_opts = program_opts self.invokeai_opts = invokeai_opts self.user_cancelled = False - self.user_selections = default_user_selections(program_opts) + self.install_selections = default_user_selections(program_opts) def onStart(self): npyscreen.setTheme(npyscreen.Themes.DefaultTheme) @@ -627,19 +626,19 @@ def default_startup_options(init_file: Path) -> Namespace: opts.nsfw_checker = True return opts -def default_user_selections(program_opts: Namespace) -> UserSelections: - return UserSelections( - install_models=default_dataset() +def default_user_selections(program_opts: Namespace) -> InstallSelections: + installer = ModelInstall(config) + models = installer.all_models() + return InstallSelections( + install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id] if program_opts.default_only - else recommended_datasets() + else [models[x].path or models[x].repo_id for x in installer.recommended_models()] if program_opts.yes_to_all - else dict(), - purge_deleted_models=False, + else list(), scan_directory=None, autoscan_on_startup=None, ) - # ------------------------------------- def initialize_rootdir(root: Path, yes_to_all: bool = False): logger.info("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **") @@ -696,7 +695,7 @@ def run_console_ui( if editApp.user_cancelled: return (None, None) else: - return (editApp.new_opts, editApp.user_selections) + return (editApp.new_opts, editApp.install_selections) # ------------------------------------- diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index 60f2d89748..54e5cdc1d8 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -2,18 +2,18 @@ Utility (backend) functions used by model_install.py """ import os -import re import shutil import sys +import traceback import warnings from dataclasses import dataclass,field from pathlib import Path -from tempfile import TemporaryFile -from typing import List, Dict, Set, Callable +from tempfile import TemporaryDirectory +from typing import List, Dict, Callable, Union, Set import requests -from diffusers import AutoencoderKL -from huggingface_hub import hf_hub_url, HfFolder +from diffusers import AutoencoderKL, StableDiffusionPipeline +from huggingface_hub import hf_hub_url, HfFolder, HfApi from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig from tqdm import tqdm @@ -21,7 +21,9 @@ from tqdm import tqdm import invokeai.configs as configs from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType +from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType +from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo +from invokeai.backend.util import download_with_resume from ..stable_diffusion import StableDiffusionGeneratorPipeline from ..util.logging import InvokeAILogger @@ -29,19 +31,11 @@ warnings.filterwarnings("ignore") # --------------------------globals----------------------- config = InvokeAIAppConfig.get_config() - -Model_dir = "models" -Weights_dir = "ldm/stable-diffusion-v1/" +logger = InvokeAILogger.getLogger(name='InvokeAI') # the initial "configs" dir is now bundled in the `invokeai.configs` package Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml" -# initial models omegaconf -Datasets = None - -# logger -logger = InvokeAILogger.getLogger(name='InvokeAI') - Config_preamble = """ # This file describes the alternative machine learning models # available to InvokeAI script. @@ -52,6 +46,24 @@ Config_preamble = """ # was trained on. """ +LEGACY_CONFIGS = { + BaseModelType.StableDiffusion1: { + ModelVariantType.Normal: 'v1-inference.yaml', + ModelVariantType.Inpaint: 'v1-inpainting-inference.yaml', + }, + + BaseModelType.StableDiffusion2: { + ModelVariantType.Normal: { + SchedulerPredictionType.Epsilon: 'v2-inference.yaml', + SchedulerPredictionType.VPrediction: 'v2-inference-v.yaml', + }, + ModelVariantType.Inpaint: { + SchedulerPredictionType.Epsilon: 'v2-inpainting-inference.yaml', + SchedulerPredictionType.VPrediction: 'v2-inpainting-inference-v.yaml', + } + } +} + @dataclass class ModelInstallList: '''Class for listing models to be installed/removed''' @@ -59,18 +71,11 @@ class ModelInstallList: remove_models: List[str] = field(default_factory=list) @dataclass -class UserSelections(): +class InstallSelections(): install_models: List[str]= field(default_factory=list) remove_models: List[str]=field(default_factory=list) - install_cn_models: List[str] = field(default_factory=list) - remove_cn_models: List[str] = field(default_factory=list) - install_lora_models: List[str] = field(default_factory=list) - remove_lora_models: List[str] = field(default_factory=list) - install_ti_models: List[str] = field(default_factory=list) - remove_ti_models: List[str] = field(default_factory=list) scan_directory: Path = None autoscan_on_startup: bool=False - import_model_paths: str=None @dataclass class ModelLoadInfo(): @@ -82,18 +87,30 @@ class ModelLoadInfo(): description: str = '' installed: bool = False recommended: bool = False - + default: bool = False + class ModelInstall(object): - def __init__(self,config:InvokeAIAppConfig): + def __init__(self, + config:InvokeAIAppConfig, + prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, + access_token:str = None): self.config = config self.mgr = ModelManager(config.model_conf_path) self.datasets = OmegaConf.load(Dataset_path) + self.prediction_helper = prediction_type_helper + self.access_token = access_token or HfFolder.get_token() + self.reverse_paths = self._reverse_paths(self.datasets) def all_models(self)->Dict[str,ModelLoadInfo]: ''' - Return dict of model_key=>ModelStatus + Return dict of model_key=>ModelLoadInfo objects. + This method consolidates and simplifies the entries in both + models.yaml and INITIAL_MODELS.yaml so that they can + be treated uniformly. It also sorts the models alphabetically + by their name, to improve the display somewhat. ''' model_dict = dict() + # first populate with the entries in INITIAL_MODELS.yaml for key, value in self.datasets.items(): name,base,model_type = ModelManager.parse_key(key) @@ -128,102 +145,237 @@ class ModelInstall(object): if model_type==ModelType.Pipeline: models.add(key) return models + + def recommended_models(self)->Set[str]: + starters = self.starter_models() + return set([x for x in starters if self.datasets[x].get('recommended',False)]) + + def default_model(self)->str: + starters = self.starter_models() + defaults = [x for x in starters if self.datasets[x].get('default',False)] + return defaults[0] + + def install(self, selections: InstallSelections): + job = 1 + jobs = len(selections.remove_models) + len(selections.install_models) + if selections.scan_directory: + jobs += 1 - -def default_config_file(): - return config.model_conf_path - -def sd_configs(): - return config.legacy_conf_path - -def initial_models(): - global Datasets - if Datasets: - return Datasets - return (Datasets := OmegaConf.load(Dataset_path)['diffusers']) - -def add_models(model_manager, config_file_path: Path, models: List[tuple[str,str,str]]): - print(f'Installing {models}') - -def del_models(model_manager, config_file_path: Path, models: List[tuple[str,str,str]]): - for base, model_type, name in models: - logger.info(f"Deleting {name}...") - model_manager.del_model(name, base, model_type) - model_manager.commit(config_file_path) - -def install_requested_models( - diffusers: ModelInstallList = None, - controlnet: ModelInstallList = None, - lora: ModelInstallList = None, - ti: ModelInstallList = None, - cn_model_map: Dict[str,str] = None, # temporary - move to model manager - scan_directory: Path = None, - external_models: List[str] = None, - scan_at_startup: bool = False, - precision: str = "float16", - config_file_path: Path = None, - model_config_file_callback: Callable[[Path],Path] = None, -): - """ - Entry point for installing/deleting starter models, or installing external models. - """ - access_token = HfFolder.get_token() - config_file_path = config_file_path or default_config_file() - if not config_file_path.exists(): - open(config_file_path, "w") - - # prevent circular import here - from ..model_management import ModelManager - model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision) - - for x in [controlnet, lora, ti, diffusers]: - if x: - add_models(model_manager, config_file_path, x.install_models) - del_models(model_manager, config_file_path, x.remove_models) + # remove requested models + for key in selections.remove_models: + name,base,mtype = self.mgr.parse_key(key) + logger.info(f'Deleting {mtype} model {name} [{job}/{jobs}]') + self.mgr.del_model(name,base,mtype) + job += 1 - # if diffusers: + # add requested models + for path in selections.install_models: + logger.info(f'Installing {path} [{job}/{jobs}]') + self.heuristic_install(path) + job += 1 - # if diffusers.install_models and len(diffusers.install_models) > 0: - # logger.info("Installing requested models") - # downloaded_paths = download_weight_datasets( - # models=diffusers.install_models, - # access_token=None, - # precision=precision, - # ) - # successful = {x:v for x,v in downloaded_paths.items() if v is not None} - # if len(successful) > 0: - # update_config_file(successful, config_file_path) - # if len(successful) < len(diffusers.install_models): - # unsuccessful = [x for x in downloaded_paths if downloaded_paths[x] is None] - # logger.warning(f"Some of the model downloads were not successful: {unsuccessful}") + # import from the scan directory, if any + if path := selections.scan_directory: + logger.info(f'Scanning and importing models from directory {path} [{job}/{jobs}]') + self.heuristic_install(path) - # due to above, we have to reload the model manager because conf file - # was changed behind its back - model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision) + self.mgr.commit() - external_models = external_models or list() - if scan_directory: - external_models.append(str(scan_directory)) + if selections.autoscan_on_startup and Path(selections.scan_directory).is_dir(): + update_autoconvert_dir(selections.scan_directory) + else: + update_autoconvert_dir(None) - if len(external_models) > 0: - logger.info("INSTALLING EXTERNAL MODELS") - for path_url_or_repo in external_models: - logger.debug(path_url_or_repo) - try: - model_manager.heuristic_import( - path_url_or_repo, - commit_to_conf=config_file_path, - config_file_callback = model_config_file_callback, + def heuristic_install(self, model_path_id_or_url: Union[str,Path]): + # A little hack to allow nested routines to retrieve info on the requested ID + self.current_id = model_path_id_or_url + + path = Path(model_path_id_or_url) + + # checkpoint file, or similar + if path.is_file(): + self._install_path(path) + return + + # folders style or similar + if path.is_dir() and any([(path/x).exists() for x in ['config.json','model_index.json','learned_embeds.bin']]): + self._install_path(path) + return + + # recursive scan + if path.is_dir(): + for child in path.iterdir(): + self.heuristic_install(child) + return + + # huggingface repo + parts = str(path).split('/') + if len(parts) == 2: + self._install_repo(str(path)) + return + + # a URL + if model_path_id_or_url.startswith(("http:", "https:", "ftp:")): + self._install_url(model_path_id_or_url) + return + + logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping') + + # install a model from a local path. The optional info parameter is there to prevent + # the model from being probed twice in the event that it has already been probed. + def _install_path(self, path: Path, info: ModelProbeInfo=None): + try: + info = info or ModelProbe().heuristic_probe(path,self.prediction_helper) + if info.model_type == ModelType.Pipeline: + attributes = self._make_attributes(path,info) + self.mgr.add_model(model_name = path.stem if info.format=='checkpoint' else path.name, + base_model = info.base_type, + model_type = info.model_type, + model_attributes = attributes + ) + except Exception as e: + logger.warning(f'{str(e)} Skipping registration.') + + def _install_url(self, url: str): + # copy to a staging area, probe, import and delete + with TemporaryDirectory(dir=self.config.models_path) as staging: + location = download_with_resume(url,Path(staging)) + if not location: + logger.error(f'Unable to download {url}. Skipping.') + info = ModelProbe().heuristic_probe(location) + dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name + models_path = shutil.move(location,dest) + + # staged version will be garbage-collected at this time + self._install_path(Path(models_path), info) + + def _get_model_name(self,path_name: str, location: Path)->str: + ''' + Calculate a name for the model - primitive implementation. + ''' + if key := self.reverse_paths.get(path_name): + (name, base, mtype) = ModelManager.parse_key(key) + return name + else: + return location.stem + + def _install_repo(self, repo_id: str): + hinfo = HfApi().model_info(repo_id) + + # we try to figure out how to download this most economically + # list all the files in the repo + files = [x.rfilename for x in hinfo.siblings] + + with TemporaryDirectory(dir=self.config.models_path) as staging: + staging = Path(staging) + if 'model_index.json' in files: + location = self._download_hf_pipeline(repo_id, staging) # pipeline + + elif 'pytorch_lora_weights.bin' in files: + location = self._download_hf_model(repo_id, ['pytorch_lora_weights.bin'], staging) # LoRA + + elif self.config.precision=='float16' and 'diffusion_pytorch_model.fp16.safetensors' in files: # vae, controlnet or some other standalone + files = ['config.json', 'diffusion_pytorch_model.fp16.safetensors'] + location = self._download_hf_model(repo_id, files, staging) + + elif 'diffusion_pytorch_model.safetensors' in files: + files = ['config.json', 'diffusion_pytorch_model.safetensors'] + location = self._download_hf_model(repo_id, files, staging) + + elif 'learned_embeds.bin' in files: + location = self._download_hf_model(repo_id, ['learned_embeds.bin'], staging) + + info = ModelProbe().heuristic_probe(location, self.prediction_helper) + dest = self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id,location) + if dest.exists(): + shutil.rmtree(dest) + shutil.copytree(location,dest) + self._install_path(dest, info) + + def _make_attributes(self, path: Path, info: ModelProbeInfo)->dict: + + # convoluted way to retrieve the description from datasets + description = f'{info.base_type.value} {info.model_type.value} model' + if key := self.reverse_paths.get(self.current_id): + if key in self.datasets: + description = self.datasets[key]['description'] + + attributes = dict( + path = str(path), + description = str(description), + format = info.format, + ) + if info.model_type == ModelType.Pipeline: + attributes.update( + dict( + variant = info.variant_type, + ) + ) + if info.base_type == BaseModelType.StableDiffusion2: + attributes.update( + dict( + prediction_type = info.prediction_type, + upcast_attention = info.prediction_type == SchedulerPredictionType.VPrediction, + ) ) - except KeyboardInterrupt: - sys.exit(-1) - except Exception: - pass + if info.format=="checkpoint": + try: + legacy_conf = LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type] if BaseModelType.StableDiffusion2 \ + else LEGACY_CONFIGS[info.base_type][info.variant_type] + except KeyError: + legacy_conf = 'v1-inference.yaml' # best guess + + attributes.update( + dict( + config = str(self.config.legacy_conf_path / legacy_conf) + ) + ) + return attributes - if scan_at_startup and scan_directory.is_dir(): - update_autoconvert_dir(scan_directory) - else: - update_autoconvert_dir(None) + def _download_hf_pipeline(self, repo_id: str, staging: Path)->Path: + ''' + This retrieves a StableDiffusion model from cache or remote and then + does a save_pretrained() to the indicated staging area. + ''' + _,name = repo_id.split("/") + revisions = ['fp16','main'] if self.config.precision=='float16' else ['main'] + model = None + for revision in revisions: + try: + model = StableDiffusionPipeline.from_pretrained(repo_id,revision=revision,safety_checker=None) + except: # most errors are due to fp16 not being present. Fix this to catch other errors + pass + if model: + break + if not model: + logger.error(f'Diffusers model {repo_id} could not be downloaded. Skipping.') + return None + model.save_pretrained(staging / name, safe_serialization=True) + return staging / name + + def _download_hf_model(self, repo_id: str, files: List[str], staging: Path)->Path: + _,name = repo_id.split("/") + location = staging / name + paths = list() + for filename in files: + p = hf_download_with_resume(repo_id, + model_dir=location, + model_name=filename, + access_token = self.access_token + ) + if p: + paths.append(p) + else: + logger.warning(f'Could not download {filename} from {repo_id}.') + + return location if len(paths)>0 else None + + @classmethod + def _reverse_paths(cls,datasets)->dict: + ''' + Reverse mapping from repo_id/path to destination name. + ''' + return {v.get('path') or v.get('repo_id') : k for k, v in datasets.items()} def update_autoconvert_dir(autodir: Path): ''' @@ -249,89 +401,7 @@ def yes_or_no(prompt: str, default_yes=True): return response[0] in ("y", "Y") # --------------------------------------------- -def recommended_datasets() -> List['str']: - datasets = set() - for ds in initial_models().keys(): - if initial_models()[ds].get("recommended", False): - datasets.add(ds) - return list(datasets) - -# --------------------------------------------- -def default_dataset() -> dict: - datasets = set() - for ds in initial_models().keys(): - if initial_models()[ds].get("default", False): - datasets.add(ds) - return list(datasets) - - -# --------------------------------------------- -def all_datasets() -> dict: - datasets = dict() - for ds in initial_models().keys(): - datasets[ds] = True - return datasets - - -# --------------------------------------------- -# look for legacy model.ckpt in models directory and offer to -# normalize its name -def migrate_models_ckpt(): - model_path = os.path.join(config.root_dir, Model_dir, Weights_dir) - if not os.path.exists(os.path.join(model_path, "model.ckpt")): - return - new_name = initial_models()["stable-diffusion-1.4"]["file"] - logger.warning( - 'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.' - ) - logger.warning(f"model.ckpt => {new_name}") - os.replace( - os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name) - ) - - -# --------------------------------------------- -def download_weight_datasets( - models: List[str], access_token: str, precision: str = "float32" -): - migrate_models_ckpt() - successful = dict() - for mod in models: - logger.info(f"Downloading {mod}:") - successful[mod] = _download_repo_or_file( - initial_models()[mod], access_token, precision=precision - ) - return successful - - -def _download_repo_or_file( - mconfig: DictConfig, access_token: str, precision: str = "float32" -) -> Path: - path = None - if mconfig["format"] == "ckpt": - path = _download_ckpt_weights(mconfig, access_token) - else: - path = _download_diffusion_weights(mconfig, access_token, precision=precision) - if "vae" in mconfig and "repo_id" in mconfig["vae"]: - _download_diffusion_weights( - mconfig["vae"], access_token, precision=precision - ) - return path - -def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path: - repo_id = mconfig["repo_id"] - filename = mconfig["file"] - cache_dir = os.path.join(config.root_dir, Model_dir, Weights_dir) - return hf_download_with_resume( - repo_id=repo_id, - model_dir=cache_dir, - model_name=filename, - access_token=access_token, - ) - - -# --------------------------------------------- -def download_from_hf( +def hf_download_from_pretrained( model_class: object, model_name: str, destination: Path, **kwargs ): logger = InvokeAILogger.getLogger('InvokeAI') @@ -345,35 +415,6 @@ def download_from_hf( model.save_pretrained(destination, safe_serialization=True) return destination -def _download_diffusion_weights( - mconfig: DictConfig, access_token: str, precision: str = "float32" -): - repo_id = mconfig["repo_id"] - model_class = ( - StableDiffusionGeneratorPipeline - if mconfig.get("format", None) == "diffusers" - else AutoencoderKL - ) - extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}] - path = None - for extra_args in extra_arg_list: - try: - path = download_from_hf( - model_class, - repo_id, - safety_checker=None, - **extra_args, - ) - except OSError as e: - if 'Revision Not Found' in str(e): - pass - else: - logger.error(str(e)) - if path: - break - return path - - # --------------------------------------------- def hf_download_with_resume( repo_id: str, @@ -432,128 +473,3 @@ def hf_download_with_resume( return model_dest -# --------------------------------------------- -def update_config_file(successfully_downloaded: dict, config_file: Path): - config_file = ( - Path(config_file) if config_file is not None else default_config_file() - ) - - # In some cases (incomplete setup, etc), the default configs directory might be missing. - # Create it if it doesn't exist. - # this check is ignored if opt.config_file is specified - user is assumed to know what they - # are doing if they are passing a custom config file from elsewhere. - if config_file is default_config_file() and not config_file.parent.exists(): - configs_src = Dataset_path.parent - configs_dest = default_config_file().parent - shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True) - - yaml = new_config_file_contents(successfully_downloaded, config_file) - - try: - backup = None - if os.path.exists(config_file): - logger.warning( - f"{config_file.name} exists. Renaming to {config_file.stem}.yaml.orig" - ) - backup = config_file.with_suffix(".yaml.orig") - ## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183 - if sys.platform == "win32" and backup.is_file(): - backup.unlink() - config_file.rename(backup) - - with TemporaryFile() as tmp: - tmp.write(Config_preamble.encode()) - tmp.write(yaml.encode()) - - with open(str(config_file.expanduser().resolve()), "wb") as new_config: - tmp.seek(0) - new_config.write(tmp.read()) - - except Exception as e: - logger.error(f"Error creating config file {config_file}: {str(e)}") - if backup is not None: - logger.info("restoring previous config file") - ## workaround, for WinError 183, see above - if sys.platform == "win32" and config_file.is_file(): - config_file.unlink() - backup.rename(config_file) - return - - logger.info(f"Successfully created new configuration file {config_file}") - - -# --------------------------------------------- -def new_config_file_contents( - successfully_downloaded: dict, - config_file: Path, -) -> str: - if config_file.exists(): - conf = OmegaConf.load(str(config_file.expanduser().resolve())) - else: - conf = OmegaConf.create() - - default_selected = None - for model in successfully_downloaded: - # a bit hacky - what we are doing here is seeing whether a checkpoint - # version of the model was previously defined, and whether the current - # model is a diffusers (indicated with a path) - if conf.get(model) and Path(successfully_downloaded[model]).is_dir(): - delete_weights(model, conf[model]) - - stanza = {} - mod = initial_models()[model] - stanza["description"] = mod["description"] - stanza["repo_id"] = mod["repo_id"] - stanza["format"] = mod["format"] - # diffusers don't need width and height (probably .ckpt doesn't either) - # so we no longer require these in INITIAL_MODELS.yaml - if "width" in mod: - stanza["width"] = mod["width"] - if "height" in mod: - stanza["height"] = mod["height"] - if "file" in mod: - stanza["weights"] = os.path.relpath( - successfully_downloaded[model], start=config.root_dir - ) - stanza["config"] = os.path.normpath( - os.path.join(sd_configs(), mod["config"]) - ) - if "vae" in mod: - if "file" in mod["vae"]: - stanza["vae"] = os.path.normpath( - os.path.join(Model_dir, Weights_dir, mod["vae"]["file"]) - ) - else: - stanza["vae"] = mod["vae"] - if mod.get("default", False): - stanza["default"] = True - default_selected = True - - conf[model] = stanza - - # if no default model was chosen, then we select the first - # one in the list - if not default_selected: - conf[list(successfully_downloaded.keys())[0]]["default"] = True - - return OmegaConf.to_yaml(conf) - - -# --------------------------------------------- -def delete_weights(model_name: str, conf_stanza: dict): - if not (weights := conf_stanza.get("weights")): - return - if re.match("/VAE/", conf_stanza.get("config")): - return - - logger.warning( - f"\nThe checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?" - ) - - weights = Path(weights) - if not weights.is_absolute(): - weights = config.root_dir / weights - try: - weights.unlink() - except OSError as e: - logger.error(str(e)) diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index aea7b417a1..fb3b20a20a 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -4,3 +4,4 @@ Initialization file for invokeai.backend.model_management from .model_manager import ModelManager, ModelInfo from .model_cache import ModelCache from .models import BaseModelType, ModelType, SubModelType, ModelVariantType + diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 79c6573f4f..7a7a765fd3 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -682,7 +682,7 @@ class ModelManager(object): for model_key, model_config in list(self.models.items()): model_name, base_model, model_type = self.parse_key(model_key) - model_path = str(self.globals.root / model_config.path) + model_path = str(self.globals.root_path / model_config.path) if not os.path.exists(model_path): model_class = MODEL_CLASSES[base_model][model_type] if model_class.save_to_config: @@ -703,13 +703,14 @@ class ModelManager(object): for entry_name in os.listdir(models_dir): model_path = os.path.join(models_dir, entry_name) if model_path not in loaded_files: # TODO: check - model_name = Path(model_path).stem + model_path = Path(model_path) + model_name = model_path.name if model_path.is_dir else model_path.stem model_key = self.create_key(model_name, base_model, model_type) if model_key in self.models: raise Exception(f"Model with key {model_key} added twice") - model_config: ModelConfigBase = model_class.probe_config(model_path) + model_config: ModelConfigBase = model_class.probe_config(str(model_path)) self.models[model_key] = model_config new_models_found = True diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index 54fac5cde1..59e0c8e970 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -15,13 +15,13 @@ import invokeai.backend.util.logging as logger from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings @dataclass -class ModelVariantInfo(object): +class ModelProbeInfo(object): model_type: ModelType base_type: BaseModelType variant_type: ModelVariantType prediction_type: SchedulerPredictionType upcast_attention: bool - format: Literal['folder','checkpoint'] + format: Literal['diffusers','checkpoint'] image_size: int class ProbeBase(object): @@ -31,7 +31,7 @@ class ProbeBase(object): class ModelProbe(object): PROBES = { - 'folder': { }, + 'diffusers': { }, 'checkpoint': { }, } @@ -43,7 +43,7 @@ class ModelProbe(object): @classmethod def register_probe(cls, - format: Literal['folder','file'], + format: Literal['diffusers','checkpoint'], model_type: ModelType, probe_class: ProbeBase): cls.PROBES[format][model_type] = probe_class @@ -51,8 +51,8 @@ class ModelProbe(object): @classmethod def heuristic_probe(cls, model: Union[Dict, ModelMixin, Path], - prediction_type_helper: Callable[[Path],BaseModelType]=None, - )->ModelVariantInfo: + prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, + )->ModelProbeInfo: if isinstance(model,Path): return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper) elif isinstance(model,(dict,ModelMixin,ConfigMixin)): @@ -64,7 +64,7 @@ class ModelProbe(object): def probe(cls, model_path: Path, model: Union[Dict, ModelMixin] = None, - prediction_type_helper: Callable[[Path],BaseModelType] = None)->ModelVariantInfo: + prediction_type_helper: Callable[[Path],SchedulerPredictionType] = None)->ModelProbeInfo: ''' Probe the model at model_path and return sufficient information about it to place it somewhere in the models directory hierarchy. If the model is @@ -74,14 +74,14 @@ class ModelProbe(object): between V2-Base and V2-768 SD models. ''' if model_path: - format = 'folder' if model_path.is_dir() else 'checkpoint' + format = 'diffusers' if model_path.is_dir() else 'checkpoint' else: - format = 'folder' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint' + format = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint' model_info = None try: model_type = cls.get_model_type_from_folder(model_path, model) \ - if format == 'folder' \ + if format == 'diffusers' \ else cls.get_model_type_from_checkpoint(model_path, model) probe_class = cls.PROBES[format].get(model_type) if not probe_class: @@ -90,7 +90,7 @@ class ModelProbe(object): base_type = probe.get_base_type() variant_type = probe.get_variant_type() prediction_type = probe.get_scheduler_prediction_type() - model_info = ModelVariantInfo( + model_info = ModelProbeInfo( model_type = model_type, base_type = base_type, variant_type = variant_type, @@ -196,7 +196,7 @@ class CheckpointProbeBase(ProbeBase): def __init__(self, checkpoint_path: Path, checkpoint: dict, - helper: Callable[[Path],BaseModelType] = None + helper: Callable[[Path],SchedulerPredictionType] = None )->BaseModelType: self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path) self.checkpoint_path = checkpoint_path @@ -405,11 +405,11 @@ class LoRAFolderProbe(FolderProbeBase): pass ############## register probe classes ###### -ModelProbe.register_probe('folder', ModelType.Pipeline, PipelineFolderProbe) -ModelProbe.register_probe('folder', ModelType.Vae, VaeFolderProbe) -ModelProbe.register_probe('folder', ModelType.Lora, LoRAFolderProbe) -ModelProbe.register_probe('folder', ModelType.TextualInversion, TextualInversionFolderProbe) -ModelProbe.register_probe('folder', ModelType.ControlNet, ControlNetFolderProbe) +ModelProbe.register_probe('diffusers', ModelType.Pipeline, PipelineFolderProbe) +ModelProbe.register_probe('diffusers', ModelType.Vae, VaeFolderProbe) +ModelProbe.register_probe('diffusers', ModelType.Lora, LoRAFolderProbe) +ModelProbe.register_probe('diffusers', ModelType.TextualInversion, TextualInversionFolderProbe) +ModelProbe.register_probe('diffusers', ModelType.ControlNet, ControlNetFolderProbe) ModelProbe.register_probe('checkpoint', ModelType.Pipeline, PipelineCheckpointProbe) ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe) ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe) diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index 3bf0045918..f18099b4e7 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -154,7 +154,6 @@ class ModelBase(metaclass=ABCMeta): def create_config(cls, **kwargs) -> ModelConfigBase: if "format" not in kwargs: raise Exception("Field 'format' not found in model config") - configs = cls._get_configs() return configs[kwargs["format"]](**kwargs) diff --git a/invokeai/configs/INITIAL_MODELS.yaml b/invokeai/configs/INITIAL_MODELS.yaml index ccb7ca09aa..cb16f3ed4b 100644 --- a/invokeai/configs/INITIAL_MODELS.yaml +++ b/invokeai/configs/INITIAL_MODELS.yaml @@ -3,6 +3,7 @@ sd-1/pipeline/stable-diffusion-v1-5: description: Stable Diffusion version 1.5 diffusers model (4.27 GB) repo_id: runwayml/stable-diffusion-v1-5 recommended: True + default: True sd-1/pipeline/stable-diffusion-inpainting: description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB) repo_id: runwayml/stable-diffusion-inpainting @@ -27,7 +28,7 @@ sd-1/pipeline/Dungeons-and-Diffusion: description: Dungeons & Dragons characters (2.13 GB) repo_id: 0xJustin/Dungeons-and-Diffusion recommended: False -sd-1/pipeline/dreamlike-photoreal-2.0: +sd-1/pipeline/dreamlike-photoreal-2: description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB) repo_id: dreamlike-art/dreamlike-photoreal-2.0 recommended: False diff --git a/invokeai/configs/stable-diffusion/v2-inpainting-inference-v.yaml b/invokeai/configs/stable-diffusion/v2-inpainting-inference-v.yaml new file mode 100644 index 0000000000..37cda460aa --- /dev/null +++ b/invokeai/configs/stable-diffusion/v2-inpainting-inference-v.yaml @@ -0,0 +1,159 @@ +model: + base_learning_rate: 5.0e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + parameterization: "v" + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + +data: + target: ldm.data.laion.WebDataModuleFromConfig + params: + tar_base: null # for concat as in LAION-A + p_unsafe_threshold: 0.1 + filter_word_list: "data/filters.yaml" + max_pwatermark: 0.45 + batch_size: 8 + num_workers: 6 + multinode: True + min_size: 512 + train: + shards: + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar" + shuffle: 10000 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + # NOTE use enough shards to avoid empty validation loops in workers + validation: + shards: + - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - " + shuffle: 0 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.CenterCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + +lightning: + find_unused_parameters: True + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 10000 + + image_logger: + target: main.ImageLogger + params: + enable_autocast: False + disabled: False + batch_frequency: 1000 + max_images: 4 + increase_log_steps: False + log_first_step: False + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 5.0 + unconditional_guidance_label: [""] + ddim_steps: 50 # todo check these out for depth2img, + ddim_eta: 0.0 # todo check these out for depth2img, + + trainer: + benchmark: True + val_check_interval: 5000000 + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 \ No newline at end of file diff --git a/invokeai/configs/stable-diffusion/v2-inpainting-inference.yaml b/invokeai/configs/stable-diffusion/v2-inpainting-inference.yaml new file mode 100644 index 0000000000..5aaf13162d --- /dev/null +++ b/invokeai/configs/stable-diffusion/v2-inpainting-inference.yaml @@ -0,0 +1,158 @@ +model: + base_learning_rate: 5.0e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + +data: + target: ldm.data.laion.WebDataModuleFromConfig + params: + tar_base: null # for concat as in LAION-A + p_unsafe_threshold: 0.1 + filter_word_list: "data/filters.yaml" + max_pwatermark: 0.45 + batch_size: 8 + num_workers: 6 + multinode: True + min_size: 512 + train: + shards: + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar" + shuffle: 10000 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + # NOTE use enough shards to avoid empty validation loops in workers + validation: + shards: + - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - " + shuffle: 0 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.CenterCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + +lightning: + find_unused_parameters: True + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 10000 + + image_logger: + target: main.ImageLogger + params: + enable_autocast: False + disabled: False + batch_frequency: 1000 + max_images: 4 + increase_log_steps: False + log_first_step: False + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 5.0 + unconditional_guidance_label: [""] + ddim_steps: 50 # todo check these out for depth2img, + ddim_eta: 0.0 # todo check these out for depth2img, + + trainer: + benchmark: True + val_check_interval: 5000000 + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 \ No newline at end of file diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index 1753364f64..80ddebca84 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -11,7 +11,6 @@ The work is actually done in backend code in model_install_backend.py. import argparse import curses -import os import sys import textwrap import traceback @@ -20,27 +19,21 @@ from multiprocessing import Process from multiprocessing.connection import Connection, Pipe from pathlib import Path from shutil import get_terminal_size -from typing import List import logging import npyscreen import torch from npyscreen import widget -from omegaconf import OmegaConf from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.install.model_install_backend import ( - Dataset_path, # most of these should go!! - default_config_file, - default_dataset, - install_requested_models, - recommended_datasets, ModelInstallList, - UserSelections, - ModelInstall + InstallSelections, + ModelInstall, + SchedulerPredictionType, ) -from invokeai.backend.model_management import ModelManager, BaseModelType, ModelType +from invokeai.backend.model_management import ModelManager, ModelType from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.frontend.install.widgets import ( CenteredTitleText, @@ -133,7 +126,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): bottom_of_table = self.nextrely self.nextrely = top_of_table - self.pipeline_models = self.add_model_widgets( + self.pipeline_models = self.add_pipeline_widgets( model_type=ModelType.Pipeline, window_width=window_width, exclude = self.starter_models @@ -210,11 +203,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): starters = self.starter_models starter_model_labels = self.model_labels - recommended_models = set([ - x - for x in starters - if models[x].recommended - ]) self.installed_models = sorted( [x for x in starters if models[x].installed] ) @@ -312,16 +300,18 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): return widgets ### Tab for arbitrary diffusers widgets ### - def add_diffusers_widgets(self, - model_type: ModelType=ModelType.Pipeline, - window_width: int=120, - )->dict[str,npyscreen.widget]: + def add_pipeline_widgets(self, + model_type: ModelType=ModelType.Pipeline, + window_width: int=120, + **kwargs, + )->dict[str,npyscreen.widget]: '''Similar to add_model_widgets() but adds some additional widgets at the bottom to support the autoload directory''' widgets = self.add_model_widgets( model_type = model_type, window_width = window_width, install_prompt=f"Additional {model_type.value.title()} models already installed.", + **kwargs, ) label = "Directory to scan for models to automatically import ( autocompletes):" @@ -428,7 +418,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): target = process_and_execute, kwargs=dict( opt = app.program_opts, - selections = app.user_selections, + selections = app.install_selections, conn_out = child_conn, ) ) @@ -436,8 +426,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): child_conn.close() self.subprocess_connection = parent_conn self.subprocess = p - app.user_selections = UserSelections() - # process_and_execute(app.opt, app.user_selections) + app.install_selections = InstallSelections() + # process_and_execute(app.opt, app.install_selections) def on_back(self): self.parentApp.switchFormPrevious() @@ -453,7 +443,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): self.parentApp.setNextForm(None) self.parentApp.user_cancelled = False self.editing = False - + ########## This routine monitors the child process that is performing model installation and removal ##### def while_waiting(self): '''Called during idle periods. Main task is to update the Log Messages box with messages @@ -532,73 +522,24 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): .autoscan_on_startup: True if invokeai should scan and import at startup time .import_model_paths: list of URLs, repo_ids and file paths to import """ - # we're using a global here rather than storing the result in the parentapp - # due to some bug in npyscreen that is causing attributes to be lost - selections = self.parentApp.user_selections + selections = self.parentApp.install_selections + all_models = self.all_models - # Starter models to install/remove - # TO DO - turn these into a dict so we don't have to hard-code the attributes - print(f'installed={[x for x in self.all_models if self.all_models[x].installed]}',file=f) - for section in [self.starter_pipelines, self.pipeline_models, - self.controlnet_models, self.lora_models, self.ti_models]: + # Defined models (in INITIAL_CONFIG.yaml or models.yaml) to add/remove + ui_sections = [self.starter_pipelines, self.pipeline_models, + self.controlnet_models, self.lora_models, self.ti_models] + for section in ui_sections: selected = set([section['models'][x] for x in section['models_selected'].value]) models_to_install = [x for x in selected if not self.all_models[x].installed] models_to_remove = [x for x in section['models'] if x not in selected and self.all_models[x].installed] - - # "More" models - selections.import_model_paths = self.pipeline_models['download_ids'].value.split() - if diffusers_selected := self.pipeline_models.get('models_selected'): - selections.remove_models.extend([x - for x in diffusers_selected.values - if self.installed_pipeline_models[x] - and diffusers_selected.values.index(x) not in diffusers_selected.value - ] - ) - - # TODO: REFACTOR THIS REPETITIVE CODE - if cn_models_selected := self.controlnet_models.get('models_selected'): - selections.install_cn_models = [cn_models_selected.values[x] - for x in cn_models_selected.value - if not self.installed_cn_models[cn_models_selected.values[x]] - ] - selections.remove_cn_models = [x - for x in cn_models_selected.values - if self.installed_cn_models[x] - and cn_models_selected.values.index(x) not in cn_models_selected.value - ] - if (additional_cns := self.controlnet_models['download_ids'].value.split()): - valid_cns = [x for x in additional_cns if '/' in x] - selections.install_cn_models.extend(valid_cns) + selections.remove_models.extend(models_to_remove) + selections.install_models.extend(all_models[x].path or all_models[x].repo_id \ + for x in models_to_install if all_models[x].path or all_models[x].repo_id) - # same thing, for LoRAs - if loras_selected := self.lora_models.get('models_selected'): - selections.install_lora_models = [loras_selected.values[x] - for x in loras_selected.value - if not self.installed_lora_models[loras_selected.values[x]] - ] - selections.remove_lora_models = [x - for x in loras_selected.values - if self.installed_lora_models[x] - and loras_selected.values.index(x) not in loras_selected.value - ] - if (additional_loras := self.lora_models['download_ids'].value.split()): - selections.install_lora_models.extend(additional_loras) - - # same thing, for TIs - # TODO: refactor - if tis_selected := self.ti_models.get('models_selected'): - selections.install_ti_models = [tis_selected.values[x] - for x in tis_selected.value - if not self.installed_ti_models[tis_selected.values[x]] - ] - selections.remove_ti_models = [x - for x in tis_selected.values - if self.installed_ti_models[x] - and tis_selected.values.index(x) not in tis_selected.value - ] - - if (additional_tis := self.ti_models['download_ids'].value.split()): - selections.install_ti_models.extend(additional_tis) + # models located in the 'download_ids" section + for section in ui_sections: + if downloads := section.get('download_ids'): + selections.install_models.extend(downloads.value.split()) # load directory and whether to scan on startup selections.scan_directory = self.pipeline_models['autoload_directory'].value @@ -609,7 +550,7 @@ class AddModelApplication(npyscreen.NPSAppManaged): super().__init__() self.program_opts = opt self.user_cancelled = False - self.user_selections = UserSelections() + self.install_selections = InstallSelections() def onStart(self): npyscreen.setTheme(npyscreen.Themes.DefaultTheme) @@ -628,21 +569,17 @@ class StderrToMessage(): pass # -------------------------------------------------------- -def ask_user_for_config_file(model_path: Path, - tui_conn: Connection=None - )->Path: +def ask_user_for_prediction_type(model_path: Path, + tui_conn: Connection=None + )->Path: if tui_conn: logger.debug('Waiting for user response...') - return _ask_user_for_cf_tui(model_path, tui_conn) + return _ask_user_for_pt_tui(model_path, tui_conn) else: - return _ask_user_for_cf_cmdline(model_path) + return _ask_user_for_pt_cmdline(model_path) -def _ask_user_for_cf_cmdline(model_path): - choices = [ - config.legacy_conf_path / x - for x in ['v2-inference.yaml','v2-inference-v.yaml'] - ] - choices.extend([None]) +def _ask_user_for_pt_cmdline(model_path): + choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None] print( f""" Please select the type of the V2 checkpoint named {model_path.name}: @@ -664,7 +601,7 @@ Please select the type of the V2 checkpoint named {model_path.name}: return return choice -def _ask_user_for_cf_tui(model_path: Path, tui_conn: Connection)->Path: +def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->Path: try: tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8')) # note that we don't do any status checking here @@ -672,20 +609,20 @@ def _ask_user_for_cf_tui(model_path: Path, tui_conn: Connection)->Path: if response is None: return None elif response == 'epsilon': - return config.legacy_conf_path / 'v2-inference.yaml' + return SchedulerPredictionType.epsilon elif response == 'v': - return config.legacy_conf_path / 'v2-inference-v.yaml' + return SchedulerPredictionType.VPrediction elif response == 'abort': logger.info('Conversion aborted') return None else: - return Path(response) + return response except: return None # -------------------------------------------------------- def process_and_execute(opt: Namespace, - selections: UserSelections, + selections: InstallSelections, conn_out: Connection=None, ): # set up so that stderr is sent to conn_out @@ -696,34 +633,14 @@ def process_and_execute(opt: Namespace, logger = InvokeAILogger.getLogger() logger.handlers.clear() logger.addHandler(logging.StreamHandler(translator)) - - models_to_install = selections.install_models - models_to_remove = selections.remove_models - directory_to_scan = selections.scan_directory - scan_at_startup = selections.autoscan_on_startup - potential_models_to_install = selections.import_model_paths - name_map = selections.model_name_map - install_requested_models( - diffusers = ModelInstallList(models_to_install, [name_map[ModelType.Pipeline][x] for x in models_to_remove]), - controlnet = ModelInstallList(selections.install_cn_models, [name_map[ModelType.ControlNet][x] for x in selections.remove_cn_models]), - lora = ModelInstallList(selections.install_lora_models, [name_map[ModelType.Lora][x] for x in selections.remove_lora_models]), - ti = ModelInstallList(selections.install_ti_models, [name_map[ModelType.TextualInversion][x] for x in selections.remove_ti_models]), - scan_directory=Path(directory_to_scan) if directory_to_scan else None, - external_models=potential_models_to_install, - scan_at_startup=scan_at_startup, - precision="float32" - if opt.full_precision - else choose_precision(torch.device(choose_torch_device())), - config_file_path=Path(opt.config_file) if opt.config_file else config.model_conf_path, - model_config_file_callback = lambda x: ask_user_for_config_file(x,conn_out) - ) + installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x,conn_out)) + installer.install(selections) if conn_out: conn_out.send_bytes('*done*'.encode('utf-8')) conn_out.close() - def do_listings(opt)->bool: """List installed models of various sorts, and return True if any were requested.""" @@ -754,38 +671,34 @@ def select_and_download_models(opt: Namespace): if opt.full_precision else choose_precision(torch.device(choose_torch_device())) ) - - if do_listings(opt): - pass - # this processes command line additions/removals - elif opt.diffusers or opt.controlnets or opt.textual_inversions or opt.loras: - action = 'remove_models' if opt.delete else 'install_models' - diffusers_args = {'diffusers':ModelInstallList(remove_models=opt.diffusers or [])} \ - if opt.delete \ - else {'external_models':opt.diffusers or []} - install_requested_models( - **diffusers_args, - controlnet=ModelInstallList(**{action:opt.controlnets or []}), - ti=ModelInstallList(**{action:opt.textual_inversions or []}), - lora=ModelInstallList(**{action:opt.loras or []}), - precision=precision, - model_config_file_callback=lambda x: ask_user_for_config_file(x), + config.precision = precision + helper = lambda x: ask_user_for_prediction_type(x) + # if do_listings(opt): + # pass + + installer = ModelInstall(config, prediction_type_helper=helper) + if opt.add or opt.delete: + selections = InstallSelections( + install_models = opt.add or [], + remove_models = opt.delete or [] ) + installer.install(selections) elif opt.default_only: - install_requested_models( - diffusers=ModelInstallList(install_models=default_dataset()), - precision=precision, + selections = InstallSelections( + install_models = installer.default_model() ) + installer.install(selections) elif opt.yes_to_all: - install_requested_models( - diffusers=ModelInstallList(install_models=recommended_datasets()), - precision=precision, + selections = InstallSelections( + install_models = installer.recommended_models() ) + installer.install(selections) # this is where the TUI is called else: # needed because the torch library is loaded, even though we don't use it - torch.multiprocessing.set_start_method("spawn") + # currently commented out because it has started generating errors (?) + # torch.multiprocessing.set_start_method("spawn") # the third argument is needed in the Windows 11 environment in # order to launch and resize a console window running this program @@ -801,35 +714,20 @@ def select_and_download_models(opt: Namespace): installApp.main_form.subprocess.terminate() installApp.main_form.subprocess = None raise e - process_and_execute(opt, installApp.user_selections) + process_and_execute(opt, installApp.install_selections) # ------------------------------------- def main(): parser = argparse.ArgumentParser(description="InvokeAI model downloader") parser.add_argument( - "--diffusers", + "--add", nargs="*", - help="List of URLs or repo_ids of diffusers to install/delete", - ) - parser.add_argument( - "--loras", - nargs="*", - help="List of URLs or repo_ids of LoRA/LyCORIS models to install/delete", - ) - parser.add_argument( - "--controlnets", - nargs="*", - help="List of URLs or repo_ids of controlnet models to install/delete", - ) - parser.add_argument( - "--textual-inversions", - nargs="*", - help="List of URLs or repo_ids of textual inversion embeddings to install/delete", + help="List of URLs, local paths or repo_ids of models to install", ) parser.add_argument( "--delete", - action="store_true", - help="Delete models listed on command line rather than installing them", + nargs="*", + help="List of names of models to idelete", ) parser.add_argument( "--full-precision", @@ -849,7 +747,7 @@ def main(): parser.add_argument( "--default_only", action="store_true", - help="only install the default model", + help="Only install the default model", ) parser.add_argument( "--list-models", diff --git a/invokeai/frontend/install/widgets.py b/invokeai/frontend/install/widgets.py index 14167d4ee0..5ef7f6924e 100644 --- a/invokeai/frontend/install/widgets.py +++ b/invokeai/frontend/install/widgets.py @@ -17,8 +17,8 @@ from shutil import get_terminal_size from curses import BUTTON2_CLICKED,BUTTON3_CLICKED # minimum size for UIs -MIN_COLS = 120 -MIN_LINES = 50 +MIN_COLS = 180 +MIN_LINES = 55 # ------------------------------------- def set_terminal_size(columns: int, lines: int, launch_command: str=None): @@ -384,7 +384,6 @@ def select_stable_diffusion_config_file( "An SD v2.x base model (512 pixels; no 'parameterization:' line in its yaml file)", "An SD v2.x v-predictive model (768 pixels; 'parameterization: \"v\"' line in its yaml file)", "Skip installation for now and come back later", - "Enter config file path manually", ] F = ConfirmCancelPopup( @@ -406,35 +405,17 @@ def select_stable_diffusion_config_file( mlw.values = message choice = F.add( - SingleSelectWithChanged, + npyscreen.SelectOne, values = options, value = [0], max_height = len(options)+1, scroll_exit=True, ) - file = F.add( - FileBox, - name='Path to config file', - max_height=3, - hidden=True, - must_exist=True, - scroll_exit=True - ) - - def toggle_visible(value): - value = value[0] - if value==3: - file.hidden=False - else: - file.hidden=True - F.display() - - choice.on_changed = toggle_visible F.editw = 1 F.edit() if not F.value: return None - assert choice.value[0] in range(0,4),'invalid choice' - choices = ['epsilon','v','abort',file.value] + assert choice.value[0] in range(0,3),'invalid choice' + choices = ['epsilon','v','abort'] return choices[choice.value[0]] diff --git a/scripts/migrate_models_to_3.0.py b/scripts/migrate_models_to_3.0.py index 2d498df237..23db6d63da 100644 --- a/scripts/migrate_models_to_3.0.py +++ b/scripts/migrate_models_to_3.0.py @@ -26,7 +26,7 @@ from transformers import ( import invokeai.backend.util.logging as logger from invokeai.backend.model_management import ModelManager from invokeai.backend.model_management.model_probe import ( - ModelProbe, ModelType, BaseModelType, SchedulerPredictionType, ModelVariantInfo + ModelProbe, ModelType, BaseModelType, SchedulerPredictionType, ModelProbeInfo ) warnings.filterwarnings("ignore") @@ -171,13 +171,13 @@ def migrate_tuning_models(dest: Path): logger.info(f'Scanning {subdir}') migrate_models(src, dest) -def write_yaml(model_name: str, path:Path, info:ModelVariantInfo, dest_yaml: io.TextIOBase): +def write_yaml(model_name: str, path:Path, info:ModelProbeInfo, dest_yaml: io.TextIOBase): name = unique_name(model_name, info) stanza = { f'{info.base_type.value}/{info.model_type.value}/{name}': { 'name': model_name, 'path': str(path), - 'description': f'diffusers model {model_name}', + 'description': f'A {info.base_type.value} {info.model_type.value} model', 'format': 'diffusers', 'image_size': info.image_size, 'base': info.base_type.value, @@ -266,7 +266,7 @@ def migrate_checkpoints(dest_dir: Path, dest_yaml: io.TextIOBase): { 'name': model_name, 'path': str(weights), - 'description': f'checkpoint model {model_name}', + 'description': f'{info.base_type.value}-based checkpoint', 'format': 'checkpoint', 'image_size': info.image_size, 'base': info.base_type.value,