mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
simplify passing of config options
This commit is contained in:
@ -37,12 +37,12 @@ class ApiDependencies:
|
|||||||
invoker: Invoker = None
|
invoker: Invoker = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def initialize(args, config, event_handler_id: int):
|
def initialize(config, event_handler_id: int):
|
||||||
Globals.try_patchmatch = args.patchmatch
|
Globals.try_patchmatch = config.patchmatch
|
||||||
Globals.always_use_cpu = args.always_use_cpu
|
Globals.always_use_cpu = config.always_use_cpu
|
||||||
Globals.internet_available = args.internet_available and check_internet()
|
Globals.internet_available = config.internet_available and check_internet()
|
||||||
Globals.disable_xformers = not args.xformers
|
Globals.disable_xformers = not config.xformers
|
||||||
Globals.ckpt_convert = args.ckpt_convert
|
Globals.ckpt_convert = config.ckpt_convert
|
||||||
|
|
||||||
# TODO: Use a logger
|
# TODO: Use a logger
|
||||||
print(f">> Internet connectivity is {Globals.internet_available}")
|
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||||
@ -59,7 +59,7 @@ class ApiDependencies:
|
|||||||
db_location = os.path.join(output_folder, "invokeai.db")
|
db_location = os.path.join(output_folder, "invokeai.db")
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=get_model_manager(args, config),
|
model_manager=get_model_manager(config),
|
||||||
events=events,
|
events=events,
|
||||||
images=images,
|
images=images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
|
@ -53,11 +53,11 @@ config = {}
|
|||||||
# Add startup event to load dependencies
|
# Add startup event to load dependencies
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
args = Args()
|
config = Args()
|
||||||
config = args.parse_args()
|
config.parse_args()
|
||||||
|
|
||||||
ApiDependencies.initialize(
|
ApiDependencies.initialize(
|
||||||
args=args, config=config, event_handler_id=event_handler_id
|
config=config, event_handler_id=event_handler_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_gra
|
|||||||
from .invocations import *
|
from .invocations import *
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
from .services.generate_initializer import get_model_manager
|
from .services.model_manager_initializer import get_model_manager
|
||||||
from .services.graph import EdgeConnection, GraphExecutionState
|
from .services.graph import EdgeConnection, GraphExecutionState
|
||||||
from .services.image_storage import DiskImageStorage
|
from .services.image_storage import DiskImageStorage
|
||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
@ -126,10 +126,9 @@ def invoke_all(context: CliContext):
|
|||||||
|
|
||||||
|
|
||||||
def invoke_cli():
|
def invoke_cli():
|
||||||
args = Args()
|
config = Args()
|
||||||
config = args.parse_args()
|
config.parse_args()
|
||||||
|
model_manager = get_model_manager(config)
|
||||||
model_manager = get_model_manager(args, config)
|
|
||||||
|
|
||||||
events = EventServiceBase()
|
events = EventServiceBase()
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import torch
|
import torch
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
from invokeai.backend import Args
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -11,12 +12,12 @@ from ...backend.util import choose_precision, choose_torch_device
|
|||||||
from ...backend import Globals
|
from ...backend import Globals
|
||||||
|
|
||||||
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated
|
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated
|
||||||
def get_model_manager(args, config) -> ModelManager:
|
def get_model_manager(config: Args) -> ModelManager:
|
||||||
if not args.conf:
|
if not config.conf:
|
||||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||||
if not os.path.exists(config_file):
|
if not os.path.exists(config_file):
|
||||||
report_model_error(
|
report_model_error(
|
||||||
args, FileNotFoundError(f"The file {config_file} could not be found.")
|
config, FileNotFoundError(f"The file {config_file} could not be found.")
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||||
@ -32,64 +33,47 @@ def get_model_manager(args, config) -> ModelManager:
|
|||||||
diffusers.logging.set_verbosity_error()
|
diffusers.logging.set_verbosity_error()
|
||||||
|
|
||||||
# normalize the config directory relative to root
|
# normalize the config directory relative to root
|
||||||
if not os.path.isabs(args.conf):
|
if not os.path.isabs(config.conf):
|
||||||
args.conf = os.path.normpath(os.path.join(Globals.root, args.conf))
|
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
|
||||||
|
|
||||||
if args.embeddings:
|
if config.embeddings:
|
||||||
if not os.path.isabs(args.embedding_path):
|
if not os.path.isabs(config.embedding_path):
|
||||||
embedding_path = os.path.normpath(
|
embedding_path = os.path.normpath(
|
||||||
os.path.join(Globals.root, args.embedding_path)
|
os.path.join(Globals.root, config.embedding_path)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
embedding_path = args.embedding_path
|
embedding_path = config.embedding_path
|
||||||
else:
|
else:
|
||||||
embedding_path = None
|
embedding_path = None
|
||||||
|
|
||||||
# migrate legacy models
|
# migrate legacy models
|
||||||
ModelManager.migrate_models()
|
ModelManager.migrate_models()
|
||||||
|
|
||||||
# load the infile as a list of lines
|
|
||||||
if args.infile:
|
|
||||||
try:
|
|
||||||
if os.path.isfile(args.infile):
|
|
||||||
infile = open(args.infile, "r", encoding="utf-8")
|
|
||||||
elif args.infile == "-": # stdin
|
|
||||||
infile = sys.stdin
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(f"{args.infile} not found.")
|
|
||||||
except (FileNotFoundError, IOError) as e:
|
|
||||||
print(f"{e}. Aborting.")
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
# creating the model manager
|
# creating the model manager
|
||||||
try:
|
try:
|
||||||
device = torch.device(choose_torch_device())
|
device = torch.device(choose_torch_device())
|
||||||
precision = 'float16' if args.precision=='float16' \
|
precision = 'float16' if config.precision=='float16' \
|
||||||
else 'float32' if args.precision=='float32' \
|
else 'float32' if config.precision=='float32' \
|
||||||
else choose_precision(device)
|
else choose_precision(device)
|
||||||
|
|
||||||
model_manager = ModelManager(
|
model_manager = ModelManager(
|
||||||
OmegaConf.load(args.conf),
|
OmegaConf.load(config.conf),
|
||||||
precision=precision,
|
precision=precision,
|
||||||
device_type=device,
|
device_type=device,
|
||||||
max_loaded_models=args.max_loaded_models,
|
max_loaded_models=config.max_loaded_models,
|
||||||
embedding_path = Path(embedding_path),
|
embedding_path = Path(embedding_path),
|
||||||
)
|
)
|
||||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||||
report_model_error(args, e)
|
report_model_error(config, e)
|
||||||
except (IOError, KeyError) as e:
|
except (IOError, KeyError) as e:
|
||||||
print(f"{e}. Aborting.")
|
print(f"{e}. Aborting.")
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
if args.seamless:
|
|
||||||
#TODO: do something here ?
|
|
||||||
print(">> changed to seamless tiling mode")
|
|
||||||
|
|
||||||
# try to autoconvert new models
|
# try to autoconvert new models
|
||||||
# autoimport new .ckpt files
|
# autoimport new .ckpt files
|
||||||
if path := args.autoconvert:
|
if path := config.autoconvert:
|
||||||
model_manager.autoconvert_weights(
|
model_manager.autoconvert_weights(
|
||||||
conf_path=args.conf,
|
conf_path=config.conf,
|
||||||
weights_directory=path,
|
weights_directory=path,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -118,10 +102,10 @@ def report_model_error(opt: Namespace, e: Exception):
|
|||||||
# only the arguments accepted by the configuration script are parsed
|
# only the arguments accepted by the configuration script are parsed
|
||||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||||
previous_args = sys.argv
|
previous_config = sys.argv
|
||||||
sys.argv = ["invokeai-configure"]
|
sys.argv = ["invokeai-configure"]
|
||||||
sys.argv.extend(root_dir)
|
sys.argv.extend(root_dir)
|
||||||
sys.argv.extend(config)
|
sys.argv.extend(config.to_dict())
|
||||||
if yes_to_all is not None:
|
if yes_to_all is not None:
|
||||||
for arg in yes_to_all.split():
|
for arg in yes_to_all.split():
|
||||||
sys.argv.append(arg)
|
sys.argv.append(arg)
|
||||||
|
Reference in New Issue
Block a user