fix(config): do not write env vars to config files

Add class `DefaultInvokeAIAppConfig`, which inherits from `InvokeAIAppConfig`. When instantiated, this class does not parse environment variables, so it outputs a "clean" default config. That's the only difference.

Then, we can use this new class in the 3 places:
- When creating the example config file (no env vars should be here)
- When migrating a v3 config (we want to instantiate the migrated config without env vars, so that when we write it out, they are not written to disk)
- When creating a fresh config file (i.e. on first run with an uninitialized root or new config file path - no env vars here!)
This commit is contained in:
psychedelicious 2024-03-21 11:55:49 +11:00
parent d0a936ebd4
commit f538ed54fb

View File

@ -13,7 +13,7 @@ from typing import Any, Literal, Optional
import psutil
import yaml
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
import invokeai.configs as model_configs
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
@ -332,6 +332,27 @@ class InvokeAIAppConfig(BaseSettings):
return root
class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
"""A version of `InvokeAIAppConfig` that does not automatically parse any settings from environment variables
or any file.
This is useful for writing out a default config file.
Note that init settings are set if provided.
"""
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (init_settings,)
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate a v3 config dictionary to a current config object.
@ -367,7 +388,8 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
config = InvokeAIAppConfig.model_validate(parsed_config_dict)
# When migrating the config file, we should not include currently-set environment variables.
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
@ -391,14 +413,13 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
# This is a v3 config file, attempt to migrate it
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
# This could be the wrong shape, but we will catch all exceptions below
config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
migrated_config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
# By excluding defaults, we ensure that the new config file only contains the settings that were explicitly set
config.write_file(config_path)
return config
migrated_config.write_file(config_path)
return migrated_config
else:
# Attempt to load as a v4 config file
try:
@ -426,6 +447,7 @@ def get_config() -> InvokeAIAppConfig:
On subsequent calls, the object is returned from the cache.
"""
# This object includes environment variables, as parsed by pydantic-settings
config = InvokeAIAppConfig()
args = InvokeAIArgs.args
@ -441,8 +463,8 @@ def get_config() -> InvokeAIAppConfig:
if config_file := getattr(args, "config_file", None):
config._config_file = Path(config_file)
# Create the example file from a deep copy, with some extra values provided
example_config = config.model_copy(deep=True)
# Create the example config file, with some extra example values provided
example_config = DefaultInvokeAIAppConfig()
example_config.remote_api_tokens = [
URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
@ -454,10 +476,12 @@ def get_config() -> InvokeAIAppConfig:
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
if config.config_file_path.exists():
incoming_config = load_and_migrate_config(config.config_file_path)
config_from_file = load_and_migrate_config(config.config_file_path)
# Clobbering here will overwrite any settings that were set via environment variables
config.update_config(incoming_config, clobber=False)
config.update_config(config_from_file, clobber=False)
else:
config.write_file(config.config_file_path)
# We should never write env vars to the config file
default_config = DefaultInvokeAIAppConfig()
default_config.write_file(config.config_file_path, as_example=False)
return config