InvokeAI/invokeai/app/services/model_manager_initializer.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

121 lines
4.1 KiB
Python
Raw Normal View History

import os
import sys
import torch
from argparse import Namespace
2023-03-11 16:32:57 +00:00
from invokeai.backend import Args
from omegaconf import OmegaConf
from pathlib import Path
import invokeai.version
from ...backend import ModelManager
from ...backend.util import choose_precision, choose_torch_device
from ...backend import Globals
2023-03-11 23:08:45 +00:00
# TODO: Replace with an abstract class base ModelManagerBase
2023-03-11 16:32:57 +00:00
def get_model_manager(config: Args) -> ModelManager:
if not config.conf:
config_file = os.path.join(Globals.root, "configs", "models.yaml")
if not os.path.exists(config_file):
report_model_error(
2023-03-11 16:32:57 +00:00
config, FileNotFoundError(f"The file {config_file} could not be found.")
)
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers # type: ignore
transformers.logging.set_verbosity_error()
import diffusers
diffusers.logging.set_verbosity_error()
# normalize the config directory relative to root
2023-03-11 16:32:57 +00:00
if not os.path.isabs(config.conf):
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
2023-03-11 16:32:57 +00:00
if config.embeddings:
if not os.path.isabs(config.embedding_path):
embedding_path = os.path.normpath(
2023-03-11 16:32:57 +00:00
os.path.join(Globals.root, config.embedding_path)
)
else:
2023-03-11 16:32:57 +00:00
embedding_path = config.embedding_path
else:
embedding_path = None
# migrate legacy models
ModelManager.migrate_models()
# creating the model manager
try:
device = torch.device(choose_torch_device())
2023-03-11 16:32:57 +00:00
precision = 'float16' if config.precision=='float16' \
else 'float32' if config.precision=='float32' \
else choose_precision(device)
model_manager = ModelManager(
2023-03-11 16:32:57 +00:00
OmegaConf.load(config.conf),
precision=precision,
device_type=device,
2023-03-11 16:32:57 +00:00
max_loaded_models=config.max_loaded_models,
embedding_path = Path(embedding_path),
)
except (FileNotFoundError, TypeError, AssertionError) as e:
2023-03-11 16:32:57 +00:00
report_model_error(config, e)
except (IOError, KeyError) as e:
print(f"{e}. Aborting.")
sys.exit(-1)
# try to autoconvert new models
# autoimport new .ckpt files
2023-03-11 16:32:57 +00:00
if path := config.autoconvert:
model_manager.autoconvert_weights(
2023-03-11 16:32:57 +00:00
conf_path=config.conf,
weights_directory=path,
)
return model_manager
def report_model_error(opt: Namespace, e: Exception):
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
print(
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
)
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
if yes_to_all:
print(
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
)
else:
response = input(
"Do you want to run invokeai-configure script to select and/or reinstall models? [y] "
)
if response.startswith(("n", "N")):
return
print("invokeai-configure is launching....\n")
# Match arguments that were set on the CLI
# only the arguments accepted by the configuration script are parsed
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
config = ["--config", opt.conf] if opt.conf is not None else []
2023-03-11 16:32:57 +00:00
previous_config = sys.argv
sys.argv = ["invokeai-configure"]
sys.argv.extend(root_dir)
2023-03-11 16:32:57 +00:00
sys.argv.extend(config.to_dict())
if yes_to_all is not None:
for arg in yes_to_all.split():
sys.argv.append(arg)
from invokeai.frontend.install import invokeai_configure
invokeai_configure()
# TODO: Figure out how to restart
# print('** InvokeAI will now restart')
# sys.argv = previous_args
# main() # would rather do a os.exec(), but doesn't exist?
# sys.exit(0)