mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
configure/install basically working; needs edge case testing
This commit is contained in:
parent
ada7399753
commit
f28d50070e
@ -56,11 +56,10 @@ from invokeai.frontend.install.widgets import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
||||||
from invokeai.backend.install.model_install_backend import (
|
from invokeai.backend.install.model_install_backend import (
|
||||||
default_dataset,
|
hf_download_from_pretrained,
|
||||||
download_from_hf,
|
|
||||||
hf_download_with_resume,
|
hf_download_with_resume,
|
||||||
recommended_datasets,
|
InstallSelections,
|
||||||
UserSelections,
|
ModelInstall,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_management.model_probe import (
|
from invokeai.backend.model_management.model_probe import (
|
||||||
ModelProbe, ModelType, BaseModelType, SchedulerPredictionType
|
ModelProbe, ModelType, BaseModelType, SchedulerPredictionType
|
||||||
@ -198,8 +197,8 @@ def download_conversion_models():
|
|||||||
|
|
||||||
# sd-1
|
# sd-1
|
||||||
repo_id = 'openai/clip-vit-large-patch14'
|
repo_id = 'openai/clip-vit-large-patch14'
|
||||||
download_from_hf(CLIPTokenizer, repo_id, target_dir / 'clip-vit-large-patch14')
|
hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / 'clip-vit-large-patch14')
|
||||||
download_from_hf(CLIPTextModel, repo_id, target_dir / 'clip-vit-large-patch14')
|
hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / 'clip-vit-large-patch14')
|
||||||
|
|
||||||
# sd-2
|
# sd-2
|
||||||
repo_id = "stabilityai/stable-diffusion-2"
|
repo_id = "stabilityai/stable-diffusion-2"
|
||||||
@ -275,8 +274,8 @@ def download_clipseg():
|
|||||||
logger.info("Installing clipseg model for text-based masking...")
|
logger.info("Installing clipseg model for text-based masking...")
|
||||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||||
try:
|
try:
|
||||||
download_from_hf(AutoProcessor, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
|
hf_download_from_pretrained(AutoProcessor, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
|
||||||
download_from_hf(CLIPSegForImageSegmentation, CLIPSEG_MODEL,'models/core/misc/clipseg')
|
hf_download_from_pretrained(CLIPSegForImageSegmentation, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.info("Error installing clipseg model:")
|
logger.info("Error installing clipseg model:")
|
||||||
logger.info(traceback.format_exc())
|
logger.info(traceback.format_exc())
|
||||||
@ -592,7 +591,7 @@ class EditOptApplication(npyscreen.NPSAppManaged):
|
|||||||
self.program_opts = program_opts
|
self.program_opts = program_opts
|
||||||
self.invokeai_opts = invokeai_opts
|
self.invokeai_opts = invokeai_opts
|
||||||
self.user_cancelled = False
|
self.user_cancelled = False
|
||||||
self.user_selections = default_user_selections(program_opts)
|
self.install_selections = default_user_selections(program_opts)
|
||||||
|
|
||||||
def onStart(self):
|
def onStart(self):
|
||||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||||
@ -627,19 +626,19 @@ def default_startup_options(init_file: Path) -> Namespace:
|
|||||||
opts.nsfw_checker = True
|
opts.nsfw_checker = True
|
||||||
return opts
|
return opts
|
||||||
|
|
||||||
def default_user_selections(program_opts: Namespace) -> UserSelections:
|
def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
||||||
return UserSelections(
|
installer = ModelInstall(config)
|
||||||
install_models=default_dataset()
|
models = installer.all_models()
|
||||||
|
return InstallSelections(
|
||||||
|
install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id]
|
||||||
if program_opts.default_only
|
if program_opts.default_only
|
||||||
else recommended_datasets()
|
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
|
||||||
if program_opts.yes_to_all
|
if program_opts.yes_to_all
|
||||||
else dict(),
|
else list(),
|
||||||
purge_deleted_models=False,
|
|
||||||
scan_directory=None,
|
scan_directory=None,
|
||||||
autoscan_on_startup=None,
|
autoscan_on_startup=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||||
logger.info("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
|
logger.info("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
|
||||||
@ -696,7 +695,7 @@ def run_console_ui(
|
|||||||
if editApp.user_cancelled:
|
if editApp.user_cancelled:
|
||||||
return (None, None)
|
return (None, None)
|
||||||
else:
|
else:
|
||||||
return (editApp.new_opts, editApp.user_selections)
|
return (editApp.new_opts, editApp.install_selections)
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
|
@ -2,18 +2,18 @@
|
|||||||
Utility (backend) functions used by model_install.py
|
Utility (backend) functions used by model_install.py
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass,field
|
from dataclasses import dataclass,field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryFile
|
from tempfile import TemporaryDirectory
|
||||||
from typing import List, Dict, Set, Callable
|
from typing import List, Dict, Callable, Union, Set
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import AutoencoderKL, StableDiffusionPipeline
|
||||||
from huggingface_hub import hf_hub_url, HfFolder
|
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -21,7 +21,9 @@ from tqdm import tqdm
|
|||||||
import invokeai.configs as configs
|
import invokeai.configs as configs
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType
|
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
|
||||||
|
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
||||||
|
from invokeai.backend.util import download_with_resume
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||||
from ..util.logging import InvokeAILogger
|
from ..util.logging import InvokeAILogger
|
||||||
|
|
||||||
@ -29,19 +31,11 @@ warnings.filterwarnings("ignore")
|
|||||||
|
|
||||||
# --------------------------globals-----------------------
|
# --------------------------globals-----------------------
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
logger = InvokeAILogger.getLogger(name='InvokeAI')
|
||||||
Model_dir = "models"
|
|
||||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
|
||||||
|
|
||||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||||
|
|
||||||
# initial models omegaconf
|
|
||||||
Datasets = None
|
|
||||||
|
|
||||||
# logger
|
|
||||||
logger = InvokeAILogger.getLogger(name='InvokeAI')
|
|
||||||
|
|
||||||
Config_preamble = """
|
Config_preamble = """
|
||||||
# This file describes the alternative machine learning models
|
# This file describes the alternative machine learning models
|
||||||
# available to InvokeAI script.
|
# available to InvokeAI script.
|
||||||
@ -52,6 +46,24 @@ Config_preamble = """
|
|||||||
# was trained on.
|
# was trained on.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
LEGACY_CONFIGS = {
|
||||||
|
BaseModelType.StableDiffusion1: {
|
||||||
|
ModelVariantType.Normal: 'v1-inference.yaml',
|
||||||
|
ModelVariantType.Inpaint: 'v1-inpainting-inference.yaml',
|
||||||
|
},
|
||||||
|
|
||||||
|
BaseModelType.StableDiffusion2: {
|
||||||
|
ModelVariantType.Normal: {
|
||||||
|
SchedulerPredictionType.Epsilon: 'v2-inference.yaml',
|
||||||
|
SchedulerPredictionType.VPrediction: 'v2-inference-v.yaml',
|
||||||
|
},
|
||||||
|
ModelVariantType.Inpaint: {
|
||||||
|
SchedulerPredictionType.Epsilon: 'v2-inpainting-inference.yaml',
|
||||||
|
SchedulerPredictionType.VPrediction: 'v2-inpainting-inference-v.yaml',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelInstallList:
|
class ModelInstallList:
|
||||||
'''Class for listing models to be installed/removed'''
|
'''Class for listing models to be installed/removed'''
|
||||||
@ -59,18 +71,11 @@ class ModelInstallList:
|
|||||||
remove_models: List[str] = field(default_factory=list)
|
remove_models: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UserSelections():
|
class InstallSelections():
|
||||||
install_models: List[str]= field(default_factory=list)
|
install_models: List[str]= field(default_factory=list)
|
||||||
remove_models: List[str]=field(default_factory=list)
|
remove_models: List[str]=field(default_factory=list)
|
||||||
install_cn_models: List[str] = field(default_factory=list)
|
|
||||||
remove_cn_models: List[str] = field(default_factory=list)
|
|
||||||
install_lora_models: List[str] = field(default_factory=list)
|
|
||||||
remove_lora_models: List[str] = field(default_factory=list)
|
|
||||||
install_ti_models: List[str] = field(default_factory=list)
|
|
||||||
remove_ti_models: List[str] = field(default_factory=list)
|
|
||||||
scan_directory: Path = None
|
scan_directory: Path = None
|
||||||
autoscan_on_startup: bool=False
|
autoscan_on_startup: bool=False
|
||||||
import_model_paths: str=None
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelLoadInfo():
|
class ModelLoadInfo():
|
||||||
@ -82,18 +87,30 @@ class ModelLoadInfo():
|
|||||||
description: str = ''
|
description: str = ''
|
||||||
installed: bool = False
|
installed: bool = False
|
||||||
recommended: bool = False
|
recommended: bool = False
|
||||||
|
default: bool = False
|
||||||
|
|
||||||
class ModelInstall(object):
|
class ModelInstall(object):
|
||||||
def __init__(self,config:InvokeAIAppConfig):
|
def __init__(self,
|
||||||
|
config:InvokeAIAppConfig,
|
||||||
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||||
|
access_token:str = None):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.mgr = ModelManager(config.model_conf_path)
|
self.mgr = ModelManager(config.model_conf_path)
|
||||||
self.datasets = OmegaConf.load(Dataset_path)
|
self.datasets = OmegaConf.load(Dataset_path)
|
||||||
|
self.prediction_helper = prediction_type_helper
|
||||||
|
self.access_token = access_token or HfFolder.get_token()
|
||||||
|
self.reverse_paths = self._reverse_paths(self.datasets)
|
||||||
|
|
||||||
def all_models(self)->Dict[str,ModelLoadInfo]:
|
def all_models(self)->Dict[str,ModelLoadInfo]:
|
||||||
'''
|
'''
|
||||||
Return dict of model_key=>ModelStatus
|
Return dict of model_key=>ModelLoadInfo objects.
|
||||||
|
This method consolidates and simplifies the entries in both
|
||||||
|
models.yaml and INITIAL_MODELS.yaml so that they can
|
||||||
|
be treated uniformly. It also sorts the models alphabetically
|
||||||
|
by their name, to improve the display somewhat.
|
||||||
'''
|
'''
|
||||||
model_dict = dict()
|
model_dict = dict()
|
||||||
|
|
||||||
# first populate with the entries in INITIAL_MODELS.yaml
|
# first populate with the entries in INITIAL_MODELS.yaml
|
||||||
for key, value in self.datasets.items():
|
for key, value in self.datasets.items():
|
||||||
name,base,model_type = ModelManager.parse_key(key)
|
name,base,model_type = ModelManager.parse_key(key)
|
||||||
@ -128,102 +145,237 @@ class ModelInstall(object):
|
|||||||
if model_type==ModelType.Pipeline:
|
if model_type==ModelType.Pipeline:
|
||||||
models.add(key)
|
models.add(key)
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
def recommended_models(self)->Set[str]:
|
||||||
|
starters = self.starter_models()
|
||||||
|
return set([x for x in starters if self.datasets[x].get('recommended',False)])
|
||||||
|
|
||||||
|
def default_model(self)->str:
|
||||||
|
starters = self.starter_models()
|
||||||
|
defaults = [x for x in starters if self.datasets[x].get('default',False)]
|
||||||
|
return defaults[0]
|
||||||
|
|
||||||
|
def install(self, selections: InstallSelections):
|
||||||
|
job = 1
|
||||||
|
jobs = len(selections.remove_models) + len(selections.install_models)
|
||||||
|
if selections.scan_directory:
|
||||||
|
jobs += 1
|
||||||
|
|
||||||
|
# remove requested models
|
||||||
def default_config_file():
|
for key in selections.remove_models:
|
||||||
return config.model_conf_path
|
name,base,mtype = self.mgr.parse_key(key)
|
||||||
|
logger.info(f'Deleting {mtype} model {name} [{job}/{jobs}]')
|
||||||
def sd_configs():
|
self.mgr.del_model(name,base,mtype)
|
||||||
return config.legacy_conf_path
|
job += 1
|
||||||
|
|
||||||
def initial_models():
|
|
||||||
global Datasets
|
|
||||||
if Datasets:
|
|
||||||
return Datasets
|
|
||||||
return (Datasets := OmegaConf.load(Dataset_path)['diffusers'])
|
|
||||||
|
|
||||||
def add_models(model_manager, config_file_path: Path, models: List[tuple[str,str,str]]):
|
|
||||||
print(f'Installing {models}')
|
|
||||||
|
|
||||||
def del_models(model_manager, config_file_path: Path, models: List[tuple[str,str,str]]):
|
|
||||||
for base, model_type, name in models:
|
|
||||||
logger.info(f"Deleting {name}...")
|
|
||||||
model_manager.del_model(name, base, model_type)
|
|
||||||
model_manager.commit(config_file_path)
|
|
||||||
|
|
||||||
def install_requested_models(
|
|
||||||
diffusers: ModelInstallList = None,
|
|
||||||
controlnet: ModelInstallList = None,
|
|
||||||
lora: ModelInstallList = None,
|
|
||||||
ti: ModelInstallList = None,
|
|
||||||
cn_model_map: Dict[str,str] = None, # temporary - move to model manager
|
|
||||||
scan_directory: Path = None,
|
|
||||||
external_models: List[str] = None,
|
|
||||||
scan_at_startup: bool = False,
|
|
||||||
precision: str = "float16",
|
|
||||||
config_file_path: Path = None,
|
|
||||||
model_config_file_callback: Callable[[Path],Path] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Entry point for installing/deleting starter models, or installing external models.
|
|
||||||
"""
|
|
||||||
access_token = HfFolder.get_token()
|
|
||||||
config_file_path = config_file_path or default_config_file()
|
|
||||||
if not config_file_path.exists():
|
|
||||||
open(config_file_path, "w")
|
|
||||||
|
|
||||||
# prevent circular import here
|
|
||||||
from ..model_management import ModelManager
|
|
||||||
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
|
|
||||||
|
|
||||||
for x in [controlnet, lora, ti, diffusers]:
|
|
||||||
if x:
|
|
||||||
add_models(model_manager, config_file_path, x.install_models)
|
|
||||||
del_models(model_manager, config_file_path, x.remove_models)
|
|
||||||
|
|
||||||
# if diffusers:
|
# add requested models
|
||||||
|
for path in selections.install_models:
|
||||||
|
logger.info(f'Installing {path} [{job}/{jobs}]')
|
||||||
|
self.heuristic_install(path)
|
||||||
|
job += 1
|
||||||
|
|
||||||
# if diffusers.install_models and len(diffusers.install_models) > 0:
|
# import from the scan directory, if any
|
||||||
# logger.info("Installing requested models")
|
if path := selections.scan_directory:
|
||||||
# downloaded_paths = download_weight_datasets(
|
logger.info(f'Scanning and importing models from directory {path} [{job}/{jobs}]')
|
||||||
# models=diffusers.install_models,
|
self.heuristic_install(path)
|
||||||
# access_token=None,
|
|
||||||
# precision=precision,
|
|
||||||
# )
|
|
||||||
# successful = {x:v for x,v in downloaded_paths.items() if v is not None}
|
|
||||||
# if len(successful) > 0:
|
|
||||||
# update_config_file(successful, config_file_path)
|
|
||||||
# if len(successful) < len(diffusers.install_models):
|
|
||||||
# unsuccessful = [x for x in downloaded_paths if downloaded_paths[x] is None]
|
|
||||||
# logger.warning(f"Some of the model downloads were not successful: {unsuccessful}")
|
|
||||||
|
|
||||||
# due to above, we have to reload the model manager because conf file
|
self.mgr.commit()
|
||||||
# was changed behind its back
|
|
||||||
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
|
|
||||||
|
|
||||||
external_models = external_models or list()
|
if selections.autoscan_on_startup and Path(selections.scan_directory).is_dir():
|
||||||
if scan_directory:
|
update_autoconvert_dir(selections.scan_directory)
|
||||||
external_models.append(str(scan_directory))
|
else:
|
||||||
|
update_autoconvert_dir(None)
|
||||||
|
|
||||||
if len(external_models) > 0:
|
def heuristic_install(self, model_path_id_or_url: Union[str,Path]):
|
||||||
logger.info("INSTALLING EXTERNAL MODELS")
|
# A little hack to allow nested routines to retrieve info on the requested ID
|
||||||
for path_url_or_repo in external_models:
|
self.current_id = model_path_id_or_url
|
||||||
logger.debug(path_url_or_repo)
|
|
||||||
try:
|
path = Path(model_path_id_or_url)
|
||||||
model_manager.heuristic_import(
|
|
||||||
path_url_or_repo,
|
# checkpoint file, or similar
|
||||||
commit_to_conf=config_file_path,
|
if path.is_file():
|
||||||
config_file_callback = model_config_file_callback,
|
self._install_path(path)
|
||||||
|
return
|
||||||
|
|
||||||
|
# folders style or similar
|
||||||
|
if path.is_dir() and any([(path/x).exists() for x in ['config.json','model_index.json','learned_embeds.bin']]):
|
||||||
|
self._install_path(path)
|
||||||
|
return
|
||||||
|
|
||||||
|
# recursive scan
|
||||||
|
if path.is_dir():
|
||||||
|
for child in path.iterdir():
|
||||||
|
self.heuristic_install(child)
|
||||||
|
return
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
parts = str(path).split('/')
|
||||||
|
if len(parts) == 2:
|
||||||
|
self._install_repo(str(path))
|
||||||
|
return
|
||||||
|
|
||||||
|
# a URL
|
||||||
|
if model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
|
||||||
|
self._install_url(model_path_id_or_url)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||||
|
|
||||||
|
# install a model from a local path. The optional info parameter is there to prevent
|
||||||
|
# the model from being probed twice in the event that it has already been probed.
|
||||||
|
def _install_path(self, path: Path, info: ModelProbeInfo=None):
|
||||||
|
try:
|
||||||
|
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||||
|
if info.model_type == ModelType.Pipeline:
|
||||||
|
attributes = self._make_attributes(path,info)
|
||||||
|
self.mgr.add_model(model_name = path.stem if info.format=='checkpoint' else path.name,
|
||||||
|
base_model = info.base_type,
|
||||||
|
model_type = info.model_type,
|
||||||
|
model_attributes = attributes
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f'{str(e)} Skipping registration.')
|
||||||
|
|
||||||
|
def _install_url(self, url: str):
|
||||||
|
# copy to a staging area, probe, import and delete
|
||||||
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
|
location = download_with_resume(url,Path(staging))
|
||||||
|
if not location:
|
||||||
|
logger.error(f'Unable to download {url}. Skipping.')
|
||||||
|
info = ModelProbe().heuristic_probe(location)
|
||||||
|
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
|
||||||
|
models_path = shutil.move(location,dest)
|
||||||
|
|
||||||
|
# staged version will be garbage-collected at this time
|
||||||
|
self._install_path(Path(models_path), info)
|
||||||
|
|
||||||
|
def _get_model_name(self,path_name: str, location: Path)->str:
|
||||||
|
'''
|
||||||
|
Calculate a name for the model - primitive implementation.
|
||||||
|
'''
|
||||||
|
if key := self.reverse_paths.get(path_name):
|
||||||
|
(name, base, mtype) = ModelManager.parse_key(key)
|
||||||
|
return name
|
||||||
|
else:
|
||||||
|
return location.stem
|
||||||
|
|
||||||
|
def _install_repo(self, repo_id: str):
|
||||||
|
hinfo = HfApi().model_info(repo_id)
|
||||||
|
|
||||||
|
# we try to figure out how to download this most economically
|
||||||
|
# list all the files in the repo
|
||||||
|
files = [x.rfilename for x in hinfo.siblings]
|
||||||
|
|
||||||
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
|
staging = Path(staging)
|
||||||
|
if 'model_index.json' in files:
|
||||||
|
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
||||||
|
|
||||||
|
elif 'pytorch_lora_weights.bin' in files:
|
||||||
|
location = self._download_hf_model(repo_id, ['pytorch_lora_weights.bin'], staging) # LoRA
|
||||||
|
|
||||||
|
elif self.config.precision=='float16' and 'diffusion_pytorch_model.fp16.safetensors' in files: # vae, controlnet or some other standalone
|
||||||
|
files = ['config.json', 'diffusion_pytorch_model.fp16.safetensors']
|
||||||
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
|
|
||||||
|
elif 'diffusion_pytorch_model.safetensors' in files:
|
||||||
|
files = ['config.json', 'diffusion_pytorch_model.safetensors']
|
||||||
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
|
|
||||||
|
elif 'learned_embeds.bin' in files:
|
||||||
|
location = self._download_hf_model(repo_id, ['learned_embeds.bin'], staging)
|
||||||
|
|
||||||
|
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
||||||
|
dest = self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id,location)
|
||||||
|
if dest.exists():
|
||||||
|
shutil.rmtree(dest)
|
||||||
|
shutil.copytree(location,dest)
|
||||||
|
self._install_path(dest, info)
|
||||||
|
|
||||||
|
def _make_attributes(self, path: Path, info: ModelProbeInfo)->dict:
|
||||||
|
|
||||||
|
# convoluted way to retrieve the description from datasets
|
||||||
|
description = f'{info.base_type.value} {info.model_type.value} model'
|
||||||
|
if key := self.reverse_paths.get(self.current_id):
|
||||||
|
if key in self.datasets:
|
||||||
|
description = self.datasets[key]['description']
|
||||||
|
|
||||||
|
attributes = dict(
|
||||||
|
path = str(path),
|
||||||
|
description = str(description),
|
||||||
|
format = info.format,
|
||||||
|
)
|
||||||
|
if info.model_type == ModelType.Pipeline:
|
||||||
|
attributes.update(
|
||||||
|
dict(
|
||||||
|
variant = info.variant_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if info.base_type == BaseModelType.StableDiffusion2:
|
||||||
|
attributes.update(
|
||||||
|
dict(
|
||||||
|
prediction_type = info.prediction_type,
|
||||||
|
upcast_attention = info.prediction_type == SchedulerPredictionType.VPrediction,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
if info.format=="checkpoint":
|
||||||
sys.exit(-1)
|
try:
|
||||||
except Exception:
|
legacy_conf = LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type] if BaseModelType.StableDiffusion2 \
|
||||||
pass
|
else LEGACY_CONFIGS[info.base_type][info.variant_type]
|
||||||
|
except KeyError:
|
||||||
|
legacy_conf = 'v1-inference.yaml' # best guess
|
||||||
|
|
||||||
|
attributes.update(
|
||||||
|
dict(
|
||||||
|
config = str(self.config.legacy_conf_path / legacy_conf)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return attributes
|
||||||
|
|
||||||
if scan_at_startup and scan_directory.is_dir():
|
def _download_hf_pipeline(self, repo_id: str, staging: Path)->Path:
|
||||||
update_autoconvert_dir(scan_directory)
|
'''
|
||||||
else:
|
This retrieves a StableDiffusion model from cache or remote and then
|
||||||
update_autoconvert_dir(None)
|
does a save_pretrained() to the indicated staging area.
|
||||||
|
'''
|
||||||
|
_,name = repo_id.split("/")
|
||||||
|
revisions = ['fp16','main'] if self.config.precision=='float16' else ['main']
|
||||||
|
model = None
|
||||||
|
for revision in revisions:
|
||||||
|
try:
|
||||||
|
model = StableDiffusionPipeline.from_pretrained(repo_id,revision=revision,safety_checker=None)
|
||||||
|
except: # most errors are due to fp16 not being present. Fix this to catch other errors
|
||||||
|
pass
|
||||||
|
if model:
|
||||||
|
break
|
||||||
|
if not model:
|
||||||
|
logger.error(f'Diffusers model {repo_id} could not be downloaded. Skipping.')
|
||||||
|
return None
|
||||||
|
model.save_pretrained(staging / name, safe_serialization=True)
|
||||||
|
return staging / name
|
||||||
|
|
||||||
|
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path)->Path:
|
||||||
|
_,name = repo_id.split("/")
|
||||||
|
location = staging / name
|
||||||
|
paths = list()
|
||||||
|
for filename in files:
|
||||||
|
p = hf_download_with_resume(repo_id,
|
||||||
|
model_dir=location,
|
||||||
|
model_name=filename,
|
||||||
|
access_token = self.access_token
|
||||||
|
)
|
||||||
|
if p:
|
||||||
|
paths.append(p)
|
||||||
|
else:
|
||||||
|
logger.warning(f'Could not download {filename} from {repo_id}.')
|
||||||
|
|
||||||
|
return location if len(paths)>0 else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _reverse_paths(cls,datasets)->dict:
|
||||||
|
'''
|
||||||
|
Reverse mapping from repo_id/path to destination name.
|
||||||
|
'''
|
||||||
|
return {v.get('path') or v.get('repo_id') : k for k, v in datasets.items()}
|
||||||
|
|
||||||
def update_autoconvert_dir(autodir: Path):
|
def update_autoconvert_dir(autodir: Path):
|
||||||
'''
|
'''
|
||||||
@ -249,89 +401,7 @@ def yes_or_no(prompt: str, default_yes=True):
|
|||||||
return response[0] in ("y", "Y")
|
return response[0] in ("y", "Y")
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def recommended_datasets() -> List['str']:
|
def hf_download_from_pretrained(
|
||||||
datasets = set()
|
|
||||||
for ds in initial_models().keys():
|
|
||||||
if initial_models()[ds].get("recommended", False):
|
|
||||||
datasets.add(ds)
|
|
||||||
return list(datasets)
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def default_dataset() -> dict:
|
|
||||||
datasets = set()
|
|
||||||
for ds in initial_models().keys():
|
|
||||||
if initial_models()[ds].get("default", False):
|
|
||||||
datasets.add(ds)
|
|
||||||
return list(datasets)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def all_datasets() -> dict:
|
|
||||||
datasets = dict()
|
|
||||||
for ds in initial_models().keys():
|
|
||||||
datasets[ds] = True
|
|
||||||
return datasets
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
# look for legacy model.ckpt in models directory and offer to
|
|
||||||
# normalize its name
|
|
||||||
def migrate_models_ckpt():
|
|
||||||
model_path = os.path.join(config.root_dir, Model_dir, Weights_dir)
|
|
||||||
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
|
||||||
return
|
|
||||||
new_name = initial_models()["stable-diffusion-1.4"]["file"]
|
|
||||||
logger.warning(
|
|
||||||
'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.'
|
|
||||||
)
|
|
||||||
logger.warning(f"model.ckpt => {new_name}")
|
|
||||||
os.replace(
|
|
||||||
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def download_weight_datasets(
|
|
||||||
models: List[str], access_token: str, precision: str = "float32"
|
|
||||||
):
|
|
||||||
migrate_models_ckpt()
|
|
||||||
successful = dict()
|
|
||||||
for mod in models:
|
|
||||||
logger.info(f"Downloading {mod}:")
|
|
||||||
successful[mod] = _download_repo_or_file(
|
|
||||||
initial_models()[mod], access_token, precision=precision
|
|
||||||
)
|
|
||||||
return successful
|
|
||||||
|
|
||||||
|
|
||||||
def _download_repo_or_file(
|
|
||||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
|
||||||
) -> Path:
|
|
||||||
path = None
|
|
||||||
if mconfig["format"] == "ckpt":
|
|
||||||
path = _download_ckpt_weights(mconfig, access_token)
|
|
||||||
else:
|
|
||||||
path = _download_diffusion_weights(mconfig, access_token, precision=precision)
|
|
||||||
if "vae" in mconfig and "repo_id" in mconfig["vae"]:
|
|
||||||
_download_diffusion_weights(
|
|
||||||
mconfig["vae"], access_token, precision=precision
|
|
||||||
)
|
|
||||||
return path
|
|
||||||
|
|
||||||
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
|
||||||
repo_id = mconfig["repo_id"]
|
|
||||||
filename = mconfig["file"]
|
|
||||||
cache_dir = os.path.join(config.root_dir, Model_dir, Weights_dir)
|
|
||||||
return hf_download_with_resume(
|
|
||||||
repo_id=repo_id,
|
|
||||||
model_dir=cache_dir,
|
|
||||||
model_name=filename,
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def download_from_hf(
|
|
||||||
model_class: object, model_name: str, destination: Path, **kwargs
|
model_class: object, model_name: str, destination: Path, **kwargs
|
||||||
):
|
):
|
||||||
logger = InvokeAILogger.getLogger('InvokeAI')
|
logger = InvokeAILogger.getLogger('InvokeAI')
|
||||||
@ -345,35 +415,6 @@ def download_from_hf(
|
|||||||
model.save_pretrained(destination, safe_serialization=True)
|
model.save_pretrained(destination, safe_serialization=True)
|
||||||
return destination
|
return destination
|
||||||
|
|
||||||
def _download_diffusion_weights(
|
|
||||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
|
||||||
):
|
|
||||||
repo_id = mconfig["repo_id"]
|
|
||||||
model_class = (
|
|
||||||
StableDiffusionGeneratorPipeline
|
|
||||||
if mconfig.get("format", None) == "diffusers"
|
|
||||||
else AutoencoderKL
|
|
||||||
)
|
|
||||||
extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}]
|
|
||||||
path = None
|
|
||||||
for extra_args in extra_arg_list:
|
|
||||||
try:
|
|
||||||
path = download_from_hf(
|
|
||||||
model_class,
|
|
||||||
repo_id,
|
|
||||||
safety_checker=None,
|
|
||||||
**extra_args,
|
|
||||||
)
|
|
||||||
except OSError as e:
|
|
||||||
if 'Revision Not Found' in str(e):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
logger.error(str(e))
|
|
||||||
if path:
|
|
||||||
break
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def hf_download_with_resume(
|
def hf_download_with_resume(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
@ -432,128 +473,3 @@ def hf_download_with_resume(
|
|||||||
return model_dest
|
return model_dest
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def update_config_file(successfully_downloaded: dict, config_file: Path):
|
|
||||||
config_file = (
|
|
||||||
Path(config_file) if config_file is not None else default_config_file()
|
|
||||||
)
|
|
||||||
|
|
||||||
# In some cases (incomplete setup, etc), the default configs directory might be missing.
|
|
||||||
# Create it if it doesn't exist.
|
|
||||||
# this check is ignored if opt.config_file is specified - user is assumed to know what they
|
|
||||||
# are doing if they are passing a custom config file from elsewhere.
|
|
||||||
if config_file is default_config_file() and not config_file.parent.exists():
|
|
||||||
configs_src = Dataset_path.parent
|
|
||||||
configs_dest = default_config_file().parent
|
|
||||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
|
||||||
|
|
||||||
yaml = new_config_file_contents(successfully_downloaded, config_file)
|
|
||||||
|
|
||||||
try:
|
|
||||||
backup = None
|
|
||||||
if os.path.exists(config_file):
|
|
||||||
logger.warning(
|
|
||||||
f"{config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
|
||||||
)
|
|
||||||
backup = config_file.with_suffix(".yaml.orig")
|
|
||||||
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
|
|
||||||
if sys.platform == "win32" and backup.is_file():
|
|
||||||
backup.unlink()
|
|
||||||
config_file.rename(backup)
|
|
||||||
|
|
||||||
with TemporaryFile() as tmp:
|
|
||||||
tmp.write(Config_preamble.encode())
|
|
||||||
tmp.write(yaml.encode())
|
|
||||||
|
|
||||||
with open(str(config_file.expanduser().resolve()), "wb") as new_config:
|
|
||||||
tmp.seek(0)
|
|
||||||
new_config.write(tmp.read())
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error creating config file {config_file}: {str(e)}")
|
|
||||||
if backup is not None:
|
|
||||||
logger.info("restoring previous config file")
|
|
||||||
## workaround, for WinError 183, see above
|
|
||||||
if sys.platform == "win32" and config_file.is_file():
|
|
||||||
config_file.unlink()
|
|
||||||
backup.rename(config_file)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"Successfully created new configuration file {config_file}")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def new_config_file_contents(
|
|
||||||
successfully_downloaded: dict,
|
|
||||||
config_file: Path,
|
|
||||||
) -> str:
|
|
||||||
if config_file.exists():
|
|
||||||
conf = OmegaConf.load(str(config_file.expanduser().resolve()))
|
|
||||||
else:
|
|
||||||
conf = OmegaConf.create()
|
|
||||||
|
|
||||||
default_selected = None
|
|
||||||
for model in successfully_downloaded:
|
|
||||||
# a bit hacky - what we are doing here is seeing whether a checkpoint
|
|
||||||
# version of the model was previously defined, and whether the current
|
|
||||||
# model is a diffusers (indicated with a path)
|
|
||||||
if conf.get(model) and Path(successfully_downloaded[model]).is_dir():
|
|
||||||
delete_weights(model, conf[model])
|
|
||||||
|
|
||||||
stanza = {}
|
|
||||||
mod = initial_models()[model]
|
|
||||||
stanza["description"] = mod["description"]
|
|
||||||
stanza["repo_id"] = mod["repo_id"]
|
|
||||||
stanza["format"] = mod["format"]
|
|
||||||
# diffusers don't need width and height (probably .ckpt doesn't either)
|
|
||||||
# so we no longer require these in INITIAL_MODELS.yaml
|
|
||||||
if "width" in mod:
|
|
||||||
stanza["width"] = mod["width"]
|
|
||||||
if "height" in mod:
|
|
||||||
stanza["height"] = mod["height"]
|
|
||||||
if "file" in mod:
|
|
||||||
stanza["weights"] = os.path.relpath(
|
|
||||||
successfully_downloaded[model], start=config.root_dir
|
|
||||||
)
|
|
||||||
stanza["config"] = os.path.normpath(
|
|
||||||
os.path.join(sd_configs(), mod["config"])
|
|
||||||
)
|
|
||||||
if "vae" in mod:
|
|
||||||
if "file" in mod["vae"]:
|
|
||||||
stanza["vae"] = os.path.normpath(
|
|
||||||
os.path.join(Model_dir, Weights_dir, mod["vae"]["file"])
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
stanza["vae"] = mod["vae"]
|
|
||||||
if mod.get("default", False):
|
|
||||||
stanza["default"] = True
|
|
||||||
default_selected = True
|
|
||||||
|
|
||||||
conf[model] = stanza
|
|
||||||
|
|
||||||
# if no default model was chosen, then we select the first
|
|
||||||
# one in the list
|
|
||||||
if not default_selected:
|
|
||||||
conf[list(successfully_downloaded.keys())[0]]["default"] = True
|
|
||||||
|
|
||||||
return OmegaConf.to_yaml(conf)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def delete_weights(model_name: str, conf_stanza: dict):
|
|
||||||
if not (weights := conf_stanza.get("weights")):
|
|
||||||
return
|
|
||||||
if re.match("/VAE/", conf_stanza.get("config")):
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"\nThe checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
|
|
||||||
)
|
|
||||||
|
|
||||||
weights = Path(weights)
|
|
||||||
if not weights.is_absolute():
|
|
||||||
weights = config.root_dir / weights
|
|
||||||
try:
|
|
||||||
weights.unlink()
|
|
||||||
except OSError as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
|
@ -4,3 +4,4 @@ Initialization file for invokeai.backend.model_management
|
|||||||
from .model_manager import ModelManager, ModelInfo
|
from .model_manager import ModelManager, ModelInfo
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
||||||
|
|
||||||
|
@ -682,7 +682,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
for model_key, model_config in list(self.models.items()):
|
for model_key, model_config in list(self.models.items()):
|
||||||
model_name, base_model, model_type = self.parse_key(model_key)
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
model_path = str(self.globals.root / model_config.path)
|
model_path = str(self.globals.root_path / model_config.path)
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
if model_class.save_to_config:
|
if model_class.save_to_config:
|
||||||
@ -703,13 +703,14 @@ class ModelManager(object):
|
|||||||
for entry_name in os.listdir(models_dir):
|
for entry_name in os.listdir(models_dir):
|
||||||
model_path = os.path.join(models_dir, entry_name)
|
model_path = os.path.join(models_dir, entry_name)
|
||||||
if model_path not in loaded_files: # TODO: check
|
if model_path not in loaded_files: # TODO: check
|
||||||
model_name = Path(model_path).stem
|
model_path = Path(model_path)
|
||||||
|
model_name = model_path.name if model_path.is_dir else model_path.stem
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
|
||||||
if model_key in self.models:
|
if model_key in self.models:
|
||||||
raise Exception(f"Model with key {model_key} added twice")
|
raise Exception(f"Model with key {model_key} added twice")
|
||||||
|
|
||||||
model_config: ModelConfigBase = model_class.probe_config(model_path)
|
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
||||||
self.models[model_key] = model_config
|
self.models[model_key] = model_config
|
||||||
new_models_found = True
|
new_models_found = True
|
||||||
|
|
||||||
|
@ -15,13 +15,13 @@ import invokeai.backend.util.logging as logger
|
|||||||
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings
|
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelVariantInfo(object):
|
class ModelProbeInfo(object):
|
||||||
model_type: ModelType
|
model_type: ModelType
|
||||||
base_type: BaseModelType
|
base_type: BaseModelType
|
||||||
variant_type: ModelVariantType
|
variant_type: ModelVariantType
|
||||||
prediction_type: SchedulerPredictionType
|
prediction_type: SchedulerPredictionType
|
||||||
upcast_attention: bool
|
upcast_attention: bool
|
||||||
format: Literal['folder','checkpoint']
|
format: Literal['diffusers','checkpoint']
|
||||||
image_size: int
|
image_size: int
|
||||||
|
|
||||||
class ProbeBase(object):
|
class ProbeBase(object):
|
||||||
@ -31,7 +31,7 @@ class ProbeBase(object):
|
|||||||
class ModelProbe(object):
|
class ModelProbe(object):
|
||||||
|
|
||||||
PROBES = {
|
PROBES = {
|
||||||
'folder': { },
|
'diffusers': { },
|
||||||
'checkpoint': { },
|
'checkpoint': { },
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_probe(cls,
|
def register_probe(cls,
|
||||||
format: Literal['folder','file'],
|
format: Literal['diffusers','checkpoint'],
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
probe_class: ProbeBase):
|
probe_class: ProbeBase):
|
||||||
cls.PROBES[format][model_type] = probe_class
|
cls.PROBES[format][model_type] = probe_class
|
||||||
@ -51,8 +51,8 @@ class ModelProbe(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def heuristic_probe(cls,
|
def heuristic_probe(cls,
|
||||||
model: Union[Dict, ModelMixin, Path],
|
model: Union[Dict, ModelMixin, Path],
|
||||||
prediction_type_helper: Callable[[Path],BaseModelType]=None,
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||||
)->ModelVariantInfo:
|
)->ModelProbeInfo:
|
||||||
if isinstance(model,Path):
|
if isinstance(model,Path):
|
||||||
return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper)
|
return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper)
|
||||||
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
|
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
|
||||||
@ -64,7 +64,7 @@ class ModelProbe(object):
|
|||||||
def probe(cls,
|
def probe(cls,
|
||||||
model_path: Path,
|
model_path: Path,
|
||||||
model: Union[Dict, ModelMixin] = None,
|
model: Union[Dict, ModelMixin] = None,
|
||||||
prediction_type_helper: Callable[[Path],BaseModelType] = None)->ModelVariantInfo:
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType] = None)->ModelProbeInfo:
|
||||||
'''
|
'''
|
||||||
Probe the model at model_path and return sufficient information about it
|
Probe the model at model_path and return sufficient information about it
|
||||||
to place it somewhere in the models directory hierarchy. If the model is
|
to place it somewhere in the models directory hierarchy. If the model is
|
||||||
@ -74,14 +74,14 @@ class ModelProbe(object):
|
|||||||
between V2-Base and V2-768 SD models.
|
between V2-Base and V2-768 SD models.
|
||||||
'''
|
'''
|
||||||
if model_path:
|
if model_path:
|
||||||
format = 'folder' if model_path.is_dir() else 'checkpoint'
|
format = 'diffusers' if model_path.is_dir() else 'checkpoint'
|
||||||
else:
|
else:
|
||||||
format = 'folder' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
format = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
||||||
|
|
||||||
model_info = None
|
model_info = None
|
||||||
try:
|
try:
|
||||||
model_type = cls.get_model_type_from_folder(model_path, model) \
|
model_type = cls.get_model_type_from_folder(model_path, model) \
|
||||||
if format == 'folder' \
|
if format == 'diffusers' \
|
||||||
else cls.get_model_type_from_checkpoint(model_path, model)
|
else cls.get_model_type_from_checkpoint(model_path, model)
|
||||||
probe_class = cls.PROBES[format].get(model_type)
|
probe_class = cls.PROBES[format].get(model_type)
|
||||||
if not probe_class:
|
if not probe_class:
|
||||||
@ -90,7 +90,7 @@ class ModelProbe(object):
|
|||||||
base_type = probe.get_base_type()
|
base_type = probe.get_base_type()
|
||||||
variant_type = probe.get_variant_type()
|
variant_type = probe.get_variant_type()
|
||||||
prediction_type = probe.get_scheduler_prediction_type()
|
prediction_type = probe.get_scheduler_prediction_type()
|
||||||
model_info = ModelVariantInfo(
|
model_info = ModelProbeInfo(
|
||||||
model_type = model_type,
|
model_type = model_type,
|
||||||
base_type = base_type,
|
base_type = base_type,
|
||||||
variant_type = variant_type,
|
variant_type = variant_type,
|
||||||
@ -196,7 +196,7 @@ class CheckpointProbeBase(ProbeBase):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
checkpoint_path: Path,
|
checkpoint_path: Path,
|
||||||
checkpoint: dict,
|
checkpoint: dict,
|
||||||
helper: Callable[[Path],BaseModelType] = None
|
helper: Callable[[Path],SchedulerPredictionType] = None
|
||||||
)->BaseModelType:
|
)->BaseModelType:
|
||||||
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
||||||
self.checkpoint_path = checkpoint_path
|
self.checkpoint_path = checkpoint_path
|
||||||
@ -405,11 +405,11 @@ class LoRAFolderProbe(FolderProbeBase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
############## register probe classes ######
|
############## register probe classes ######
|
||||||
ModelProbe.register_probe('folder', ModelType.Pipeline, PipelineFolderProbe)
|
ModelProbe.register_probe('diffusers', ModelType.Pipeline, PipelineFolderProbe)
|
||||||
ModelProbe.register_probe('folder', ModelType.Vae, VaeFolderProbe)
|
ModelProbe.register_probe('diffusers', ModelType.Vae, VaeFolderProbe)
|
||||||
ModelProbe.register_probe('folder', ModelType.Lora, LoRAFolderProbe)
|
ModelProbe.register_probe('diffusers', ModelType.Lora, LoRAFolderProbe)
|
||||||
ModelProbe.register_probe('folder', ModelType.TextualInversion, TextualInversionFolderProbe)
|
ModelProbe.register_probe('diffusers', ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||||
ModelProbe.register_probe('folder', ModelType.ControlNet, ControlNetFolderProbe)
|
ModelProbe.register_probe('diffusers', ModelType.ControlNet, ControlNetFolderProbe)
|
||||||
ModelProbe.register_probe('checkpoint', ModelType.Pipeline, PipelineCheckpointProbe)
|
ModelProbe.register_probe('checkpoint', ModelType.Pipeline, PipelineCheckpointProbe)
|
||||||
ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe)
|
ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe)
|
||||||
ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe)
|
ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe)
|
||||||
|
@ -154,7 +154,6 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
def create_config(cls, **kwargs) -> ModelConfigBase:
|
def create_config(cls, **kwargs) -> ModelConfigBase:
|
||||||
if "format" not in kwargs:
|
if "format" not in kwargs:
|
||||||
raise Exception("Field 'format' not found in model config")
|
raise Exception("Field 'format' not found in model config")
|
||||||
|
|
||||||
configs = cls._get_configs()
|
configs = cls._get_configs()
|
||||||
return configs[kwargs["format"]](**kwargs)
|
return configs[kwargs["format"]](**kwargs)
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ sd-1/pipeline/stable-diffusion-v1-5:
|
|||||||
description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
|
description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
|
||||||
repo_id: runwayml/stable-diffusion-v1-5
|
repo_id: runwayml/stable-diffusion-v1-5
|
||||||
recommended: True
|
recommended: True
|
||||||
|
default: True
|
||||||
sd-1/pipeline/stable-diffusion-inpainting:
|
sd-1/pipeline/stable-diffusion-inpainting:
|
||||||
description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)
|
description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)
|
||||||
repo_id: runwayml/stable-diffusion-inpainting
|
repo_id: runwayml/stable-diffusion-inpainting
|
||||||
@ -27,7 +28,7 @@ sd-1/pipeline/Dungeons-and-Diffusion:
|
|||||||
description: Dungeons & Dragons characters (2.13 GB)
|
description: Dungeons & Dragons characters (2.13 GB)
|
||||||
repo_id: 0xJustin/Dungeons-and-Diffusion
|
repo_id: 0xJustin/Dungeons-and-Diffusion
|
||||||
recommended: False
|
recommended: False
|
||||||
sd-1/pipeline/dreamlike-photoreal-2.0:
|
sd-1/pipeline/dreamlike-photoreal-2:
|
||||||
description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB)
|
description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB)
|
||||||
repo_id: dreamlike-art/dreamlike-photoreal-2.0
|
repo_id: dreamlike-art/dreamlike-photoreal-2.0
|
||||||
recommended: False
|
recommended: False
|
||||||
|
159
invokeai/configs/stable-diffusion/v2-inpainting-inference-v.yaml
Normal file
159
invokeai/configs/stable-diffusion/v2-inpainting-inference-v.yaml
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
parameterization: "v"
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: hybrid
|
||||||
|
scale_factor: 0.18215
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
finetune_keys: null
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
use_checkpoint: True
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 9
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
layer: "penultimate"
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: null # for concat as in LAION-A
|
||||||
|
p_unsafe_threshold: 0.1
|
||||||
|
filter_word_list: "data/filters.yaml"
|
||||||
|
max_pwatermark: 0.45
|
||||||
|
batch_size: 8
|
||||||
|
num_workers: 6
|
||||||
|
multinode: True
|
||||||
|
min_size: 512
|
||||||
|
train:
|
||||||
|
shards:
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddMask
|
||||||
|
params:
|
||||||
|
mode: "512train-large"
|
||||||
|
p_drop: 0.25
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards:
|
||||||
|
- "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddMask
|
||||||
|
params:
|
||||||
|
mode: "512train-large"
|
||||||
|
p_drop: 0.25
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: True
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
metrics_over_trainsteps_checkpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 10000
|
||||||
|
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
enable_autocast: False
|
||||||
|
disabled: False
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 5.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
ddim_steps: 50 # todo check these out for depth2img,
|
||||||
|
ddim_eta: 0.0 # todo check these out for depth2img,
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
158
invokeai/configs/stable-diffusion/v2-inpainting-inference.yaml
Normal file
158
invokeai/configs/stable-diffusion/v2-inpainting-inference.yaml
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: hybrid
|
||||||
|
scale_factor: 0.18215
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
finetune_keys: null
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
use_checkpoint: True
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 9
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
layer: "penultimate"
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: null # for concat as in LAION-A
|
||||||
|
p_unsafe_threshold: 0.1
|
||||||
|
filter_word_list: "data/filters.yaml"
|
||||||
|
max_pwatermark: 0.45
|
||||||
|
batch_size: 8
|
||||||
|
num_workers: 6
|
||||||
|
multinode: True
|
||||||
|
min_size: 512
|
||||||
|
train:
|
||||||
|
shards:
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddMask
|
||||||
|
params:
|
||||||
|
mode: "512train-large"
|
||||||
|
p_drop: 0.25
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards:
|
||||||
|
- "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddMask
|
||||||
|
params:
|
||||||
|
mode: "512train-large"
|
||||||
|
p_drop: 0.25
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: True
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
metrics_over_trainsteps_checkpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 10000
|
||||||
|
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
enable_autocast: False
|
||||||
|
disabled: False
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 5.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
ddim_steps: 50 # todo check these out for depth2img,
|
||||||
|
ddim_eta: 0.0 # todo check these out for depth2img,
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
@ -11,7 +11,6 @@ The work is actually done in backend code in model_install_backend.py.
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import curses
|
import curses
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
import traceback
|
import traceback
|
||||||
@ -20,27 +19,21 @@ from multiprocessing import Process
|
|||||||
from multiprocessing.connection import Connection, Pipe
|
from multiprocessing.connection import Connection, Pipe
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import get_terminal_size
|
from shutil import get_terminal_size
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import npyscreen
|
import npyscreen
|
||||||
import torch
|
import torch
|
||||||
from npyscreen import widget
|
from npyscreen import widget
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from invokeai.backend.install.model_install_backend import (
|
from invokeai.backend.install.model_install_backend import (
|
||||||
Dataset_path, # most of these should go!!
|
|
||||||
default_config_file,
|
|
||||||
default_dataset,
|
|
||||||
install_requested_models,
|
|
||||||
recommended_datasets,
|
|
||||||
ModelInstallList,
|
ModelInstallList,
|
||||||
UserSelections,
|
InstallSelections,
|
||||||
ModelInstall
|
ModelInstall,
|
||||||
|
SchedulerPredictionType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_management import ModelManager, BaseModelType, ModelType
|
from invokeai.backend.model_management import ModelManager, ModelType
|
||||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||||
from invokeai.frontend.install.widgets import (
|
from invokeai.frontend.install.widgets import (
|
||||||
CenteredTitleText,
|
CenteredTitleText,
|
||||||
@ -133,7 +126,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
bottom_of_table = self.nextrely
|
bottom_of_table = self.nextrely
|
||||||
|
|
||||||
self.nextrely = top_of_table
|
self.nextrely = top_of_table
|
||||||
self.pipeline_models = self.add_model_widgets(
|
self.pipeline_models = self.add_pipeline_widgets(
|
||||||
model_type=ModelType.Pipeline,
|
model_type=ModelType.Pipeline,
|
||||||
window_width=window_width,
|
window_width=window_width,
|
||||||
exclude = self.starter_models
|
exclude = self.starter_models
|
||||||
@ -210,11 +203,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
starters = self.starter_models
|
starters = self.starter_models
|
||||||
starter_model_labels = self.model_labels
|
starter_model_labels = self.model_labels
|
||||||
|
|
||||||
recommended_models = set([
|
|
||||||
x
|
|
||||||
for x in starters
|
|
||||||
if models[x].recommended
|
|
||||||
])
|
|
||||||
self.installed_models = sorted(
|
self.installed_models = sorted(
|
||||||
[x for x in starters if models[x].installed]
|
[x for x in starters if models[x].installed]
|
||||||
)
|
)
|
||||||
@ -312,16 +300,18 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
return widgets
|
return widgets
|
||||||
|
|
||||||
### Tab for arbitrary diffusers widgets ###
|
### Tab for arbitrary diffusers widgets ###
|
||||||
def add_diffusers_widgets(self,
|
def add_pipeline_widgets(self,
|
||||||
model_type: ModelType=ModelType.Pipeline,
|
model_type: ModelType=ModelType.Pipeline,
|
||||||
window_width: int=120,
|
window_width: int=120,
|
||||||
)->dict[str,npyscreen.widget]:
|
**kwargs,
|
||||||
|
)->dict[str,npyscreen.widget]:
|
||||||
'''Similar to add_model_widgets() but adds some additional widgets at the bottom
|
'''Similar to add_model_widgets() but adds some additional widgets at the bottom
|
||||||
to support the autoload directory'''
|
to support the autoload directory'''
|
||||||
widgets = self.add_model_widgets(
|
widgets = self.add_model_widgets(
|
||||||
model_type = model_type,
|
model_type = model_type,
|
||||||
window_width = window_width,
|
window_width = window_width,
|
||||||
install_prompt=f"Additional {model_type.value.title()} models already installed.",
|
install_prompt=f"Additional {model_type.value.title()} models already installed.",
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
label = "Directory to scan for models to automatically import (<tab> autocompletes):"
|
label = "Directory to scan for models to automatically import (<tab> autocompletes):"
|
||||||
@ -428,7 +418,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
target = process_and_execute,
|
target = process_and_execute,
|
||||||
kwargs=dict(
|
kwargs=dict(
|
||||||
opt = app.program_opts,
|
opt = app.program_opts,
|
||||||
selections = app.user_selections,
|
selections = app.install_selections,
|
||||||
conn_out = child_conn,
|
conn_out = child_conn,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -436,8 +426,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
child_conn.close()
|
child_conn.close()
|
||||||
self.subprocess_connection = parent_conn
|
self.subprocess_connection = parent_conn
|
||||||
self.subprocess = p
|
self.subprocess = p
|
||||||
app.user_selections = UserSelections()
|
app.install_selections = InstallSelections()
|
||||||
# process_and_execute(app.opt, app.user_selections)
|
# process_and_execute(app.opt, app.install_selections)
|
||||||
|
|
||||||
def on_back(self):
|
def on_back(self):
|
||||||
self.parentApp.switchFormPrevious()
|
self.parentApp.switchFormPrevious()
|
||||||
@ -453,7 +443,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
self.parentApp.setNextForm(None)
|
self.parentApp.setNextForm(None)
|
||||||
self.parentApp.user_cancelled = False
|
self.parentApp.user_cancelled = False
|
||||||
self.editing = False
|
self.editing = False
|
||||||
|
|
||||||
########## This routine monitors the child process that is performing model installation and removal #####
|
########## This routine monitors the child process that is performing model installation and removal #####
|
||||||
def while_waiting(self):
|
def while_waiting(self):
|
||||||
'''Called during idle periods. Main task is to update the Log Messages box with messages
|
'''Called during idle periods. Main task is to update the Log Messages box with messages
|
||||||
@ -532,73 +522,24 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
||||||
.import_model_paths: list of URLs, repo_ids and file paths to import
|
.import_model_paths: list of URLs, repo_ids and file paths to import
|
||||||
"""
|
"""
|
||||||
# we're using a global here rather than storing the result in the parentapp
|
selections = self.parentApp.install_selections
|
||||||
# due to some bug in npyscreen that is causing attributes to be lost
|
all_models = self.all_models
|
||||||
selections = self.parentApp.user_selections
|
|
||||||
|
|
||||||
# Starter models to install/remove
|
# Defined models (in INITIAL_CONFIG.yaml or models.yaml) to add/remove
|
||||||
# TO DO - turn these into a dict so we don't have to hard-code the attributes
|
ui_sections = [self.starter_pipelines, self.pipeline_models,
|
||||||
print(f'installed={[x for x in self.all_models if self.all_models[x].installed]}',file=f)
|
self.controlnet_models, self.lora_models, self.ti_models]
|
||||||
for section in [self.starter_pipelines, self.pipeline_models,
|
for section in ui_sections:
|
||||||
self.controlnet_models, self.lora_models, self.ti_models]:
|
|
||||||
selected = set([section['models'][x] for x in section['models_selected'].value])
|
selected = set([section['models'][x] for x in section['models_selected'].value])
|
||||||
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
||||||
models_to_remove = [x for x in section['models'] if x not in selected and self.all_models[x].installed]
|
models_to_remove = [x for x in section['models'] if x not in selected and self.all_models[x].installed]
|
||||||
|
selections.remove_models.extend(models_to_remove)
|
||||||
# "More" models
|
selections.install_models.extend(all_models[x].path or all_models[x].repo_id \
|
||||||
selections.import_model_paths = self.pipeline_models['download_ids'].value.split()
|
for x in models_to_install if all_models[x].path or all_models[x].repo_id)
|
||||||
if diffusers_selected := self.pipeline_models.get('models_selected'):
|
|
||||||
selections.remove_models.extend([x
|
|
||||||
for x in diffusers_selected.values
|
|
||||||
if self.installed_pipeline_models[x]
|
|
||||||
and diffusers_selected.values.index(x) not in diffusers_selected.value
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: REFACTOR THIS REPETITIVE CODE
|
|
||||||
if cn_models_selected := self.controlnet_models.get('models_selected'):
|
|
||||||
selections.install_cn_models = [cn_models_selected.values[x]
|
|
||||||
for x in cn_models_selected.value
|
|
||||||
if not self.installed_cn_models[cn_models_selected.values[x]]
|
|
||||||
]
|
|
||||||
selections.remove_cn_models = [x
|
|
||||||
for x in cn_models_selected.values
|
|
||||||
if self.installed_cn_models[x]
|
|
||||||
and cn_models_selected.values.index(x) not in cn_models_selected.value
|
|
||||||
]
|
|
||||||
if (additional_cns := self.controlnet_models['download_ids'].value.split()):
|
|
||||||
valid_cns = [x for x in additional_cns if '/' in x]
|
|
||||||
selections.install_cn_models.extend(valid_cns)
|
|
||||||
|
|
||||||
# same thing, for LoRAs
|
# models located in the 'download_ids" section
|
||||||
if loras_selected := self.lora_models.get('models_selected'):
|
for section in ui_sections:
|
||||||
selections.install_lora_models = [loras_selected.values[x]
|
if downloads := section.get('download_ids'):
|
||||||
for x in loras_selected.value
|
selections.install_models.extend(downloads.value.split())
|
||||||
if not self.installed_lora_models[loras_selected.values[x]]
|
|
||||||
]
|
|
||||||
selections.remove_lora_models = [x
|
|
||||||
for x in loras_selected.values
|
|
||||||
if self.installed_lora_models[x]
|
|
||||||
and loras_selected.values.index(x) not in loras_selected.value
|
|
||||||
]
|
|
||||||
if (additional_loras := self.lora_models['download_ids'].value.split()):
|
|
||||||
selections.install_lora_models.extend(additional_loras)
|
|
||||||
|
|
||||||
# same thing, for TIs
|
|
||||||
# TODO: refactor
|
|
||||||
if tis_selected := self.ti_models.get('models_selected'):
|
|
||||||
selections.install_ti_models = [tis_selected.values[x]
|
|
||||||
for x in tis_selected.value
|
|
||||||
if not self.installed_ti_models[tis_selected.values[x]]
|
|
||||||
]
|
|
||||||
selections.remove_ti_models = [x
|
|
||||||
for x in tis_selected.values
|
|
||||||
if self.installed_ti_models[x]
|
|
||||||
and tis_selected.values.index(x) not in tis_selected.value
|
|
||||||
]
|
|
||||||
|
|
||||||
if (additional_tis := self.ti_models['download_ids'].value.split()):
|
|
||||||
selections.install_ti_models.extend(additional_tis)
|
|
||||||
|
|
||||||
# load directory and whether to scan on startup
|
# load directory and whether to scan on startup
|
||||||
selections.scan_directory = self.pipeline_models['autoload_directory'].value
|
selections.scan_directory = self.pipeline_models['autoload_directory'].value
|
||||||
@ -609,7 +550,7 @@ class AddModelApplication(npyscreen.NPSAppManaged):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.program_opts = opt
|
self.program_opts = opt
|
||||||
self.user_cancelled = False
|
self.user_cancelled = False
|
||||||
self.user_selections = UserSelections()
|
self.install_selections = InstallSelections()
|
||||||
|
|
||||||
def onStart(self):
|
def onStart(self):
|
||||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||||
@ -628,21 +569,17 @@ class StderrToMessage():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
def ask_user_for_config_file(model_path: Path,
|
def ask_user_for_prediction_type(model_path: Path,
|
||||||
tui_conn: Connection=None
|
tui_conn: Connection=None
|
||||||
)->Path:
|
)->Path:
|
||||||
if tui_conn:
|
if tui_conn:
|
||||||
logger.debug('Waiting for user response...')
|
logger.debug('Waiting for user response...')
|
||||||
return _ask_user_for_cf_tui(model_path, tui_conn)
|
return _ask_user_for_pt_tui(model_path, tui_conn)
|
||||||
else:
|
else:
|
||||||
return _ask_user_for_cf_cmdline(model_path)
|
return _ask_user_for_pt_cmdline(model_path)
|
||||||
|
|
||||||
def _ask_user_for_cf_cmdline(model_path):
|
def _ask_user_for_pt_cmdline(model_path):
|
||||||
choices = [
|
choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
|
||||||
config.legacy_conf_path / x
|
|
||||||
for x in ['v2-inference.yaml','v2-inference-v.yaml']
|
|
||||||
]
|
|
||||||
choices.extend([None])
|
|
||||||
print(
|
print(
|
||||||
f"""
|
f"""
|
||||||
Please select the type of the V2 checkpoint named {model_path.name}:
|
Please select the type of the V2 checkpoint named {model_path.name}:
|
||||||
@ -664,7 +601,7 @@ Please select the type of the V2 checkpoint named {model_path.name}:
|
|||||||
return
|
return
|
||||||
return choice
|
return choice
|
||||||
|
|
||||||
def _ask_user_for_cf_tui(model_path: Path, tui_conn: Connection)->Path:
|
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->Path:
|
||||||
try:
|
try:
|
||||||
tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8'))
|
tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8'))
|
||||||
# note that we don't do any status checking here
|
# note that we don't do any status checking here
|
||||||
@ -672,20 +609,20 @@ def _ask_user_for_cf_tui(model_path: Path, tui_conn: Connection)->Path:
|
|||||||
if response is None:
|
if response is None:
|
||||||
return None
|
return None
|
||||||
elif response == 'epsilon':
|
elif response == 'epsilon':
|
||||||
return config.legacy_conf_path / 'v2-inference.yaml'
|
return SchedulerPredictionType.epsilon
|
||||||
elif response == 'v':
|
elif response == 'v':
|
||||||
return config.legacy_conf_path / 'v2-inference-v.yaml'
|
return SchedulerPredictionType.VPrediction
|
||||||
elif response == 'abort':
|
elif response == 'abort':
|
||||||
logger.info('Conversion aborted')
|
logger.info('Conversion aborted')
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return Path(response)
|
return response
|
||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
def process_and_execute(opt: Namespace,
|
def process_and_execute(opt: Namespace,
|
||||||
selections: UserSelections,
|
selections: InstallSelections,
|
||||||
conn_out: Connection=None,
|
conn_out: Connection=None,
|
||||||
):
|
):
|
||||||
# set up so that stderr is sent to conn_out
|
# set up so that stderr is sent to conn_out
|
||||||
@ -696,34 +633,14 @@ def process_and_execute(opt: Namespace,
|
|||||||
logger = InvokeAILogger.getLogger()
|
logger = InvokeAILogger.getLogger()
|
||||||
logger.handlers.clear()
|
logger.handlers.clear()
|
||||||
logger.addHandler(logging.StreamHandler(translator))
|
logger.addHandler(logging.StreamHandler(translator))
|
||||||
|
|
||||||
models_to_install = selections.install_models
|
|
||||||
models_to_remove = selections.remove_models
|
|
||||||
directory_to_scan = selections.scan_directory
|
|
||||||
scan_at_startup = selections.autoscan_on_startup
|
|
||||||
potential_models_to_install = selections.import_model_paths
|
|
||||||
name_map = selections.model_name_map
|
|
||||||
|
|
||||||
install_requested_models(
|
installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x,conn_out))
|
||||||
diffusers = ModelInstallList(models_to_install, [name_map[ModelType.Pipeline][x] for x in models_to_remove]),
|
installer.install(selections)
|
||||||
controlnet = ModelInstallList(selections.install_cn_models, [name_map[ModelType.ControlNet][x] for x in selections.remove_cn_models]),
|
|
||||||
lora = ModelInstallList(selections.install_lora_models, [name_map[ModelType.Lora][x] for x in selections.remove_lora_models]),
|
|
||||||
ti = ModelInstallList(selections.install_ti_models, [name_map[ModelType.TextualInversion][x] for x in selections.remove_ti_models]),
|
|
||||||
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
|
|
||||||
external_models=potential_models_to_install,
|
|
||||||
scan_at_startup=scan_at_startup,
|
|
||||||
precision="float32"
|
|
||||||
if opt.full_precision
|
|
||||||
else choose_precision(torch.device(choose_torch_device())),
|
|
||||||
config_file_path=Path(opt.config_file) if opt.config_file else config.model_conf_path,
|
|
||||||
model_config_file_callback = lambda x: ask_user_for_config_file(x,conn_out)
|
|
||||||
)
|
|
||||||
|
|
||||||
if conn_out:
|
if conn_out:
|
||||||
conn_out.send_bytes('*done*'.encode('utf-8'))
|
conn_out.send_bytes('*done*'.encode('utf-8'))
|
||||||
conn_out.close()
|
conn_out.close()
|
||||||
|
|
||||||
|
|
||||||
def do_listings(opt)->bool:
|
def do_listings(opt)->bool:
|
||||||
"""List installed models of various sorts, and return
|
"""List installed models of various sorts, and return
|
||||||
True if any were requested."""
|
True if any were requested."""
|
||||||
@ -754,38 +671,34 @@ def select_and_download_models(opt: Namespace):
|
|||||||
if opt.full_precision
|
if opt.full_precision
|
||||||
else choose_precision(torch.device(choose_torch_device()))
|
else choose_precision(torch.device(choose_torch_device()))
|
||||||
)
|
)
|
||||||
|
config.precision = precision
|
||||||
if do_listings(opt):
|
helper = lambda x: ask_user_for_prediction_type(x)
|
||||||
pass
|
# if do_listings(opt):
|
||||||
# this processes command line additions/removals
|
# pass
|
||||||
elif opt.diffusers or opt.controlnets or opt.textual_inversions or opt.loras:
|
|
||||||
action = 'remove_models' if opt.delete else 'install_models'
|
installer = ModelInstall(config, prediction_type_helper=helper)
|
||||||
diffusers_args = {'diffusers':ModelInstallList(remove_models=opt.diffusers or [])} \
|
if opt.add or opt.delete:
|
||||||
if opt.delete \
|
selections = InstallSelections(
|
||||||
else {'external_models':opt.diffusers or []}
|
install_models = opt.add or [],
|
||||||
install_requested_models(
|
remove_models = opt.delete or []
|
||||||
**diffusers_args,
|
|
||||||
controlnet=ModelInstallList(**{action:opt.controlnets or []}),
|
|
||||||
ti=ModelInstallList(**{action:opt.textual_inversions or []}),
|
|
||||||
lora=ModelInstallList(**{action:opt.loras or []}),
|
|
||||||
precision=precision,
|
|
||||||
model_config_file_callback=lambda x: ask_user_for_config_file(x),
|
|
||||||
)
|
)
|
||||||
|
installer.install(selections)
|
||||||
elif opt.default_only:
|
elif opt.default_only:
|
||||||
install_requested_models(
|
selections = InstallSelections(
|
||||||
diffusers=ModelInstallList(install_models=default_dataset()),
|
install_models = installer.default_model()
|
||||||
precision=precision,
|
|
||||||
)
|
)
|
||||||
|
installer.install(selections)
|
||||||
elif opt.yes_to_all:
|
elif opt.yes_to_all:
|
||||||
install_requested_models(
|
selections = InstallSelections(
|
||||||
diffusers=ModelInstallList(install_models=recommended_datasets()),
|
install_models = installer.recommended_models()
|
||||||
precision=precision,
|
|
||||||
)
|
)
|
||||||
|
installer.install(selections)
|
||||||
|
|
||||||
# this is where the TUI is called
|
# this is where the TUI is called
|
||||||
else:
|
else:
|
||||||
# needed because the torch library is loaded, even though we don't use it
|
# needed because the torch library is loaded, even though we don't use it
|
||||||
torch.multiprocessing.set_start_method("spawn")
|
# currently commented out because it has started generating errors (?)
|
||||||
|
# torch.multiprocessing.set_start_method("spawn")
|
||||||
|
|
||||||
# the third argument is needed in the Windows 11 environment in
|
# the third argument is needed in the Windows 11 environment in
|
||||||
# order to launch and resize a console window running this program
|
# order to launch and resize a console window running this program
|
||||||
@ -801,35 +714,20 @@ def select_and_download_models(opt: Namespace):
|
|||||||
installApp.main_form.subprocess.terminate()
|
installApp.main_form.subprocess.terminate()
|
||||||
installApp.main_form.subprocess = None
|
installApp.main_form.subprocess = None
|
||||||
raise e
|
raise e
|
||||||
process_and_execute(opt, installApp.user_selections)
|
process_and_execute(opt, installApp.install_selections)
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--diffusers",
|
"--add",
|
||||||
nargs="*",
|
nargs="*",
|
||||||
help="List of URLs or repo_ids of diffusers to install/delete",
|
help="List of URLs, local paths or repo_ids of models to install",
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--loras",
|
|
||||||
nargs="*",
|
|
||||||
help="List of URLs or repo_ids of LoRA/LyCORIS models to install/delete",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--controlnets",
|
|
||||||
nargs="*",
|
|
||||||
help="List of URLs or repo_ids of controlnet models to install/delete",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--textual-inversions",
|
|
||||||
nargs="*",
|
|
||||||
help="List of URLs or repo_ids of textual inversion embeddings to install/delete",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--delete",
|
"--delete",
|
||||||
action="store_true",
|
nargs="*",
|
||||||
help="Delete models listed on command line rather than installing them",
|
help="List of names of models to idelete",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--full-precision",
|
"--full-precision",
|
||||||
@ -849,7 +747,7 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--default_only",
|
"--default_only",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="only install the default model",
|
help="Only install the default model",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--list-models",
|
"--list-models",
|
||||||
|
@ -17,8 +17,8 @@ from shutil import get_terminal_size
|
|||||||
from curses import BUTTON2_CLICKED,BUTTON3_CLICKED
|
from curses import BUTTON2_CLICKED,BUTTON3_CLICKED
|
||||||
|
|
||||||
# minimum size for UIs
|
# minimum size for UIs
|
||||||
MIN_COLS = 120
|
MIN_COLS = 180
|
||||||
MIN_LINES = 50
|
MIN_LINES = 55
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def set_terminal_size(columns: int, lines: int, launch_command: str=None):
|
def set_terminal_size(columns: int, lines: int, launch_command: str=None):
|
||||||
@ -384,7 +384,6 @@ def select_stable_diffusion_config_file(
|
|||||||
"An SD v2.x base model (512 pixels; no 'parameterization:' line in its yaml file)",
|
"An SD v2.x base model (512 pixels; no 'parameterization:' line in its yaml file)",
|
||||||
"An SD v2.x v-predictive model (768 pixels; 'parameterization: \"v\"' line in its yaml file)",
|
"An SD v2.x v-predictive model (768 pixels; 'parameterization: \"v\"' line in its yaml file)",
|
||||||
"Skip installation for now and come back later",
|
"Skip installation for now and come back later",
|
||||||
"Enter config file path manually",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
F = ConfirmCancelPopup(
|
F = ConfirmCancelPopup(
|
||||||
@ -406,35 +405,17 @@ def select_stable_diffusion_config_file(
|
|||||||
mlw.values = message
|
mlw.values = message
|
||||||
|
|
||||||
choice = F.add(
|
choice = F.add(
|
||||||
SingleSelectWithChanged,
|
npyscreen.SelectOne,
|
||||||
values = options,
|
values = options,
|
||||||
value = [0],
|
value = [0],
|
||||||
max_height = len(options)+1,
|
max_height = len(options)+1,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
file = F.add(
|
|
||||||
FileBox,
|
|
||||||
name='Path to config file',
|
|
||||||
max_height=3,
|
|
||||||
hidden=True,
|
|
||||||
must_exist=True,
|
|
||||||
scroll_exit=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def toggle_visible(value):
|
|
||||||
value = value[0]
|
|
||||||
if value==3:
|
|
||||||
file.hidden=False
|
|
||||||
else:
|
|
||||||
file.hidden=True
|
|
||||||
F.display()
|
|
||||||
|
|
||||||
choice.on_changed = toggle_visible
|
|
||||||
|
|
||||||
F.editw = 1
|
F.editw = 1
|
||||||
F.edit()
|
F.edit()
|
||||||
if not F.value:
|
if not F.value:
|
||||||
return None
|
return None
|
||||||
assert choice.value[0] in range(0,4),'invalid choice'
|
assert choice.value[0] in range(0,3),'invalid choice'
|
||||||
choices = ['epsilon','v','abort',file.value]
|
choices = ['epsilon','v','abort']
|
||||||
return choices[choice.value[0]]
|
return choices[choice.value[0]]
|
||||||
|
@ -26,7 +26,7 @@ from transformers import (
|
|||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.model_management import ModelManager
|
from invokeai.backend.model_management import ModelManager
|
||||||
from invokeai.backend.model_management.model_probe import (
|
from invokeai.backend.model_management.model_probe import (
|
||||||
ModelProbe, ModelType, BaseModelType, SchedulerPredictionType, ModelVariantInfo
|
ModelProbe, ModelType, BaseModelType, SchedulerPredictionType, ModelProbeInfo
|
||||||
)
|
)
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
@ -171,13 +171,13 @@ def migrate_tuning_models(dest: Path):
|
|||||||
logger.info(f'Scanning {subdir}')
|
logger.info(f'Scanning {subdir}')
|
||||||
migrate_models(src, dest)
|
migrate_models(src, dest)
|
||||||
|
|
||||||
def write_yaml(model_name: str, path:Path, info:ModelVariantInfo, dest_yaml: io.TextIOBase):
|
def write_yaml(model_name: str, path:Path, info:ModelProbeInfo, dest_yaml: io.TextIOBase):
|
||||||
name = unique_name(model_name, info)
|
name = unique_name(model_name, info)
|
||||||
stanza = {
|
stanza = {
|
||||||
f'{info.base_type.value}/{info.model_type.value}/{name}': {
|
f'{info.base_type.value}/{info.model_type.value}/{name}': {
|
||||||
'name': model_name,
|
'name': model_name,
|
||||||
'path': str(path),
|
'path': str(path),
|
||||||
'description': f'diffusers model {model_name}',
|
'description': f'A {info.base_type.value} {info.model_type.value} model',
|
||||||
'format': 'diffusers',
|
'format': 'diffusers',
|
||||||
'image_size': info.image_size,
|
'image_size': info.image_size,
|
||||||
'base': info.base_type.value,
|
'base': info.base_type.value,
|
||||||
@ -266,7 +266,7 @@ def migrate_checkpoints(dest_dir: Path, dest_yaml: io.TextIOBase):
|
|||||||
{
|
{
|
||||||
'name': model_name,
|
'name': model_name,
|
||||||
'path': str(weights),
|
'path': str(weights),
|
||||||
'description': f'checkpoint model {model_name}',
|
'description': f'{info.base_type.value}-based checkpoint',
|
||||||
'format': 'checkpoint',
|
'format': 'checkpoint',
|
||||||
'image_size': info.image_size,
|
'image_size': info.image_size,
|
||||||
'base': info.base_type.value,
|
'base': info.base_type.value,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user