diff --git a/invokeai/backend/install/invokeai_configure.py b/invokeai/backend/install/invokeai_configure.py index a863a98a8a..f379eb2569 100755 --- a/invokeai/backend/install/invokeai_configure.py +++ b/invokeai/backend/install/invokeai_configure.py @@ -33,12 +33,12 @@ from tqdm import tqdm from transformers import AutoFeatureExtractor import invokeai.configs as model_configs -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import InvokeAIAppConfig, get_config from invokeai.backend.install.install_helper import InstallHelper, InstallSelections from invokeai.backend.model_manager import ModelType from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util.logging import InvokeAILogger +from invokeai.frontend.cli.arg_parser import InvokeAIArgs from invokeai.frontend.install.model_install import addModelsForm # TO DO - Move all the frontend code into invokeai.frontend.install @@ -63,8 +63,7 @@ def get_literal_fields(field: str) -> Tuple[Any]: # --------------------------globals----------------------- - -config = get_config() +config = None PRECISION_CHOICES = get_literal_fields("precision") DEVICE_CHOICES = get_literal_fields("device") @@ -745,6 +744,8 @@ def is_v2_install(root: Path) -> bool: # ------------------------------------- def main() -> None: global FORCE_FULL_PRECISION # FIXME + global config + parser = argparse.ArgumentParser(description="InvokeAI model downloader") parser.add_argument( "--skip-sd-weights", @@ -787,10 +788,12 @@ def main() -> None: default=None, help="path to root of install directory", ) + opt = parser.parse_args() updates: dict[str, Any] = {} - if opt.root: - config.set_root(Path(opt.root)) + + InvokeAIArgs.args = opt + config = get_config() if opt.full_precision: updates["precision"] = "float32" diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index b99a5932fd..e02015614e 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -12,11 +12,11 @@ from invokeai.app.services.config.config_default import get_config CPU_DEVICE = torch.device("cpu") CUDA_DEVICE = torch.device("cuda") MPS_DEVICE = torch.device("mps") -config = get_config() def choose_torch_device() -> torch.device: """Convenience routine for guessing which GPU device to run model on""" + config = get_config() if config.device == "auto": if torch.cuda.is_available(): return torch.device("cuda") @@ -34,7 +34,7 @@ def choose_precision( device: torch.device, app_config: Optional[InvokeAIAppConfig] = None ) -> Literal["float32", "float16", "bfloat16"]: """Return an appropriate precision for the given torch device.""" - app_config = app_config or config + app_config = app_config or get_config() if device.type == "cuda": device_name = torch.cuda.get_device_name(device) if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name): diff --git a/invokeai/backend/util/logging.py b/invokeai/backend/util/logging.py index 9f5381db34..968604eb3d 100644 --- a/invokeai/backend/util/logging.py +++ b/invokeai/backend/util/logging.py @@ -339,7 +339,8 @@ class InvokeAILogger(object): # noqa D102 loggers: Dict[str, logging.Logger] = {} @classmethod - def get_logger(cls, name: str = "InvokeAI", config: InvokeAIAppConfig = get_config()) -> logging.Logger: # noqa D102 + def get_logger(cls, name: str = "InvokeAI", config: Optional[InvokeAIAppConfig] = None) -> logging.Logger: # noqa D102 + config = config or get_config() if name in cls.loggers: return cls.loggers[name] diff --git a/invokeai/frontend/cli/arg_parser.py b/invokeai/frontend/cli/arg_parser.py index 0df2b67b7c..452cb38246 100644 --- a/invokeai/frontend/cli/arg_parser.py +++ b/invokeai/frontend/cli/arg_parser.py @@ -30,11 +30,11 @@ class InvokeAIArgs: Example: ``` # In a CLI wrapper - from invokeai.frontend.cli.app_arg_parser import InvokeAIArgs + from invokeai.frontend.cli.arg_parser import InvokeAIArgs InvokeAIArgs.parse_args() # In the application - from invokeai.frontend.cli.app_arg_parser import InvokeAIArgs + from invokeai.frontend.cli.arg_parser import InvokeAIArgs args = InvokeAIArgs.args """ diff --git a/tests/test_config.py b/tests/test_config.py index 617e28785d..c1dfd01f0a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,3 +1,4 @@ +import sys from pathlib import Path from tempfile import TemporaryDirectory from typing import Any @@ -7,6 +8,7 @@ from omegaconf import OmegaConf from pydantic import ValidationError from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config, load_and_migrate_config +from invokeai.frontend.cli.arg_parser import InvokeAIArgs v4_config = """ schema_version: 4 @@ -76,6 +78,13 @@ def test_read_config_from_file(tmp_path: Path): assert config.port == 8080 +def test_arg_parsing(): + sys.argv = ["test_config.py", "--root", "/tmp"] + InvokeAIArgs.parse_args() + config = get_config() + assert config.root_path == Path("/tmp") + + def test_migrate_v3_config_from_file(tmp_path: Path): """Test reading configuration from a file.""" temp_config_file = tmp_path / "temp_invokeai.yaml"