simplify passing of config options

This commit is contained in:
Lincoln Stein
2023-03-11 11:32:57 -05:00
parent c14241436b
commit 580f9ecded
4 changed files with 33 additions and 50 deletions

View File

@ -37,12 +37,12 @@ class ApiDependencies:
invoker: Invoker = None invoker: Invoker = None
@staticmethod @staticmethod
def initialize(args, config, event_handler_id: int): def initialize(config, event_handler_id: int):
Globals.try_patchmatch = args.patchmatch Globals.try_patchmatch = config.patchmatch
Globals.always_use_cpu = args.always_use_cpu Globals.always_use_cpu = config.always_use_cpu
Globals.internet_available = args.internet_available and check_internet() Globals.internet_available = config.internet_available and check_internet()
Globals.disable_xformers = not args.xformers Globals.disable_xformers = not config.xformers
Globals.ckpt_convert = args.ckpt_convert Globals.ckpt_convert = config.ckpt_convert
# TODO: Use a logger # TODO: Use a logger
print(f">> Internet connectivity is {Globals.internet_available}") print(f">> Internet connectivity is {Globals.internet_available}")
@ -59,7 +59,7 @@ class ApiDependencies:
db_location = os.path.join(output_folder, "invokeai.db") db_location = os.path.join(output_folder, "invokeai.db")
services = InvocationServices( services = InvocationServices(
model_manager=get_model_manager(args, config), model_manager=get_model_manager(config),
events=events, events=events,
images=images, images=images,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),

View File

@ -53,11 +53,11 @@ config = {}
# Add startup event to load dependencies # Add startup event to load dependencies
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
args = Args() config = Args()
config = args.parse_args() config.parse_args()
ApiDependencies.initialize( ApiDependencies.initialize(
args=args, config=config, event_handler_id=event_handler_id config=config, event_handler_id=event_handler_id
) )

View File

@ -17,7 +17,7 @@ from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_gra
from .invocations import * from .invocations import *
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase from .services.events import EventServiceBase
from .services.generate_initializer import get_model_manager from .services.model_manager_initializer import get_model_manager
from .services.graph import EdgeConnection, GraphExecutionState from .services.graph import EdgeConnection, GraphExecutionState
from .services.image_storage import DiskImageStorage from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_queue import MemoryInvocationQueue
@ -126,10 +126,9 @@ def invoke_all(context: CliContext):
def invoke_cli(): def invoke_cli():
args = Args() config = Args()
config = args.parse_args() config.parse_args()
model_manager = get_model_manager(config)
model_manager = get_model_manager(args, config)
events = EventServiceBase() events = EventServiceBase()

View File

@ -2,6 +2,7 @@ import os
import sys import sys
import torch import torch
from argparse import Namespace from argparse import Namespace
from invokeai.backend import Args
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pathlib import Path from pathlib import Path
@ -11,12 +12,12 @@ from ...backend.util import choose_precision, choose_torch_device
from ...backend import Globals from ...backend import Globals
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated # TODO: most of this code should be split into individual services as the Generate.py code is deprecated
def get_model_manager(args, config) -> ModelManager: def get_model_manager(config: Args) -> ModelManager:
if not args.conf: if not config.conf:
config_file = os.path.join(Globals.root, "configs", "models.yaml") config_file = os.path.join(Globals.root, "configs", "models.yaml")
if not os.path.exists(config_file): if not os.path.exists(config_file):
report_model_error( report_model_error(
args, FileNotFoundError(f"The file {config_file} could not be found.") config, FileNotFoundError(f"The file {config_file} could not be found.")
) )
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}") print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
@ -32,64 +33,47 @@ def get_model_manager(args, config) -> ModelManager:
diffusers.logging.set_verbosity_error() diffusers.logging.set_verbosity_error()
# normalize the config directory relative to root # normalize the config directory relative to root
if not os.path.isabs(args.conf): if not os.path.isabs(config.conf):
args.conf = os.path.normpath(os.path.join(Globals.root, args.conf)) config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
if args.embeddings: if config.embeddings:
if not os.path.isabs(args.embedding_path): if not os.path.isabs(config.embedding_path):
embedding_path = os.path.normpath( embedding_path = os.path.normpath(
os.path.join(Globals.root, args.embedding_path) os.path.join(Globals.root, config.embedding_path)
) )
else: else:
embedding_path = args.embedding_path embedding_path = config.embedding_path
else: else:
embedding_path = None embedding_path = None
# migrate legacy models # migrate legacy models
ModelManager.migrate_models() ModelManager.migrate_models()
# load the infile as a list of lines
if args.infile:
try:
if os.path.isfile(args.infile):
infile = open(args.infile, "r", encoding="utf-8")
elif args.infile == "-": # stdin
infile = sys.stdin
else:
raise FileNotFoundError(f"{args.infile} not found.")
except (FileNotFoundError, IOError) as e:
print(f"{e}. Aborting.")
sys.exit(-1)
# creating the model manager # creating the model manager
try: try:
device = torch.device(choose_torch_device()) device = torch.device(choose_torch_device())
precision = 'float16' if args.precision=='float16' \ precision = 'float16' if config.precision=='float16' \
else 'float32' if args.precision=='float32' \ else 'float32' if config.precision=='float32' \
else choose_precision(device) else choose_precision(device)
model_manager = ModelManager( model_manager = ModelManager(
OmegaConf.load(args.conf), OmegaConf.load(config.conf),
precision=precision, precision=precision,
device_type=device, device_type=device,
max_loaded_models=args.max_loaded_models, max_loaded_models=config.max_loaded_models,
embedding_path = Path(embedding_path), embedding_path = Path(embedding_path),
) )
except (FileNotFoundError, TypeError, AssertionError) as e: except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(args, e) report_model_error(config, e)
except (IOError, KeyError) as e: except (IOError, KeyError) as e:
print(f"{e}. Aborting.") print(f"{e}. Aborting.")
sys.exit(-1) sys.exit(-1)
if args.seamless:
#TODO: do something here ?
print(">> changed to seamless tiling mode")
# try to autoconvert new models # try to autoconvert new models
# autoimport new .ckpt files # autoimport new .ckpt files
if path := args.autoconvert: if path := config.autoconvert:
model_manager.autoconvert_weights( model_manager.autoconvert_weights(
conf_path=args.conf, conf_path=config.conf,
weights_directory=path, weights_directory=path,
) )
@ -118,10 +102,10 @@ def report_model_error(opt: Namespace, e: Exception):
# only the arguments accepted by the configuration script are parsed # only the arguments accepted by the configuration script are parsed
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else [] root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
config = ["--config", opt.conf] if opt.conf is not None else [] config = ["--config", opt.conf] if opt.conf is not None else []
previous_args = sys.argv previous_config = sys.argv
sys.argv = ["invokeai-configure"] sys.argv = ["invokeai-configure"]
sys.argv.extend(root_dir) sys.argv.extend(root_dir)
sys.argv.extend(config) sys.argv.extend(config.to_dict())
if yes_to_all is not None: if yes_to_all is not None:
for arg in yes_to_all.split(): for arg in yes_to_all.split():
sys.argv.append(arg) sys.argv.append(arg)