Download all model types. (#3944)

This commit is contained in:
Lincoln Stein 2023-07-26 10:24:37 -04:00 committed by GitHub
commit 385483ff8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 8 deletions

View File

@ -558,7 +558,7 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
# ------------------------------------- # -------------------------------------
def initialize_rootdir(root: Path, yes_to_all: bool = False): def initialize_rootdir(root: Path, yes_to_all: bool = False):
logger.info("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **") logger.info("Initializing InvokeAI runtime directory")
for name in ( for name in (
"models", "models",
"databases", "databases",
@ -788,15 +788,14 @@ def main():
sys.exit(0) sys.exit(0)
if opt.skip_support_models: if opt.skip_support_models:
logger.info("SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST") logger.info("Skipping support models at user's request")
else: else:
logger.info("CHECKING/UPDATING SUPPORT MODELS") logger.info("Installing support models")
download_support_models() download_support_models()
if opt.skip_sd_weights: if opt.skip_sd_weights:
logger.warning("SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST") logger.warning("Skipping diffusion weights download per user request")
elif models_to_download: elif models_to_download:
logger.info("DOWNLOADING DIFFUSION WEIGHTS")
process_and_execute(opt, models_to_download) process_and_execute(opt, models_to_download)
postscript(errors=errors) postscript(errors=errors)

View File

@ -149,16 +149,17 @@ class ModelInstall(object):
for i in installed: for i in installed:
print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}") print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}")
def starter_models(self)->Set[str]: # logic here a little reversed to maintain backward compatibility
def starter_models(self, all_models: bool=False)->Set[str]:
models = set() models = set()
for key, value in self.datasets.items(): for key, value in self.datasets.items():
name,base,model_type = ModelManager.parse_key(key) name,base,model_type = ModelManager.parse_key(key)
if model_type==ModelType.Main: if all_models or model_type==ModelType.Main:
models.add(key) models.add(key)
return models return models
def recommended_models(self)->Set[str]: def recommended_models(self)->Set[str]:
starters = self.starter_models() starters = self.starter_models(all_models=True)
return set([x for x in starters if self.datasets[x].get('recommended',False)]) return set([x for x in starters if self.datasets[x].get('recommended',False)])
def default_model(self)->str: def default_model(self)->str: