fix(config): revised config methods

- `write_file` requires an destination file path
- `read_config` -> `merge_from_file`, if no path is provided, reads from `self.init_file_path`
- update app, tests to use new methods
- fix configurator, was overwriting config file data unexpectedly
This commit is contained in:
psychedelicious
2024-03-12 01:45:12 +11:00
parent 5e39e46954
commit f69938c6a8
4 changed files with 36 additions and 38 deletions

View File

@ -711,7 +711,8 @@ def run_console_ui(
) -> Tuple[Optional[Namespace], Optional[InstallSelections]]:
first_time = not (config.root_path / "invokeai.yaml").exists()
invokeai_opts = default_startup_options(initfile) if first_time else config
invokeai_opts.set_root(program_opts.root)
if program_opts.root:
invokeai_opts.set_root(Path(program_opts.root))
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
raise WindowTooSmallException(
@ -731,15 +732,13 @@ def write_opts(opts: InvokeAIAppConfig, init_file: Path) -> None:
"""
Update the invokeai.yaml file with values from current settings.
"""
# this will load current settings
new_config = get_config()
new_config.set_root(config.root_path)
for key, value in vars(opts).items():
if hasattr(new_config, key):
setattr(new_config, key, value)
new_config.write_file(exclude_defaults=True)
# Remove any fields that are not in the model_fields list, like `hf_token`
cleaned_opts = {k: v for k, v in vars(opts).items() if k in new_config.model_fields}
new_config.update_config(cleaned_opts)
new_config.write_file(init_file)
if hasattr(opts, "hf_token") and opts.hf_token:
HfLogin(opts.hf_token)
@ -819,12 +818,13 @@ def main() -> None:
help="path to root of install directory",
)
opt = parser.parse_args()
invoke_args: dict[str, Any] = {}
updates: dict[str, Any] = {}
if opt.root:
invoke_args["root"] = opt.root
config.set_root(Path(opt.root))
if opt.full_precision:
invoke_args["precision"] = "float32"
config.update_config(invoke_args)
updates["precision"] = "float32"
config.merge_from_file()
config.update_config(updates)
logger = InvokeAILogger().get_logger(config=config)
errors = set()