From 6c558279dd4aa487815f364121dfae2dc2cf7e0f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 19 Mar 2024 16:24:13 +1100 Subject: [PATCH] feat(config): add CLI arg to specify config file This allows users to create simple "profiles" via separate `invokeai.yaml` files. - Remove `InvokeAIAppConfig.set_root()`, it's extraneous - Remove `InvokeAIAppConfig.merge_from_file()`, it's extraneous - Add `--config` to the app arg parser, add `InvokeAIAppConfig._config_file`, and consume in the config singleton getter - `InvokeAIAppConfig.init_file_path` -> `InvokeAIAppConfig.config_file_path` --- .../app/services/config/config_default.py | 37 ++++++------------- .../model_install/model_install_default.py | 2 +- invokeai/frontend/cli/arg_parser.py | 6 ++- .../model_records/test_model_records_sql.py | 2 +- .../model_manager/model_manager_fixtures.py | 2 +- tests/test_config.py | 2 +- 6 files changed, 20 insertions(+), 31 deletions(-) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 10d8ba308f..e09567cc8c 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -121,6 +121,7 @@ class InvokeAIAppConfig(BaseSettings): """ _root: Optional[Path] = PrivateAttr(default=None) + _config_file: Optional[Path] = PrivateAttr(default=None) # fmt: off @@ -251,24 +252,6 @@ class InvokeAIAppConfig(BaseSettings): if len(config_dict) > 0: file.write(yaml.dump(config_dict, sort_keys=False)) - def merge_from_file(self, source_path: Optional[Path] = None) -> None: - """Read the config from the `invokeai.yaml` file, migrating it if necessary and merging it into the singleton config. - - This function will write to the `invokeai.yaml` file if the config is migrated. - - Args: - source_path: Path to the config file. If not provided, the default path is used. - """ - path = source_path or self.init_file_path - config_from_file = load_and_migrate_config(path) - # Clobbering here will overwrite any settings that were set via environment variables - self.update_config(config_from_file, clobber=False) - - def set_root(self, root: Path) -> None: - """Set the runtime root directory. This is typically set using a CLI arg.""" - assert isinstance(root, Path) - self._root = root - def _resolve(self, partial_path: Path) -> Path: return (self.root_path / partial_path).resolve() @@ -283,9 +266,9 @@ class InvokeAIAppConfig(BaseSettings): return root.resolve() @property - def init_file_path(self) -> Path: + def config_file_path(self) -> Path: """Path to invokeai.yaml, resolved to an absolute path..""" - resolved_path = self._resolve(INIT_FILE) + resolved_path = self._resolve(self._config_file or INIT_FILE) assert resolved_path is not None return resolved_path @@ -441,9 +424,11 @@ def get_config() -> InvokeAIAppConfig: if not InvokeAIArgs.did_parse: return config - # CLI args trump environment variables + # Set CLI args if root := getattr(args, "root", None): - config.set_root(Path(root)) + config._root = Path(root) + if config_file := getattr(args, "config_file", None): + config._config_file = Path(config_file) # Log in to HF hf_login() @@ -452,9 +437,11 @@ def get_config() -> InvokeAIAppConfig: configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue] shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True) - if config.init_file_path.exists(): - config.merge_from_file() + if config.config_file_path.exists(): + incoming_config = load_and_migrate_config(config.config_file_path) + # Clobbering here will overwrite any settings that were set via environment variables + config.update_config(incoming_config, clobber=False) else: - config.write_file(config.init_file_path) + config.write_file(config.config_file_path) return config diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 992f5a11fe..a80149c7f0 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -343,7 +343,7 @@ class ModelInstallService(ModelInstallServiceBase): # Remove `legacy_models_yaml_path` from the config file - we are done with it either way self._app_config.legacy_models_yaml_path = None - self._app_config.write_file(self._app_config.init_file_path) + self._app_config.write_file(self._app_config.config_file_path) def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 self._cached_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()} diff --git a/invokeai/frontend/cli/arg_parser.py b/invokeai/frontend/cli/arg_parser.py index 34f7c462bb..72da8f7656 100644 --- a/invokeai/frontend/cli/arg_parser.py +++ b/invokeai/frontend/cli/arg_parser.py @@ -3,14 +3,16 @@ from typing import Optional from invokeai.version import __version__ -_root_help = r"""Sets a root directory for the app. -If omitted, the app will search for the root directory in the following order: +_root_help = r"""Path to the runtime root directory. If omitted, the app will search for the root directory in the following order: - The `$INVOKEAI_ROOT` environment variable - The currently active virtual environment's parent directory - `$HOME/invokeai`""" +_config_file_help = r"""Path to the invokeai.yaml configuration file. If omitted, the app will search for the file in the root directory.""" + _parser = ArgumentParser(description="Invoke Studio", formatter_class=RawTextHelpFormatter) _parser.add_argument("--root", type=str, help=_root_help) +_parser.add_argument("--config", dest="config_file", type=str, help=_config_file_help) _parser.add_argument("--version", action="version", version=__version__, help="Displays the version and exits.") diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 7009a81025..1bf9b3e0e3 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -35,7 +35,7 @@ def store( datadir: Any, ) -> ModelRecordServiceSQL: config = InvokeAIAppConfig() - config.set_root(datadir) + config._root = datadir logger = InvokeAILogger.get_logger(config=config) db = create_mock_sqlite_database(config, logger) return ModelRecordServiceSQL(db) diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index 8d4ccf196c..6070f2b653 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -94,7 +94,7 @@ def diffusers_dir(mm2_model_files: Path) -> Path: @pytest.fixture def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig: app_config = InvokeAIAppConfig(models_dir=mm2_root_dir / "models", log_level="info") - app_config.set_root(mm2_root_dir) + app_config._root = mm2_root_dir return app_config diff --git a/tests/test_config.py b/tests/test_config.py index 617e28785d..e0e0050bf5 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -170,7 +170,7 @@ def test_set_and_resolve_paths(): """Test setting root and resolving paths based on it.""" with TemporaryDirectory() as tmpdir: config = InvokeAIAppConfig() - config.set_root(Path(tmpdir)) + config._root = Path(tmpdir) assert config.models_path == Path(tmpdir).resolve() / "models" assert config.db_path == Path(tmpdir).resolve() / "databases" / "invokeai.db"