configure/install basically working; needs edge case testing

This commit is contained in:
Lincoln Stein 2023-06-16 22:54:36 -04:00
parent ada7399753
commit f28d50070e
12 changed files with 701 additions and 588 deletions

View File

@ -56,11 +56,10 @@ from invokeai.frontend.install.widgets import (
) )
from invokeai.backend.install.legacy_arg_parsing import legacy_parser from invokeai.backend.install.legacy_arg_parsing import legacy_parser
from invokeai.backend.install.model_install_backend import ( from invokeai.backend.install.model_install_backend import (
default_dataset, hf_download_from_pretrained,
download_from_hf,
hf_download_with_resume, hf_download_with_resume,
recommended_datasets, InstallSelections,
UserSelections, ModelInstall,
) )
from invokeai.backend.model_management.model_probe import ( from invokeai.backend.model_management.model_probe import (
ModelProbe, ModelType, BaseModelType, SchedulerPredictionType ModelProbe, ModelType, BaseModelType, SchedulerPredictionType
@ -198,8 +197,8 @@ def download_conversion_models():
# sd-1 # sd-1
repo_id = 'openai/clip-vit-large-patch14' repo_id = 'openai/clip-vit-large-patch14'
download_from_hf(CLIPTokenizer, repo_id, target_dir / 'clip-vit-large-patch14') hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / 'clip-vit-large-patch14')
download_from_hf(CLIPTextModel, repo_id, target_dir / 'clip-vit-large-patch14') hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / 'clip-vit-large-patch14')
# sd-2 # sd-2
repo_id = "stabilityai/stable-diffusion-2" repo_id = "stabilityai/stable-diffusion-2"
@ -275,8 +274,8 @@ def download_clipseg():
logger.info("Installing clipseg model for text-based masking...") logger.info("Installing clipseg model for text-based masking...")
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined" CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
try: try:
download_from_hf(AutoProcessor, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg') hf_download_from_pretrained(AutoProcessor, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
download_from_hf(CLIPSegForImageSegmentation, CLIPSEG_MODEL,'models/core/misc/clipseg') hf_download_from_pretrained(CLIPSegForImageSegmentation, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
except Exception: except Exception:
logger.info("Error installing clipseg model:") logger.info("Error installing clipseg model:")
logger.info(traceback.format_exc()) logger.info(traceback.format_exc())
@ -592,7 +591,7 @@ class EditOptApplication(npyscreen.NPSAppManaged):
self.program_opts = program_opts self.program_opts = program_opts
self.invokeai_opts = invokeai_opts self.invokeai_opts = invokeai_opts
self.user_cancelled = False self.user_cancelled = False
self.user_selections = default_user_selections(program_opts) self.install_selections = default_user_selections(program_opts)
def onStart(self): def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme) npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
@ -627,19 +626,19 @@ def default_startup_options(init_file: Path) -> Namespace:
opts.nsfw_checker = True opts.nsfw_checker = True
return opts return opts
def default_user_selections(program_opts: Namespace) -> UserSelections: def default_user_selections(program_opts: Namespace) -> InstallSelections:
return UserSelections( installer = ModelInstall(config)
install_models=default_dataset() models = installer.all_models()
return InstallSelections(
install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id]
if program_opts.default_only if program_opts.default_only
else recommended_datasets() else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
if program_opts.yes_to_all if program_opts.yes_to_all
else dict(), else list(),
purge_deleted_models=False,
scan_directory=None, scan_directory=None,
autoscan_on_startup=None, autoscan_on_startup=None,
) )
# ------------------------------------- # -------------------------------------
def initialize_rootdir(root: Path, yes_to_all: bool = False): def initialize_rootdir(root: Path, yes_to_all: bool = False):
logger.info("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **") logger.info("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
@ -696,7 +695,7 @@ def run_console_ui(
if editApp.user_cancelled: if editApp.user_cancelled:
return (None, None) return (None, None)
else: else:
return (editApp.new_opts, editApp.user_selections) return (editApp.new_opts, editApp.install_selections)
# ------------------------------------- # -------------------------------------

View File

@ -2,18 +2,18 @@
Utility (backend) functions used by model_install.py Utility (backend) functions used by model_install.py
""" """
import os import os
import re
import shutil import shutil
import sys import sys
import traceback
import warnings import warnings
from dataclasses import dataclass,field from dataclasses import dataclass,field
from pathlib import Path from pathlib import Path
from tempfile import TemporaryFile from tempfile import TemporaryDirectory
from typing import List, Dict, Set, Callable from typing import List, Dict, Callable, Union, Set
import requests import requests
from diffusers import AutoencoderKL from diffusers import AutoencoderKL, StableDiffusionPipeline
from huggingface_hub import hf_hub_url, HfFolder from huggingface_hub import hf_hub_url, HfFolder, HfApi
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
from tqdm import tqdm from tqdm import tqdm
@ -21,7 +21,9 @@ from tqdm import tqdm
import invokeai.configs as configs import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
from invokeai.backend.util import download_with_resume
from ..stable_diffusion import StableDiffusionGeneratorPipeline from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util.logging import InvokeAILogger from ..util.logging import InvokeAILogger
@ -29,19 +31,11 @@ warnings.filterwarnings("ignore")
# --------------------------globals----------------------- # --------------------------globals-----------------------
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.getLogger(name='InvokeAI')
Model_dir = "models"
Weights_dir = "ldm/stable-diffusion-v1/"
# the initial "configs" dir is now bundled in the `invokeai.configs` package # the initial "configs" dir is now bundled in the `invokeai.configs` package
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml" Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
# initial models omegaconf
Datasets = None
# logger
logger = InvokeAILogger.getLogger(name='InvokeAI')
Config_preamble = """ Config_preamble = """
# This file describes the alternative machine learning models # This file describes the alternative machine learning models
# available to InvokeAI script. # available to InvokeAI script.
@ -52,6 +46,24 @@ Config_preamble = """
# was trained on. # was trained on.
""" """
LEGACY_CONFIGS = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: 'v1-inference.yaml',
ModelVariantType.Inpaint: 'v1-inpainting-inference.yaml',
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: 'v2-inference.yaml',
SchedulerPredictionType.VPrediction: 'v2-inference-v.yaml',
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: 'v2-inpainting-inference.yaml',
SchedulerPredictionType.VPrediction: 'v2-inpainting-inference-v.yaml',
}
}
}
@dataclass @dataclass
class ModelInstallList: class ModelInstallList:
'''Class for listing models to be installed/removed''' '''Class for listing models to be installed/removed'''
@ -59,18 +71,11 @@ class ModelInstallList:
remove_models: List[str] = field(default_factory=list) remove_models: List[str] = field(default_factory=list)
@dataclass @dataclass
class UserSelections(): class InstallSelections():
install_models: List[str]= field(default_factory=list) install_models: List[str]= field(default_factory=list)
remove_models: List[str]=field(default_factory=list) remove_models: List[str]=field(default_factory=list)
install_cn_models: List[str] = field(default_factory=list)
remove_cn_models: List[str] = field(default_factory=list)
install_lora_models: List[str] = field(default_factory=list)
remove_lora_models: List[str] = field(default_factory=list)
install_ti_models: List[str] = field(default_factory=list)
remove_ti_models: List[str] = field(default_factory=list)
scan_directory: Path = None scan_directory: Path = None
autoscan_on_startup: bool=False autoscan_on_startup: bool=False
import_model_paths: str=None
@dataclass @dataclass
class ModelLoadInfo(): class ModelLoadInfo():
@ -82,18 +87,30 @@ class ModelLoadInfo():
description: str = '' description: str = ''
installed: bool = False installed: bool = False
recommended: bool = False recommended: bool = False
default: bool = False
class ModelInstall(object): class ModelInstall(object):
def __init__(self,config:InvokeAIAppConfig): def __init__(self,
config:InvokeAIAppConfig,
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
access_token:str = None):
self.config = config self.config = config
self.mgr = ModelManager(config.model_conf_path) self.mgr = ModelManager(config.model_conf_path)
self.datasets = OmegaConf.load(Dataset_path) self.datasets = OmegaConf.load(Dataset_path)
self.prediction_helper = prediction_type_helper
self.access_token = access_token or HfFolder.get_token()
self.reverse_paths = self._reverse_paths(self.datasets)
def all_models(self)->Dict[str,ModelLoadInfo]: def all_models(self)->Dict[str,ModelLoadInfo]:
''' '''
Return dict of model_key=>ModelStatus Return dict of model_key=>ModelLoadInfo objects.
This method consolidates and simplifies the entries in both
models.yaml and INITIAL_MODELS.yaml so that they can
be treated uniformly. It also sorts the models alphabetically
by their name, to improve the display somewhat.
''' '''
model_dict = dict() model_dict = dict()
# first populate with the entries in INITIAL_MODELS.yaml # first populate with the entries in INITIAL_MODELS.yaml
for key, value in self.datasets.items(): for key, value in self.datasets.items():
name,base,model_type = ModelManager.parse_key(key) name,base,model_type = ModelManager.parse_key(key)
@ -128,102 +145,237 @@ class ModelInstall(object):
if model_type==ModelType.Pipeline: if model_type==ModelType.Pipeline:
models.add(key) models.add(key)
return models return models
def recommended_models(self)->Set[str]:
starters = self.starter_models()
return set([x for x in starters if self.datasets[x].get('recommended',False)])
def default_model(self)->str:
starters = self.starter_models()
defaults = [x for x in starters if self.datasets[x].get('default',False)]
return defaults[0]
def install(self, selections: InstallSelections):
job = 1
jobs = len(selections.remove_models) + len(selections.install_models)
if selections.scan_directory:
jobs += 1
# remove requested models
def default_config_file(): for key in selections.remove_models:
return config.model_conf_path name,base,mtype = self.mgr.parse_key(key)
logger.info(f'Deleting {mtype} model {name} [{job}/{jobs}]')
def sd_configs(): self.mgr.del_model(name,base,mtype)
return config.legacy_conf_path job += 1
def initial_models():
global Datasets
if Datasets:
return Datasets
return (Datasets := OmegaConf.load(Dataset_path)['diffusers'])
def add_models(model_manager, config_file_path: Path, models: List[tuple[str,str,str]]):
print(f'Installing {models}')
def del_models(model_manager, config_file_path: Path, models: List[tuple[str,str,str]]):
for base, model_type, name in models:
logger.info(f"Deleting {name}...")
model_manager.del_model(name, base, model_type)
model_manager.commit(config_file_path)
def install_requested_models(
diffusers: ModelInstallList = None,
controlnet: ModelInstallList = None,
lora: ModelInstallList = None,
ti: ModelInstallList = None,
cn_model_map: Dict[str,str] = None, # temporary - move to model manager
scan_directory: Path = None,
external_models: List[str] = None,
scan_at_startup: bool = False,
precision: str = "float16",
config_file_path: Path = None,
model_config_file_callback: Callable[[Path],Path] = None,
):
"""
Entry point for installing/deleting starter models, or installing external models.
"""
access_token = HfFolder.get_token()
config_file_path = config_file_path or default_config_file()
if not config_file_path.exists():
open(config_file_path, "w")
# prevent circular import here
from ..model_management import ModelManager
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
for x in [controlnet, lora, ti, diffusers]:
if x:
add_models(model_manager, config_file_path, x.install_models)
del_models(model_manager, config_file_path, x.remove_models)
# if diffusers: # add requested models
for path in selections.install_models:
logger.info(f'Installing {path} [{job}/{jobs}]')
self.heuristic_install(path)
job += 1
# if diffusers.install_models and len(diffusers.install_models) > 0: # import from the scan directory, if any
# logger.info("Installing requested models") if path := selections.scan_directory:
# downloaded_paths = download_weight_datasets( logger.info(f'Scanning and importing models from directory {path} [{job}/{jobs}]')
# models=diffusers.install_models, self.heuristic_install(path)
# access_token=None,
# precision=precision,
# )
# successful = {x:v for x,v in downloaded_paths.items() if v is not None}
# if len(successful) > 0:
# update_config_file(successful, config_file_path)
# if len(successful) < len(diffusers.install_models):
# unsuccessful = [x for x in downloaded_paths if downloaded_paths[x] is None]
# logger.warning(f"Some of the model downloads were not successful: {unsuccessful}")
# due to above, we have to reload the model manager because conf file self.mgr.commit()
# was changed behind its back
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
external_models = external_models or list() if selections.autoscan_on_startup and Path(selections.scan_directory).is_dir():
if scan_directory: update_autoconvert_dir(selections.scan_directory)
external_models.append(str(scan_directory)) else:
update_autoconvert_dir(None)
if len(external_models) > 0: def heuristic_install(self, model_path_id_or_url: Union[str,Path]):
logger.info("INSTALLING EXTERNAL MODELS") # A little hack to allow nested routines to retrieve info on the requested ID
for path_url_or_repo in external_models: self.current_id = model_path_id_or_url
logger.debug(path_url_or_repo)
try: path = Path(model_path_id_or_url)
model_manager.heuristic_import(
path_url_or_repo, # checkpoint file, or similar
commit_to_conf=config_file_path, if path.is_file():
config_file_callback = model_config_file_callback, self._install_path(path)
return
# folders style or similar
if path.is_dir() and any([(path/x).exists() for x in ['config.json','model_index.json','learned_embeds.bin']]):
self._install_path(path)
return
# recursive scan
if path.is_dir():
for child in path.iterdir():
self.heuristic_install(child)
return
# huggingface repo
parts = str(path).split('/')
if len(parts) == 2:
self._install_repo(str(path))
return
# a URL
if model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
self._install_url(model_path_id_or_url)
return
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
# install a model from a local path. The optional info parameter is there to prevent
# the model from being probed twice in the event that it has already been probed.
def _install_path(self, path: Path, info: ModelProbeInfo=None):
try:
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
if info.model_type == ModelType.Pipeline:
attributes = self._make_attributes(path,info)
self.mgr.add_model(model_name = path.stem if info.format=='checkpoint' else path.name,
base_model = info.base_type,
model_type = info.model_type,
model_attributes = attributes
)
except Exception as e:
logger.warning(f'{str(e)} Skipping registration.')
def _install_url(self, url: str):
# copy to a staging area, probe, import and delete
with TemporaryDirectory(dir=self.config.models_path) as staging:
location = download_with_resume(url,Path(staging))
if not location:
logger.error(f'Unable to download {url}. Skipping.')
info = ModelProbe().heuristic_probe(location)
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
models_path = shutil.move(location,dest)
# staged version will be garbage-collected at this time
self._install_path(Path(models_path), info)
def _get_model_name(self,path_name: str, location: Path)->str:
'''
Calculate a name for the model - primitive implementation.
'''
if key := self.reverse_paths.get(path_name):
(name, base, mtype) = ModelManager.parse_key(key)
return name
else:
return location.stem
def _install_repo(self, repo_id: str):
hinfo = HfApi().model_info(repo_id)
# we try to figure out how to download this most economically
# list all the files in the repo
files = [x.rfilename for x in hinfo.siblings]
with TemporaryDirectory(dir=self.config.models_path) as staging:
staging = Path(staging)
if 'model_index.json' in files:
location = self._download_hf_pipeline(repo_id, staging) # pipeline
elif 'pytorch_lora_weights.bin' in files:
location = self._download_hf_model(repo_id, ['pytorch_lora_weights.bin'], staging) # LoRA
elif self.config.precision=='float16' and 'diffusion_pytorch_model.fp16.safetensors' in files: # vae, controlnet or some other standalone
files = ['config.json', 'diffusion_pytorch_model.fp16.safetensors']
location = self._download_hf_model(repo_id, files, staging)
elif 'diffusion_pytorch_model.safetensors' in files:
files = ['config.json', 'diffusion_pytorch_model.safetensors']
location = self._download_hf_model(repo_id, files, staging)
elif 'learned_embeds.bin' in files:
location = self._download_hf_model(repo_id, ['learned_embeds.bin'], staging)
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
dest = self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id,location)
if dest.exists():
shutil.rmtree(dest)
shutil.copytree(location,dest)
self._install_path(dest, info)
def _make_attributes(self, path: Path, info: ModelProbeInfo)->dict:
# convoluted way to retrieve the description from datasets
description = f'{info.base_type.value} {info.model_type.value} model'
if key := self.reverse_paths.get(self.current_id):
if key in self.datasets:
description = self.datasets[key]['description']
attributes = dict(
path = str(path),
description = str(description),
format = info.format,
)
if info.model_type == ModelType.Pipeline:
attributes.update(
dict(
variant = info.variant_type,
)
)
if info.base_type == BaseModelType.StableDiffusion2:
attributes.update(
dict(
prediction_type = info.prediction_type,
upcast_attention = info.prediction_type == SchedulerPredictionType.VPrediction,
)
) )
except KeyboardInterrupt: if info.format=="checkpoint":
sys.exit(-1) try:
except Exception: legacy_conf = LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type] if BaseModelType.StableDiffusion2 \
pass else LEGACY_CONFIGS[info.base_type][info.variant_type]
except KeyError:
legacy_conf = 'v1-inference.yaml' # best guess
attributes.update(
dict(
config = str(self.config.legacy_conf_path / legacy_conf)
)
)
return attributes
if scan_at_startup and scan_directory.is_dir(): def _download_hf_pipeline(self, repo_id: str, staging: Path)->Path:
update_autoconvert_dir(scan_directory) '''
else: This retrieves a StableDiffusion model from cache or remote and then
update_autoconvert_dir(None) does a save_pretrained() to the indicated staging area.
'''
_,name = repo_id.split("/")
revisions = ['fp16','main'] if self.config.precision=='float16' else ['main']
model = None
for revision in revisions:
try:
model = StableDiffusionPipeline.from_pretrained(repo_id,revision=revision,safety_checker=None)
except: # most errors are due to fp16 not being present. Fix this to catch other errors
pass
if model:
break
if not model:
logger.error(f'Diffusers model {repo_id} could not be downloaded. Skipping.')
return None
model.save_pretrained(staging / name, safe_serialization=True)
return staging / name
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path)->Path:
_,name = repo_id.split("/")
location = staging / name
paths = list()
for filename in files:
p = hf_download_with_resume(repo_id,
model_dir=location,
model_name=filename,
access_token = self.access_token
)
if p:
paths.append(p)
else:
logger.warning(f'Could not download {filename} from {repo_id}.')
return location if len(paths)>0 else None
@classmethod
def _reverse_paths(cls,datasets)->dict:
'''
Reverse mapping from repo_id/path to destination name.
'''
return {v.get('path') or v.get('repo_id') : k for k, v in datasets.items()}
def update_autoconvert_dir(autodir: Path): def update_autoconvert_dir(autodir: Path):
''' '''
@ -249,89 +401,7 @@ def yes_or_no(prompt: str, default_yes=True):
return response[0] in ("y", "Y") return response[0] in ("y", "Y")
# --------------------------------------------- # ---------------------------------------------
def recommended_datasets() -> List['str']: def hf_download_from_pretrained(
datasets = set()
for ds in initial_models().keys():
if initial_models()[ds].get("recommended", False):
datasets.add(ds)
return list(datasets)
# ---------------------------------------------
def default_dataset() -> dict:
datasets = set()
for ds in initial_models().keys():
if initial_models()[ds].get("default", False):
datasets.add(ds)
return list(datasets)
# ---------------------------------------------
def all_datasets() -> dict:
datasets = dict()
for ds in initial_models().keys():
datasets[ds] = True
return datasets
# ---------------------------------------------
# look for legacy model.ckpt in models directory and offer to
# normalize its name
def migrate_models_ckpt():
model_path = os.path.join(config.root_dir, Model_dir, Weights_dir)
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
return
new_name = initial_models()["stable-diffusion-1.4"]["file"]
logger.warning(
'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.'
)
logger.warning(f"model.ckpt => {new_name}")
os.replace(
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
)
# ---------------------------------------------
def download_weight_datasets(
models: List[str], access_token: str, precision: str = "float32"
):
migrate_models_ckpt()
successful = dict()
for mod in models:
logger.info(f"Downloading {mod}:")
successful[mod] = _download_repo_or_file(
initial_models()[mod], access_token, precision=precision
)
return successful
def _download_repo_or_file(
mconfig: DictConfig, access_token: str, precision: str = "float32"
) -> Path:
path = None
if mconfig["format"] == "ckpt":
path = _download_ckpt_weights(mconfig, access_token)
else:
path = _download_diffusion_weights(mconfig, access_token, precision=precision)
if "vae" in mconfig and "repo_id" in mconfig["vae"]:
_download_diffusion_weights(
mconfig["vae"], access_token, precision=precision
)
return path
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
repo_id = mconfig["repo_id"]
filename = mconfig["file"]
cache_dir = os.path.join(config.root_dir, Model_dir, Weights_dir)
return hf_download_with_resume(
repo_id=repo_id,
model_dir=cache_dir,
model_name=filename,
access_token=access_token,
)
# ---------------------------------------------
def download_from_hf(
model_class: object, model_name: str, destination: Path, **kwargs model_class: object, model_name: str, destination: Path, **kwargs
): ):
logger = InvokeAILogger.getLogger('InvokeAI') logger = InvokeAILogger.getLogger('InvokeAI')
@ -345,35 +415,6 @@ def download_from_hf(
model.save_pretrained(destination, safe_serialization=True) model.save_pretrained(destination, safe_serialization=True)
return destination return destination
def _download_diffusion_weights(
mconfig: DictConfig, access_token: str, precision: str = "float32"
):
repo_id = mconfig["repo_id"]
model_class = (
StableDiffusionGeneratorPipeline
if mconfig.get("format", None) == "diffusers"
else AutoencoderKL
)
extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}]
path = None
for extra_args in extra_arg_list:
try:
path = download_from_hf(
model_class,
repo_id,
safety_checker=None,
**extra_args,
)
except OSError as e:
if 'Revision Not Found' in str(e):
pass
else:
logger.error(str(e))
if path:
break
return path
# --------------------------------------------- # ---------------------------------------------
def hf_download_with_resume( def hf_download_with_resume(
repo_id: str, repo_id: str,
@ -432,128 +473,3 @@ def hf_download_with_resume(
return model_dest return model_dest
# ---------------------------------------------
def update_config_file(successfully_downloaded: dict, config_file: Path):
config_file = (
Path(config_file) if config_file is not None else default_config_file()
)
# In some cases (incomplete setup, etc), the default configs directory might be missing.
# Create it if it doesn't exist.
# this check is ignored if opt.config_file is specified - user is assumed to know what they
# are doing if they are passing a custom config file from elsewhere.
if config_file is default_config_file() and not config_file.parent.exists():
configs_src = Dataset_path.parent
configs_dest = default_config_file().parent
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
yaml = new_config_file_contents(successfully_downloaded, config_file)
try:
backup = None
if os.path.exists(config_file):
logger.warning(
f"{config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
)
backup = config_file.with_suffix(".yaml.orig")
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
if sys.platform == "win32" and backup.is_file():
backup.unlink()
config_file.rename(backup)
with TemporaryFile() as tmp:
tmp.write(Config_preamble.encode())
tmp.write(yaml.encode())
with open(str(config_file.expanduser().resolve()), "wb") as new_config:
tmp.seek(0)
new_config.write(tmp.read())
except Exception as e:
logger.error(f"Error creating config file {config_file}: {str(e)}")
if backup is not None:
logger.info("restoring previous config file")
## workaround, for WinError 183, see above
if sys.platform == "win32" and config_file.is_file():
config_file.unlink()
backup.rename(config_file)
return
logger.info(f"Successfully created new configuration file {config_file}")
# ---------------------------------------------
def new_config_file_contents(
successfully_downloaded: dict,
config_file: Path,
) -> str:
if config_file.exists():
conf = OmegaConf.load(str(config_file.expanduser().resolve()))
else:
conf = OmegaConf.create()
default_selected = None
for model in successfully_downloaded:
# a bit hacky - what we are doing here is seeing whether a checkpoint
# version of the model was previously defined, and whether the current
# model is a diffusers (indicated with a path)
if conf.get(model) and Path(successfully_downloaded[model]).is_dir():
delete_weights(model, conf[model])
stanza = {}
mod = initial_models()[model]
stanza["description"] = mod["description"]
stanza["repo_id"] = mod["repo_id"]
stanza["format"] = mod["format"]
# diffusers don't need width and height (probably .ckpt doesn't either)
# so we no longer require these in INITIAL_MODELS.yaml
if "width" in mod:
stanza["width"] = mod["width"]
if "height" in mod:
stanza["height"] = mod["height"]
if "file" in mod:
stanza["weights"] = os.path.relpath(
successfully_downloaded[model], start=config.root_dir
)
stanza["config"] = os.path.normpath(
os.path.join(sd_configs(), mod["config"])
)
if "vae" in mod:
if "file" in mod["vae"]:
stanza["vae"] = os.path.normpath(
os.path.join(Model_dir, Weights_dir, mod["vae"]["file"])
)
else:
stanza["vae"] = mod["vae"]
if mod.get("default", False):
stanza["default"] = True
default_selected = True
conf[model] = stanza
# if no default model was chosen, then we select the first
# one in the list
if not default_selected:
conf[list(successfully_downloaded.keys())[0]]["default"] = True
return OmegaConf.to_yaml(conf)
# ---------------------------------------------
def delete_weights(model_name: str, conf_stanza: dict):
if not (weights := conf_stanza.get("weights")):
return
if re.match("/VAE/", conf_stanza.get("config")):
return
logger.warning(
f"\nThe checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
)
weights = Path(weights)
if not weights.is_absolute():
weights = config.root_dir / weights
try:
weights.unlink()
except OSError as e:
logger.error(str(e))

View File

@ -4,3 +4,4 @@ Initialization file for invokeai.backend.model_management
from .model_manager import ModelManager, ModelInfo from .model_manager import ModelManager, ModelInfo
from .model_cache import ModelCache from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType from .models import BaseModelType, ModelType, SubModelType, ModelVariantType

View File

@ -682,7 +682,7 @@ class ModelManager(object):
for model_key, model_config in list(self.models.items()): for model_key, model_config in list(self.models.items()):
model_name, base_model, model_type = self.parse_key(model_key) model_name, base_model, model_type = self.parse_key(model_key)
model_path = str(self.globals.root / model_config.path) model_path = str(self.globals.root_path / model_config.path)
if not os.path.exists(model_path): if not os.path.exists(model_path):
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
if model_class.save_to_config: if model_class.save_to_config:
@ -703,13 +703,14 @@ class ModelManager(object):
for entry_name in os.listdir(models_dir): for entry_name in os.listdir(models_dir):
model_path = os.path.join(models_dir, entry_name) model_path = os.path.join(models_dir, entry_name)
if model_path not in loaded_files: # TODO: check if model_path not in loaded_files: # TODO: check
model_name = Path(model_path).stem model_path = Path(model_path)
model_name = model_path.name if model_path.is_dir else model_path.stem
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
if model_key in self.models: if model_key in self.models:
raise Exception(f"Model with key {model_key} added twice") raise Exception(f"Model with key {model_key} added twice")
model_config: ModelConfigBase = model_class.probe_config(model_path) model_config: ModelConfigBase = model_class.probe_config(str(model_path))
self.models[model_key] = model_config self.models[model_key] = model_config
new_models_found = True new_models_found = True

View File

@ -15,13 +15,13 @@ import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings
@dataclass @dataclass
class ModelVariantInfo(object): class ModelProbeInfo(object):
model_type: ModelType model_type: ModelType
base_type: BaseModelType base_type: BaseModelType
variant_type: ModelVariantType variant_type: ModelVariantType
prediction_type: SchedulerPredictionType prediction_type: SchedulerPredictionType
upcast_attention: bool upcast_attention: bool
format: Literal['folder','checkpoint'] format: Literal['diffusers','checkpoint']
image_size: int image_size: int
class ProbeBase(object): class ProbeBase(object):
@ -31,7 +31,7 @@ class ProbeBase(object):
class ModelProbe(object): class ModelProbe(object):
PROBES = { PROBES = {
'folder': { }, 'diffusers': { },
'checkpoint': { }, 'checkpoint': { },
} }
@ -43,7 +43,7 @@ class ModelProbe(object):
@classmethod @classmethod
def register_probe(cls, def register_probe(cls,
format: Literal['folder','file'], format: Literal['diffusers','checkpoint'],
model_type: ModelType, model_type: ModelType,
probe_class: ProbeBase): probe_class: ProbeBase):
cls.PROBES[format][model_type] = probe_class cls.PROBES[format][model_type] = probe_class
@ -51,8 +51,8 @@ class ModelProbe(object):
@classmethod @classmethod
def heuristic_probe(cls, def heuristic_probe(cls,
model: Union[Dict, ModelMixin, Path], model: Union[Dict, ModelMixin, Path],
prediction_type_helper: Callable[[Path],BaseModelType]=None, prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->ModelVariantInfo: )->ModelProbeInfo:
if isinstance(model,Path): if isinstance(model,Path):
return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper) return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper)
elif isinstance(model,(dict,ModelMixin,ConfigMixin)): elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
@ -64,7 +64,7 @@ class ModelProbe(object):
def probe(cls, def probe(cls,
model_path: Path, model_path: Path,
model: Union[Dict, ModelMixin] = None, model: Union[Dict, ModelMixin] = None,
prediction_type_helper: Callable[[Path],BaseModelType] = None)->ModelVariantInfo: prediction_type_helper: Callable[[Path],SchedulerPredictionType] = None)->ModelProbeInfo:
''' '''
Probe the model at model_path and return sufficient information about it Probe the model at model_path and return sufficient information about it
to place it somewhere in the models directory hierarchy. If the model is to place it somewhere in the models directory hierarchy. If the model is
@ -74,14 +74,14 @@ class ModelProbe(object):
between V2-Base and V2-768 SD models. between V2-Base and V2-768 SD models.
''' '''
if model_path: if model_path:
format = 'folder' if model_path.is_dir() else 'checkpoint' format = 'diffusers' if model_path.is_dir() else 'checkpoint'
else: else:
format = 'folder' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint' format = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
model_info = None model_info = None
try: try:
model_type = cls.get_model_type_from_folder(model_path, model) \ model_type = cls.get_model_type_from_folder(model_path, model) \
if format == 'folder' \ if format == 'diffusers' \
else cls.get_model_type_from_checkpoint(model_path, model) else cls.get_model_type_from_checkpoint(model_path, model)
probe_class = cls.PROBES[format].get(model_type) probe_class = cls.PROBES[format].get(model_type)
if not probe_class: if not probe_class:
@ -90,7 +90,7 @@ class ModelProbe(object):
base_type = probe.get_base_type() base_type = probe.get_base_type()
variant_type = probe.get_variant_type() variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type() prediction_type = probe.get_scheduler_prediction_type()
model_info = ModelVariantInfo( model_info = ModelProbeInfo(
model_type = model_type, model_type = model_type,
base_type = base_type, base_type = base_type,
variant_type = variant_type, variant_type = variant_type,
@ -196,7 +196,7 @@ class CheckpointProbeBase(ProbeBase):
def __init__(self, def __init__(self,
checkpoint_path: Path, checkpoint_path: Path,
checkpoint: dict, checkpoint: dict,
helper: Callable[[Path],BaseModelType] = None helper: Callable[[Path],SchedulerPredictionType] = None
)->BaseModelType: )->BaseModelType:
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path) self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
@ -405,11 +405,11 @@ class LoRAFolderProbe(FolderProbeBase):
pass pass
############## register probe classes ###### ############## register probe classes ######
ModelProbe.register_probe('folder', ModelType.Pipeline, PipelineFolderProbe) ModelProbe.register_probe('diffusers', ModelType.Pipeline, PipelineFolderProbe)
ModelProbe.register_probe('folder', ModelType.Vae, VaeFolderProbe) ModelProbe.register_probe('diffusers', ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe('folder', ModelType.Lora, LoRAFolderProbe) ModelProbe.register_probe('diffusers', ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe('folder', ModelType.TextualInversion, TextualInversionFolderProbe) ModelProbe.register_probe('diffusers', ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe('folder', ModelType.ControlNet, ControlNetFolderProbe) ModelProbe.register_probe('diffusers', ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe('checkpoint', ModelType.Pipeline, PipelineCheckpointProbe) ModelProbe.register_probe('checkpoint', ModelType.Pipeline, PipelineCheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe) ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe) ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe)

View File

@ -154,7 +154,6 @@ class ModelBase(metaclass=ABCMeta):
def create_config(cls, **kwargs) -> ModelConfigBase: def create_config(cls, **kwargs) -> ModelConfigBase:
if "format" not in kwargs: if "format" not in kwargs:
raise Exception("Field 'format' not found in model config") raise Exception("Field 'format' not found in model config")
configs = cls._get_configs() configs = cls._get_configs()
return configs[kwargs["format"]](**kwargs) return configs[kwargs["format"]](**kwargs)

View File

@ -3,6 +3,7 @@ sd-1/pipeline/stable-diffusion-v1-5:
description: Stable Diffusion version 1.5 diffusers model (4.27 GB) description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
repo_id: runwayml/stable-diffusion-v1-5 repo_id: runwayml/stable-diffusion-v1-5
recommended: True recommended: True
default: True
sd-1/pipeline/stable-diffusion-inpainting: sd-1/pipeline/stable-diffusion-inpainting:
description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB) description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)
repo_id: runwayml/stable-diffusion-inpainting repo_id: runwayml/stable-diffusion-inpainting
@ -27,7 +28,7 @@ sd-1/pipeline/Dungeons-and-Diffusion:
description: Dungeons & Dragons characters (2.13 GB) description: Dungeons & Dragons characters (2.13 GB)
repo_id: 0xJustin/Dungeons-and-Diffusion repo_id: 0xJustin/Dungeons-and-Diffusion
recommended: False recommended: False
sd-1/pipeline/dreamlike-photoreal-2.0: sd-1/pipeline/dreamlike-photoreal-2:
description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB) description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB)
repo_id: dreamlike-art/dreamlike-photoreal-2.0 repo_id: dreamlike-art/dreamlike-photoreal-2.0
recommended: False recommended: False

View File

@ -0,0 +1,159 @@
model:
base_learning_rate: 5.0e-05
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
parameterization: "v"
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: hybrid
scale_factor: 0.18215
monitor: val/loss_simple_ema
finetune_keys: null
use_ema: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
image_size: 32 # unused
in_channels: 9
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
data:
target: ldm.data.laion.WebDataModuleFromConfig
params:
tar_base: null # for concat as in LAION-A
p_unsafe_threshold: 0.1
filter_word_list: "data/filters.yaml"
max_pwatermark: 0.45
batch_size: 8
num_workers: 6
multinode: True
min_size: 512
train:
shards:
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
shuffle: 10000
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.RandomCrop
params:
size: 512
postprocess:
target: ldm.data.laion.AddMask
params:
mode: "512train-large"
p_drop: 0.25
# NOTE use enough shards to avoid empty validation loops in workers
validation:
shards:
- "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
shuffle: 0
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.CenterCrop
params:
size: 512
postprocess:
target: ldm.data.laion.AddMask
params:
mode: "512train-large"
p_drop: 0.25
lightning:
find_unused_parameters: True
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 10000
image_logger:
target: main.ImageLogger
params:
enable_autocast: False
disabled: False
batch_frequency: 1000
max_images: 4
increase_log_steps: False
log_first_step: False
log_images_kwargs:
use_ema_scope: False
inpaint: False
plot_progressive_rows: False
plot_diffusion_rows: False
N: 4
unconditional_guidance_scale: 5.0
unconditional_guidance_label: [""]
ddim_steps: 50 # todo check these out for depth2img,
ddim_eta: 0.0 # todo check these out for depth2img,
trainer:
benchmark: True
val_check_interval: 5000000
num_sanity_val_steps: 0
accumulate_grad_batches: 1

View File

@ -0,0 +1,158 @@
model:
base_learning_rate: 5.0e-05
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: hybrid
scale_factor: 0.18215
monitor: val/loss_simple_ema
finetune_keys: null
use_ema: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
image_size: 32 # unused
in_channels: 9
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
data:
target: ldm.data.laion.WebDataModuleFromConfig
params:
tar_base: null # for concat as in LAION-A
p_unsafe_threshold: 0.1
filter_word_list: "data/filters.yaml"
max_pwatermark: 0.45
batch_size: 8
num_workers: 6
multinode: True
min_size: 512
train:
shards:
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
shuffle: 10000
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.RandomCrop
params:
size: 512
postprocess:
target: ldm.data.laion.AddMask
params:
mode: "512train-large"
p_drop: 0.25
# NOTE use enough shards to avoid empty validation loops in workers
validation:
shards:
- "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
shuffle: 0
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.CenterCrop
params:
size: 512
postprocess:
target: ldm.data.laion.AddMask
params:
mode: "512train-large"
p_drop: 0.25
lightning:
find_unused_parameters: True
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 10000
image_logger:
target: main.ImageLogger
params:
enable_autocast: False
disabled: False
batch_frequency: 1000
max_images: 4
increase_log_steps: False
log_first_step: False
log_images_kwargs:
use_ema_scope: False
inpaint: False
plot_progressive_rows: False
plot_diffusion_rows: False
N: 4
unconditional_guidance_scale: 5.0
unconditional_guidance_label: [""]
ddim_steps: 50 # todo check these out for depth2img,
ddim_eta: 0.0 # todo check these out for depth2img,
trainer:
benchmark: True
val_check_interval: 5000000
num_sanity_val_steps: 0
accumulate_grad_batches: 1

View File

@ -11,7 +11,6 @@ The work is actually done in backend code in model_install_backend.py.
import argparse import argparse
import curses import curses
import os
import sys import sys
import textwrap import textwrap
import traceback import traceback
@ -20,27 +19,21 @@ from multiprocessing import Process
from multiprocessing.connection import Connection, Pipe from multiprocessing.connection import Connection, Pipe
from pathlib import Path from pathlib import Path
from shutil import get_terminal_size from shutil import get_terminal_size
from typing import List
import logging import logging
import npyscreen import npyscreen
import torch import torch
from npyscreen import widget from npyscreen import widget
from omegaconf import OmegaConf
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.install.model_install_backend import ( from invokeai.backend.install.model_install_backend import (
Dataset_path, # most of these should go!!
default_config_file,
default_dataset,
install_requested_models,
recommended_datasets,
ModelInstallList, ModelInstallList,
UserSelections, InstallSelections,
ModelInstall ModelInstall,
SchedulerPredictionType,
) )
from invokeai.backend.model_management import ModelManager, BaseModelType, ModelType from invokeai.backend.model_management import ModelManager, ModelType
from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util import choose_precision, choose_torch_device
from invokeai.frontend.install.widgets import ( from invokeai.frontend.install.widgets import (
CenteredTitleText, CenteredTitleText,
@ -133,7 +126,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
bottom_of_table = self.nextrely bottom_of_table = self.nextrely
self.nextrely = top_of_table self.nextrely = top_of_table
self.pipeline_models = self.add_model_widgets( self.pipeline_models = self.add_pipeline_widgets(
model_type=ModelType.Pipeline, model_type=ModelType.Pipeline,
window_width=window_width, window_width=window_width,
exclude = self.starter_models exclude = self.starter_models
@ -210,11 +203,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
starters = self.starter_models starters = self.starter_models
starter_model_labels = self.model_labels starter_model_labels = self.model_labels
recommended_models = set([
x
for x in starters
if models[x].recommended
])
self.installed_models = sorted( self.installed_models = sorted(
[x for x in starters if models[x].installed] [x for x in starters if models[x].installed]
) )
@ -312,16 +300,18 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
return widgets return widgets
### Tab for arbitrary diffusers widgets ### ### Tab for arbitrary diffusers widgets ###
def add_diffusers_widgets(self, def add_pipeline_widgets(self,
model_type: ModelType=ModelType.Pipeline, model_type: ModelType=ModelType.Pipeline,
window_width: int=120, window_width: int=120,
)->dict[str,npyscreen.widget]: **kwargs,
)->dict[str,npyscreen.widget]:
'''Similar to add_model_widgets() but adds some additional widgets at the bottom '''Similar to add_model_widgets() but adds some additional widgets at the bottom
to support the autoload directory''' to support the autoload directory'''
widgets = self.add_model_widgets( widgets = self.add_model_widgets(
model_type = model_type, model_type = model_type,
window_width = window_width, window_width = window_width,
install_prompt=f"Additional {model_type.value.title()} models already installed.", install_prompt=f"Additional {model_type.value.title()} models already installed.",
**kwargs,
) )
label = "Directory to scan for models to automatically import (<tab> autocompletes):" label = "Directory to scan for models to automatically import (<tab> autocompletes):"
@ -428,7 +418,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
target = process_and_execute, target = process_and_execute,
kwargs=dict( kwargs=dict(
opt = app.program_opts, opt = app.program_opts,
selections = app.user_selections, selections = app.install_selections,
conn_out = child_conn, conn_out = child_conn,
) )
) )
@ -436,8 +426,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
child_conn.close() child_conn.close()
self.subprocess_connection = parent_conn self.subprocess_connection = parent_conn
self.subprocess = p self.subprocess = p
app.user_selections = UserSelections() app.install_selections = InstallSelections()
# process_and_execute(app.opt, app.user_selections) # process_and_execute(app.opt, app.install_selections)
def on_back(self): def on_back(self):
self.parentApp.switchFormPrevious() self.parentApp.switchFormPrevious()
@ -453,7 +443,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.parentApp.setNextForm(None) self.parentApp.setNextForm(None)
self.parentApp.user_cancelled = False self.parentApp.user_cancelled = False
self.editing = False self.editing = False
########## This routine monitors the child process that is performing model installation and removal ##### ########## This routine monitors the child process that is performing model installation and removal #####
def while_waiting(self): def while_waiting(self):
'''Called during idle periods. Main task is to update the Log Messages box with messages '''Called during idle periods. Main task is to update the Log Messages box with messages
@ -532,73 +522,24 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
.autoscan_on_startup: True if invokeai should scan and import at startup time .autoscan_on_startup: True if invokeai should scan and import at startup time
.import_model_paths: list of URLs, repo_ids and file paths to import .import_model_paths: list of URLs, repo_ids and file paths to import
""" """
# we're using a global here rather than storing the result in the parentapp selections = self.parentApp.install_selections
# due to some bug in npyscreen that is causing attributes to be lost all_models = self.all_models
selections = self.parentApp.user_selections
# Starter models to install/remove # Defined models (in INITIAL_CONFIG.yaml or models.yaml) to add/remove
# TO DO - turn these into a dict so we don't have to hard-code the attributes ui_sections = [self.starter_pipelines, self.pipeline_models,
print(f'installed={[x for x in self.all_models if self.all_models[x].installed]}',file=f) self.controlnet_models, self.lora_models, self.ti_models]
for section in [self.starter_pipelines, self.pipeline_models, for section in ui_sections:
self.controlnet_models, self.lora_models, self.ti_models]:
selected = set([section['models'][x] for x in section['models_selected'].value]) selected = set([section['models'][x] for x in section['models_selected'].value])
models_to_install = [x for x in selected if not self.all_models[x].installed] models_to_install = [x for x in selected if not self.all_models[x].installed]
models_to_remove = [x for x in section['models'] if x not in selected and self.all_models[x].installed] models_to_remove = [x for x in section['models'] if x not in selected and self.all_models[x].installed]
selections.remove_models.extend(models_to_remove)
# "More" models selections.install_models.extend(all_models[x].path or all_models[x].repo_id \
selections.import_model_paths = self.pipeline_models['download_ids'].value.split() for x in models_to_install if all_models[x].path or all_models[x].repo_id)
if diffusers_selected := self.pipeline_models.get('models_selected'):
selections.remove_models.extend([x
for x in diffusers_selected.values
if self.installed_pipeline_models[x]
and diffusers_selected.values.index(x) not in diffusers_selected.value
]
)
# TODO: REFACTOR THIS REPETITIVE CODE
if cn_models_selected := self.controlnet_models.get('models_selected'):
selections.install_cn_models = [cn_models_selected.values[x]
for x in cn_models_selected.value
if not self.installed_cn_models[cn_models_selected.values[x]]
]
selections.remove_cn_models = [x
for x in cn_models_selected.values
if self.installed_cn_models[x]
and cn_models_selected.values.index(x) not in cn_models_selected.value
]
if (additional_cns := self.controlnet_models['download_ids'].value.split()):
valid_cns = [x for x in additional_cns if '/' in x]
selections.install_cn_models.extend(valid_cns)
# same thing, for LoRAs # models located in the 'download_ids" section
if loras_selected := self.lora_models.get('models_selected'): for section in ui_sections:
selections.install_lora_models = [loras_selected.values[x] if downloads := section.get('download_ids'):
for x in loras_selected.value selections.install_models.extend(downloads.value.split())
if not self.installed_lora_models[loras_selected.values[x]]
]
selections.remove_lora_models = [x
for x in loras_selected.values
if self.installed_lora_models[x]
and loras_selected.values.index(x) not in loras_selected.value
]
if (additional_loras := self.lora_models['download_ids'].value.split()):
selections.install_lora_models.extend(additional_loras)
# same thing, for TIs
# TODO: refactor
if tis_selected := self.ti_models.get('models_selected'):
selections.install_ti_models = [tis_selected.values[x]
for x in tis_selected.value
if not self.installed_ti_models[tis_selected.values[x]]
]
selections.remove_ti_models = [x
for x in tis_selected.values
if self.installed_ti_models[x]
and tis_selected.values.index(x) not in tis_selected.value
]
if (additional_tis := self.ti_models['download_ids'].value.split()):
selections.install_ti_models.extend(additional_tis)
# load directory and whether to scan on startup # load directory and whether to scan on startup
selections.scan_directory = self.pipeline_models['autoload_directory'].value selections.scan_directory = self.pipeline_models['autoload_directory'].value
@ -609,7 +550,7 @@ class AddModelApplication(npyscreen.NPSAppManaged):
super().__init__() super().__init__()
self.program_opts = opt self.program_opts = opt
self.user_cancelled = False self.user_cancelled = False
self.user_selections = UserSelections() self.install_selections = InstallSelections()
def onStart(self): def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme) npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
@ -628,21 +569,17 @@ class StderrToMessage():
pass pass
# -------------------------------------------------------- # --------------------------------------------------------
def ask_user_for_config_file(model_path: Path, def ask_user_for_prediction_type(model_path: Path,
tui_conn: Connection=None tui_conn: Connection=None
)->Path: )->Path:
if tui_conn: if tui_conn:
logger.debug('Waiting for user response...') logger.debug('Waiting for user response...')
return _ask_user_for_cf_tui(model_path, tui_conn) return _ask_user_for_pt_tui(model_path, tui_conn)
else: else:
return _ask_user_for_cf_cmdline(model_path) return _ask_user_for_pt_cmdline(model_path)
def _ask_user_for_cf_cmdline(model_path): def _ask_user_for_pt_cmdline(model_path):
choices = [ choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
config.legacy_conf_path / x
for x in ['v2-inference.yaml','v2-inference-v.yaml']
]
choices.extend([None])
print( print(
f""" f"""
Please select the type of the V2 checkpoint named {model_path.name}: Please select the type of the V2 checkpoint named {model_path.name}:
@ -664,7 +601,7 @@ Please select the type of the V2 checkpoint named {model_path.name}:
return return
return choice return choice
def _ask_user_for_cf_tui(model_path: Path, tui_conn: Connection)->Path: def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->Path:
try: try:
tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8')) tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8'))
# note that we don't do any status checking here # note that we don't do any status checking here
@ -672,20 +609,20 @@ def _ask_user_for_cf_tui(model_path: Path, tui_conn: Connection)->Path:
if response is None: if response is None:
return None return None
elif response == 'epsilon': elif response == 'epsilon':
return config.legacy_conf_path / 'v2-inference.yaml' return SchedulerPredictionType.epsilon
elif response == 'v': elif response == 'v':
return config.legacy_conf_path / 'v2-inference-v.yaml' return SchedulerPredictionType.VPrediction
elif response == 'abort': elif response == 'abort':
logger.info('Conversion aborted') logger.info('Conversion aborted')
return None return None
else: else:
return Path(response) return response
except: except:
return None return None
# -------------------------------------------------------- # --------------------------------------------------------
def process_and_execute(opt: Namespace, def process_and_execute(opt: Namespace,
selections: UserSelections, selections: InstallSelections,
conn_out: Connection=None, conn_out: Connection=None,
): ):
# set up so that stderr is sent to conn_out # set up so that stderr is sent to conn_out
@ -696,34 +633,14 @@ def process_and_execute(opt: Namespace,
logger = InvokeAILogger.getLogger() logger = InvokeAILogger.getLogger()
logger.handlers.clear() logger.handlers.clear()
logger.addHandler(logging.StreamHandler(translator)) logger.addHandler(logging.StreamHandler(translator))
models_to_install = selections.install_models
models_to_remove = selections.remove_models
directory_to_scan = selections.scan_directory
scan_at_startup = selections.autoscan_on_startup
potential_models_to_install = selections.import_model_paths
name_map = selections.model_name_map
install_requested_models( installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x,conn_out))
diffusers = ModelInstallList(models_to_install, [name_map[ModelType.Pipeline][x] for x in models_to_remove]), installer.install(selections)
controlnet = ModelInstallList(selections.install_cn_models, [name_map[ModelType.ControlNet][x] for x in selections.remove_cn_models]),
lora = ModelInstallList(selections.install_lora_models, [name_map[ModelType.Lora][x] for x in selections.remove_lora_models]),
ti = ModelInstallList(selections.install_ti_models, [name_map[ModelType.TextualInversion][x] for x in selections.remove_ti_models]),
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
external_models=potential_models_to_install,
scan_at_startup=scan_at_startup,
precision="float32"
if opt.full_precision
else choose_precision(torch.device(choose_torch_device())),
config_file_path=Path(opt.config_file) if opt.config_file else config.model_conf_path,
model_config_file_callback = lambda x: ask_user_for_config_file(x,conn_out)
)
if conn_out: if conn_out:
conn_out.send_bytes('*done*'.encode('utf-8')) conn_out.send_bytes('*done*'.encode('utf-8'))
conn_out.close() conn_out.close()
def do_listings(opt)->bool: def do_listings(opt)->bool:
"""List installed models of various sorts, and return """List installed models of various sorts, and return
True if any were requested.""" True if any were requested."""
@ -754,38 +671,34 @@ def select_and_download_models(opt: Namespace):
if opt.full_precision if opt.full_precision
else choose_precision(torch.device(choose_torch_device())) else choose_precision(torch.device(choose_torch_device()))
) )
config.precision = precision
if do_listings(opt): helper = lambda x: ask_user_for_prediction_type(x)
pass # if do_listings(opt):
# this processes command line additions/removals # pass
elif opt.diffusers or opt.controlnets or opt.textual_inversions or opt.loras:
action = 'remove_models' if opt.delete else 'install_models' installer = ModelInstall(config, prediction_type_helper=helper)
diffusers_args = {'diffusers':ModelInstallList(remove_models=opt.diffusers or [])} \ if opt.add or opt.delete:
if opt.delete \ selections = InstallSelections(
else {'external_models':opt.diffusers or []} install_models = opt.add or [],
install_requested_models( remove_models = opt.delete or []
**diffusers_args,
controlnet=ModelInstallList(**{action:opt.controlnets or []}),
ti=ModelInstallList(**{action:opt.textual_inversions or []}),
lora=ModelInstallList(**{action:opt.loras or []}),
precision=precision,
model_config_file_callback=lambda x: ask_user_for_config_file(x),
) )
installer.install(selections)
elif opt.default_only: elif opt.default_only:
install_requested_models( selections = InstallSelections(
diffusers=ModelInstallList(install_models=default_dataset()), install_models = installer.default_model()
precision=precision,
) )
installer.install(selections)
elif opt.yes_to_all: elif opt.yes_to_all:
install_requested_models( selections = InstallSelections(
diffusers=ModelInstallList(install_models=recommended_datasets()), install_models = installer.recommended_models()
precision=precision,
) )
installer.install(selections)
# this is where the TUI is called # this is where the TUI is called
else: else:
# needed because the torch library is loaded, even though we don't use it # needed because the torch library is loaded, even though we don't use it
torch.multiprocessing.set_start_method("spawn") # currently commented out because it has started generating errors (?)
# torch.multiprocessing.set_start_method("spawn")
# the third argument is needed in the Windows 11 environment in # the third argument is needed in the Windows 11 environment in
# order to launch and resize a console window running this program # order to launch and resize a console window running this program
@ -801,35 +714,20 @@ def select_and_download_models(opt: Namespace):
installApp.main_form.subprocess.terminate() installApp.main_form.subprocess.terminate()
installApp.main_form.subprocess = None installApp.main_form.subprocess = None
raise e raise e
process_and_execute(opt, installApp.user_selections) process_and_execute(opt, installApp.install_selections)
# ------------------------------------- # -------------------------------------
def main(): def main():
parser = argparse.ArgumentParser(description="InvokeAI model downloader") parser = argparse.ArgumentParser(description="InvokeAI model downloader")
parser.add_argument( parser.add_argument(
"--diffusers", "--add",
nargs="*", nargs="*",
help="List of URLs or repo_ids of diffusers to install/delete", help="List of URLs, local paths or repo_ids of models to install",
)
parser.add_argument(
"--loras",
nargs="*",
help="List of URLs or repo_ids of LoRA/LyCORIS models to install/delete",
)
parser.add_argument(
"--controlnets",
nargs="*",
help="List of URLs or repo_ids of controlnet models to install/delete",
)
parser.add_argument(
"--textual-inversions",
nargs="*",
help="List of URLs or repo_ids of textual inversion embeddings to install/delete",
) )
parser.add_argument( parser.add_argument(
"--delete", "--delete",
action="store_true", nargs="*",
help="Delete models listed on command line rather than installing them", help="List of names of models to idelete",
) )
parser.add_argument( parser.add_argument(
"--full-precision", "--full-precision",
@ -849,7 +747,7 @@ def main():
parser.add_argument( parser.add_argument(
"--default_only", "--default_only",
action="store_true", action="store_true",
help="only install the default model", help="Only install the default model",
) )
parser.add_argument( parser.add_argument(
"--list-models", "--list-models",

View File

@ -17,8 +17,8 @@ from shutil import get_terminal_size
from curses import BUTTON2_CLICKED,BUTTON3_CLICKED from curses import BUTTON2_CLICKED,BUTTON3_CLICKED
# minimum size for UIs # minimum size for UIs
MIN_COLS = 120 MIN_COLS = 180
MIN_LINES = 50 MIN_LINES = 55
# ------------------------------------- # -------------------------------------
def set_terminal_size(columns: int, lines: int, launch_command: str=None): def set_terminal_size(columns: int, lines: int, launch_command: str=None):
@ -384,7 +384,6 @@ def select_stable_diffusion_config_file(
"An SD v2.x base model (512 pixels; no 'parameterization:' line in its yaml file)", "An SD v2.x base model (512 pixels; no 'parameterization:' line in its yaml file)",
"An SD v2.x v-predictive model (768 pixels; 'parameterization: \"v\"' line in its yaml file)", "An SD v2.x v-predictive model (768 pixels; 'parameterization: \"v\"' line in its yaml file)",
"Skip installation for now and come back later", "Skip installation for now and come back later",
"Enter config file path manually",
] ]
F = ConfirmCancelPopup( F = ConfirmCancelPopup(
@ -406,35 +405,17 @@ def select_stable_diffusion_config_file(
mlw.values = message mlw.values = message
choice = F.add( choice = F.add(
SingleSelectWithChanged, npyscreen.SelectOne,
values = options, values = options,
value = [0], value = [0],
max_height = len(options)+1, max_height = len(options)+1,
scroll_exit=True, scroll_exit=True,
) )
file = F.add(
FileBox,
name='Path to config file',
max_height=3,
hidden=True,
must_exist=True,
scroll_exit=True
)
def toggle_visible(value):
value = value[0]
if value==3:
file.hidden=False
else:
file.hidden=True
F.display()
choice.on_changed = toggle_visible
F.editw = 1 F.editw = 1
F.edit() F.edit()
if not F.value: if not F.value:
return None return None
assert choice.value[0] in range(0,4),'invalid choice' assert choice.value[0] in range(0,3),'invalid choice'
choices = ['epsilon','v','abort',file.value] choices = ['epsilon','v','abort']
return choices[choice.value[0]] return choices[choice.value[0]]

View File

@ -26,7 +26,7 @@ from transformers import (
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.backend.model_management import ModelManager from invokeai.backend.model_management import ModelManager
from invokeai.backend.model_management.model_probe import ( from invokeai.backend.model_management.model_probe import (
ModelProbe, ModelType, BaseModelType, SchedulerPredictionType, ModelVariantInfo ModelProbe, ModelType, BaseModelType, SchedulerPredictionType, ModelProbeInfo
) )
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -171,13 +171,13 @@ def migrate_tuning_models(dest: Path):
logger.info(f'Scanning {subdir}') logger.info(f'Scanning {subdir}')
migrate_models(src, dest) migrate_models(src, dest)
def write_yaml(model_name: str, path:Path, info:ModelVariantInfo, dest_yaml: io.TextIOBase): def write_yaml(model_name: str, path:Path, info:ModelProbeInfo, dest_yaml: io.TextIOBase):
name = unique_name(model_name, info) name = unique_name(model_name, info)
stanza = { stanza = {
f'{info.base_type.value}/{info.model_type.value}/{name}': { f'{info.base_type.value}/{info.model_type.value}/{name}': {
'name': model_name, 'name': model_name,
'path': str(path), 'path': str(path),
'description': f'diffusers model {model_name}', 'description': f'A {info.base_type.value} {info.model_type.value} model',
'format': 'diffusers', 'format': 'diffusers',
'image_size': info.image_size, 'image_size': info.image_size,
'base': info.base_type.value, 'base': info.base_type.value,
@ -266,7 +266,7 @@ def migrate_checkpoints(dest_dir: Path, dest_yaml: io.TextIOBase):
{ {
'name': model_name, 'name': model_name,
'path': str(weights), 'path': str(weights),
'description': f'checkpoint model {model_name}', 'description': f'{info.base_type.value}-based checkpoint',
'format': 'checkpoint', 'format': 'checkpoint',
'image_size': info.image_size, 'image_size': info.image_size,
'base': info.base_type.value, 'base': info.base_type.value,