fix(config): revised config methods

- `write_file` requires an destination file path
- `read_config` -> `merge_from_file`, if no path is provided, reads from `self.init_file_path`
- update app, tests to use new methods
- fix configurator, was overwriting config file data unexpectedly
This commit is contained in:
psychedelicious 2024-03-12 01:45:12 +11:00
parent 5e39e46954
commit f69938c6a8
4 changed files with 36 additions and 38 deletions

View File

@ -5,7 +5,7 @@ from invokeai.app.services.config.config_default import get_config
app_config = get_config()
app_config.parse_args()
app_config.read_config()
app_config.merge_from_file()
if True: # hack to make flake8 happy with imports coming after setting up the config
import asyncio

View File

@ -195,27 +195,33 @@ class InvokeAIAppConfig(BaseSettings):
if new_value != current_value:
setattr(self, field_name, new_value)
def write_file(self, exclude_defaults: bool) -> None:
"""Write the current configuration to the `invokeai.yaml` file. This will overwrite the existing file.
def write_file(self, dest_path: Path) -> None:
"""Write the current configuration to file. This will overwrite the existing file.
A `meta` stanza is added to the top of the file, containing metadata about the config file. This is not stored in the config object.
Args:
exclude_defaults: If `True`, only include settings that were explicitly set. If `False`, include all settings, including defaults.
dest_path: Path to write the config to.
"""
with open(self.init_file_path, "w") as file:
with open(dest_path, "w") as file:
meta_dict = {"meta": ConfigMeta().model_dump()}
config_dict = self.model_dump(mode="json", exclude_unset=True, exclude_defaults=exclude_defaults)
config_dict = self.model_dump(mode="json", exclude_unset=True, exclude_defaults=True)
file.write("# Internal metadata\n")
file.write(yaml.dump(meta_dict, sort_keys=False))
file.write("\n")
file.write("# User settings\n")
file.write(yaml.dump(config_dict, sort_keys=False))
def set_root(self, root: Path) -> None:
"""Set the runtime root directory. This is typically set using a CLI arg."""
assert isinstance(root, Path)
self._root = root
def merge_from_file(self, source_path: Optional[Path] = None) -> None:
"""Read the config from the `invokeai.yaml` file, migrating it if necessary and merging it into the singleton config.
This function will write to the `invokeai.yaml` file if the config is migrated.
Args:
source_path: Path to the config file. If not provided, the default path is used.
"""
config_from_file = load_and_migrate_config(source_path or self.init_file_path)
self.update_config(config_from_file)
def parse_args(self) -> None:
"""Parse the CLI args and set the runtime root directory."""
@ -223,10 +229,10 @@ class InvokeAIAppConfig(BaseSettings):
if root := getattr(opt, "root", None):
self.set_root(Path(root))
def read_config(self) -> None:
"""Read the config from the `invokeai.yaml` file, merging it into the singleton config."""
config_from_file = load_and_migrate_config(self.init_file_path)
self.update_config(config_from_file)
def set_root(self, root: Path) -> None:
"""Set the runtime root directory. This is typically set using a CLI arg."""
assert isinstance(root, Path)
self._root = root
def _resolve(self, partial_path: Path) -> Path:
return (self.root_path / partial_path).resolve()
@ -330,14 +336,6 @@ def generate_config_docstrings() -> str:
return docstring
def load_config_from_file(config_path: Path) -> InvokeAIAppConfig:
"""Parse a config file into an InvokeAIAppConfig object. The file should be in YAML format."""
assert config_path.suffix == ".yaml"
with open(config_path) as file:
loaded_config = InvokeAIAppConfig.model_validate(yaml.safe_load(file))
return loaded_config
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate a v3 config dictionary to the latest version.
@ -383,7 +381,7 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
config_path.rename(config_path.with_suffix(".yaml.bak"))
# By excluding defaults, we ensure that the new config file only contains the settings that were explicitly set
config.write_file(exclude_defaults=True)
config.write_file(config_path)
return config
else:
# Attempt to load as a v4 config file

View File

@ -711,7 +711,8 @@ def run_console_ui(
) -> Tuple[Optional[Namespace], Optional[InstallSelections]]:
first_time = not (config.root_path / "invokeai.yaml").exists()
invokeai_opts = default_startup_options(initfile) if first_time else config
invokeai_opts.set_root(program_opts.root)
if program_opts.root:
invokeai_opts.set_root(Path(program_opts.root))
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
raise WindowTooSmallException(
@ -731,15 +732,13 @@ def write_opts(opts: InvokeAIAppConfig, init_file: Path) -> None:
"""
Update the invokeai.yaml file with values from current settings.
"""
# this will load current settings
new_config = get_config()
new_config.set_root(config.root_path)
for key, value in vars(opts).items():
if hasattr(new_config, key):
setattr(new_config, key, value)
new_config.write_file(exclude_defaults=True)
# Remove any fields that are not in the model_fields list, like `hf_token`
cleaned_opts = {k: v for k, v in vars(opts).items() if k in new_config.model_fields}
new_config.update_config(cleaned_opts)
new_config.write_file(init_file)
if hasattr(opts, "hf_token") and opts.hf_token:
HfLogin(opts.hf_token)
@ -819,12 +818,13 @@ def main() -> None:
help="path to root of install directory",
)
opt = parser.parse_args()
invoke_args: dict[str, Any] = {}
updates: dict[str, Any] = {}
if opt.root:
invoke_args["root"] = opt.root
config.set_root(Path(opt.root))
if opt.full_precision:
invoke_args["precision"] = "float32"
config.update_config(invoke_args)
updates["precision"] = "float32"
config.merge_from_file()
config.update_config(updates)
logger = InvokeAILogger().get_logger(config=config)
errors = set()

View File

@ -104,12 +104,12 @@ def test_write_config_to_file():
with TemporaryDirectory() as tmpdir:
temp_config_path = Path(tmpdir) / "invokeai.yaml"
config = InvokeAIAppConfig(host="192.168.1.1", port=8080)
config.set_root(Path(tmpdir))
config.write_file(exclude_defaults=False)
config.write_file(temp_config_path)
# Load the file and check contents
with open(temp_config_path, "r") as file:
content = file.read()
# This is a default value, so it should not be in the file
assert "pil_compress_level" not in content
assert "host: 192.168.1.1" in content
assert "port: 8080" in content
@ -179,7 +179,7 @@ def test_deny_nodes(patch_rootdir):
)
# must parse config before importing Graph, so its nodes union uses the config
conf = get_config()
conf.read_config(conf=allow_deny_nodes_conf, argv=[])
conf.merge_from_file(conf=allow_deny_nodes_conf, argv=[])
from invokeai.app.services.shared.graph import Graph
# confirm graph validation fails when using denied node