fix(config): revised config methods

- `write_file` requires an destination file path
- `read_config` -> `merge_from_file`, if no path is provided, reads from `self.init_file_path`
- update app, tests to use new methods
- fix configurator, was overwriting config file data unexpectedly
This commit is contained in:
psychedelicious 2024-03-12 01:45:12 +11:00
parent 5e39e46954
commit f69938c6a8
4 changed files with 36 additions and 38 deletions

View File

@ -5,7 +5,7 @@ from invokeai.app.services.config.config_default import get_config
app_config = get_config() app_config = get_config()
app_config.parse_args() app_config.parse_args()
app_config.read_config() app_config.merge_from_file()
if True: # hack to make flake8 happy with imports coming after setting up the config if True: # hack to make flake8 happy with imports coming after setting up the config
import asyncio import asyncio

View File

@ -195,27 +195,33 @@ class InvokeAIAppConfig(BaseSettings):
if new_value != current_value: if new_value != current_value:
setattr(self, field_name, new_value) setattr(self, field_name, new_value)
def write_file(self, exclude_defaults: bool) -> None: def write_file(self, dest_path: Path) -> None:
"""Write the current configuration to the `invokeai.yaml` file. This will overwrite the existing file. """Write the current configuration to file. This will overwrite the existing file.
A `meta` stanza is added to the top of the file, containing metadata about the config file. This is not stored in the config object. A `meta` stanza is added to the top of the file, containing metadata about the config file. This is not stored in the config object.
Args: Args:
exclude_defaults: If `True`, only include settings that were explicitly set. If `False`, include all settings, including defaults. dest_path: Path to write the config to.
""" """
with open(self.init_file_path, "w") as file: with open(dest_path, "w") as file:
meta_dict = {"meta": ConfigMeta().model_dump()} meta_dict = {"meta": ConfigMeta().model_dump()}
config_dict = self.model_dump(mode="json", exclude_unset=True, exclude_defaults=exclude_defaults) config_dict = self.model_dump(mode="json", exclude_unset=True, exclude_defaults=True)
file.write("# Internal metadata\n") file.write("# Internal metadata\n")
file.write(yaml.dump(meta_dict, sort_keys=False)) file.write(yaml.dump(meta_dict, sort_keys=False))
file.write("\n") file.write("\n")
file.write("# User settings\n") file.write("# User settings\n")
file.write(yaml.dump(config_dict, sort_keys=False)) file.write(yaml.dump(config_dict, sort_keys=False))
def set_root(self, root: Path) -> None: def merge_from_file(self, source_path: Optional[Path] = None) -> None:
"""Set the runtime root directory. This is typically set using a CLI arg.""" """Read the config from the `invokeai.yaml` file, migrating it if necessary and merging it into the singleton config.
assert isinstance(root, Path)
self._root = root This function will write to the `invokeai.yaml` file if the config is migrated.
Args:
source_path: Path to the config file. If not provided, the default path is used.
"""
config_from_file = load_and_migrate_config(source_path or self.init_file_path)
self.update_config(config_from_file)
def parse_args(self) -> None: def parse_args(self) -> None:
"""Parse the CLI args and set the runtime root directory.""" """Parse the CLI args and set the runtime root directory."""
@ -223,10 +229,10 @@ class InvokeAIAppConfig(BaseSettings):
if root := getattr(opt, "root", None): if root := getattr(opt, "root", None):
self.set_root(Path(root)) self.set_root(Path(root))
def read_config(self) -> None: def set_root(self, root: Path) -> None:
"""Read the config from the `invokeai.yaml` file, merging it into the singleton config.""" """Set the runtime root directory. This is typically set using a CLI arg."""
config_from_file = load_and_migrate_config(self.init_file_path) assert isinstance(root, Path)
self.update_config(config_from_file) self._root = root
def _resolve(self, partial_path: Path) -> Path: def _resolve(self, partial_path: Path) -> Path:
return (self.root_path / partial_path).resolve() return (self.root_path / partial_path).resolve()
@ -330,14 +336,6 @@ def generate_config_docstrings() -> str:
return docstring return docstring
def load_config_from_file(config_path: Path) -> InvokeAIAppConfig:
"""Parse a config file into an InvokeAIAppConfig object. The file should be in YAML format."""
assert config_path.suffix == ".yaml"
with open(config_path) as file:
loaded_config = InvokeAIAppConfig.model_validate(yaml.safe_load(file))
return loaded_config
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig: def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate a v3 config dictionary to the latest version. """Migrate a v3 config dictionary to the latest version.
@ -383,7 +381,7 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
config_path.rename(config_path.with_suffix(".yaml.bak")) config_path.rename(config_path.with_suffix(".yaml.bak"))
# By excluding defaults, we ensure that the new config file only contains the settings that were explicitly set # By excluding defaults, we ensure that the new config file only contains the settings that were explicitly set
config.write_file(exclude_defaults=True) config.write_file(config_path)
return config return config
else: else:
# Attempt to load as a v4 config file # Attempt to load as a v4 config file

View File

@ -711,7 +711,8 @@ def run_console_ui(
) -> Tuple[Optional[Namespace], Optional[InstallSelections]]: ) -> Tuple[Optional[Namespace], Optional[InstallSelections]]:
first_time = not (config.root_path / "invokeai.yaml").exists() first_time = not (config.root_path / "invokeai.yaml").exists()
invokeai_opts = default_startup_options(initfile) if first_time else config invokeai_opts = default_startup_options(initfile) if first_time else config
invokeai_opts.set_root(program_opts.root) if program_opts.root:
invokeai_opts.set_root(Path(program_opts.root))
if not set_min_terminal_size(MIN_COLS, MIN_LINES): if not set_min_terminal_size(MIN_COLS, MIN_LINES):
raise WindowTooSmallException( raise WindowTooSmallException(
@ -731,15 +732,13 @@ def write_opts(opts: InvokeAIAppConfig, init_file: Path) -> None:
""" """
Update the invokeai.yaml file with values from current settings. Update the invokeai.yaml file with values from current settings.
""" """
# this will load current settings
new_config = get_config() new_config = get_config()
new_config.set_root(config.root_path) new_config.set_root(config.root_path)
for key, value in vars(opts).items(): # Remove any fields that are not in the model_fields list, like `hf_token`
if hasattr(new_config, key): cleaned_opts = {k: v for k, v in vars(opts).items() if k in new_config.model_fields}
setattr(new_config, key, value) new_config.update_config(cleaned_opts)
new_config.write_file(init_file)
new_config.write_file(exclude_defaults=True)
if hasattr(opts, "hf_token") and opts.hf_token: if hasattr(opts, "hf_token") and opts.hf_token:
HfLogin(opts.hf_token) HfLogin(opts.hf_token)
@ -819,12 +818,13 @@ def main() -> None:
help="path to root of install directory", help="path to root of install directory",
) )
opt = parser.parse_args() opt = parser.parse_args()
invoke_args: dict[str, Any] = {} updates: dict[str, Any] = {}
if opt.root: if opt.root:
invoke_args["root"] = opt.root config.set_root(Path(opt.root))
if opt.full_precision: if opt.full_precision:
invoke_args["precision"] = "float32" updates["precision"] = "float32"
config.update_config(invoke_args) config.merge_from_file()
config.update_config(updates)
logger = InvokeAILogger().get_logger(config=config) logger = InvokeAILogger().get_logger(config=config)
errors = set() errors = set()

View File

@ -104,12 +104,12 @@ def test_write_config_to_file():
with TemporaryDirectory() as tmpdir: with TemporaryDirectory() as tmpdir:
temp_config_path = Path(tmpdir) / "invokeai.yaml" temp_config_path = Path(tmpdir) / "invokeai.yaml"
config = InvokeAIAppConfig(host="192.168.1.1", port=8080) config = InvokeAIAppConfig(host="192.168.1.1", port=8080)
config.set_root(Path(tmpdir)) config.write_file(temp_config_path)
config.write_file(exclude_defaults=False)
# Load the file and check contents # Load the file and check contents
with open(temp_config_path, "r") as file: with open(temp_config_path, "r") as file:
content = file.read() content = file.read()
# This is a default value, so it should not be in the file
assert "pil_compress_level" not in content
assert "host: 192.168.1.1" in content assert "host: 192.168.1.1" in content
assert "port: 8080" in content assert "port: 8080" in content
@ -179,7 +179,7 @@ def test_deny_nodes(patch_rootdir):
) )
# must parse config before importing Graph, so its nodes union uses the config # must parse config before importing Graph, so its nodes union uses the config
conf = get_config() conf = get_config()
conf.read_config(conf=allow_deny_nodes_conf, argv=[]) conf.merge_from_file(conf=allow_deny_nodes_conf, argv=[])
from invokeai.app.services.shared.graph import Graph from invokeai.app.services.shared.graph import Graph
# confirm graph validation fails when using denied node # confirm graph validation fails when using denied node