diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 5a17898bcb..9fae89af79 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -5,6 +5,7 @@ from __future__ import annotations import os import re +import shutil from functools import lru_cache from pathlib import Path from typing import Any, Literal, Optional @@ -46,12 +47,6 @@ class URLRegexTokenPair(BaseModel): return v -class ConfigMeta(BaseModel): - """Metadata for the config file. This is not stored in the config object.""" - - schema_version: int = CONFIG_SCHEMA_VERSION - - class InvokeAIAppConfig(BaseSettings): """Invoke's global app configuration. @@ -109,6 +104,10 @@ class InvokeAIAppConfig(BaseSettings): # fmt: off + # INTERNAL + schema_version: int = Field(default=CONFIG_SCHEMA_VERSION, description="Schema version of the config file. This is not a user-configurable setting.") + legacy_models_yaml_path: Optional[Path] = Field(default=None, description="Path to the legacy models.yaml file. This is not a user-configurable setting.") + # WEB host: str = Field(default="127.0.0.1", description="IP address to bind to. Use `0.0.0.0` to serve to your local network.") port: int = Field(default=9090, description="Port to bind to.") @@ -175,11 +174,6 @@ class InvokeAIAppConfig(BaseSettings): hashing_algorithm: HASHING_ALGORITHMS = Field(default="blake3", description="Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.") remote_api_tokens: Optional[list[URLRegexTokenPair]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.") - # HIDDEN FIELDS - # v4 (MM2) doesn't use `models.yaml` files, but users were able to set paths in the v3 config. When we migrate a - # v3 config, we need to save the path to the models.yaml. This is only used during migration. - legacy_models_yaml_path: Optional[Path] = Field(default=None, description="The `conf_path` setting from a v3 `invokeai.yaml` file. Only present this app session migrated a config file, and it had `conf_test` on it.", exclude=True) - # fmt: on model_config = SettingsConfigDict(env_prefix="INVOKEAI_", env_ignore_empty=True) @@ -217,8 +211,20 @@ class InvokeAIAppConfig(BaseSettings): dest_path: Path to write the config to. """ with open(dest_path, "w") as file: - meta_dict = {"meta": ConfigMeta().model_dump()} - config_dict = self.model_dump(mode="json", exclude_unset=True, exclude_defaults=True) + # Meta fields should be written in a separate stanza + meta_dict = self.model_dump(mode="json", include={"schema_version"}) + # Only include the legacy_models_yaml_path if it's set + if self.legacy_models_yaml_path: + meta_dict.update(self.model_dump(mode="json", include={"legacy_models_yaml_path"})) + + # User settings + config_dict = self.model_dump( + mode="json", + exclude_unset=True, + exclude_defaults=True, + exclude={"schema_version", "legacy_models_yaml_path"}, + ) + file.write("# Internal metadata - do not edit:\n") file.write(yaml.dump(meta_dict, sort_keys=False)) file.write("\n") @@ -370,11 +376,12 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: if "InvokeAI" in loaded_config_dict: # This is a v3 config file, attempt to migrate it + shutil.copy(config_path, config_path.with_suffix(".yaml.bak")) try: config = migrate_v3_config_dict(loaded_config_dict) except Exception as e: + shutil.copy(config_path.with_suffix(".yaml.bak"), config_path) raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e - config_path.rename(config_path.with_suffix(".yaml.bak")) # By excluding defaults, we ensure that the new config file only contains the settings that were explicitly set config.write_file(config_path) return config @@ -382,11 +389,11 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: # Attempt to load as a v4 config file try: # Meta is not included in the model fields, so we need to validate it separately - config_meta = ConfigMeta.model_validate(loaded_config_dict.pop("meta")) + config = InvokeAIAppConfig.model_validate(loaded_config_dict) assert ( - config_meta.schema_version == CONFIG_SCHEMA_VERSION - ), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config_meta.schema_version}" - return InvokeAIAppConfig.model_validate(loaded_config_dict) + config.schema_version == CONFIG_SCHEMA_VERSION + ), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}" + return config except Exception as e: raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 807c18bcff..6330cc0969 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -292,45 +292,46 @@ class ModelInstallService(ModelInstallServiceBase): self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml" ) - if not legacy_models_yaml_path.exists(): - # No yaml to migrate - return + if legacy_models_yaml_path.exists(): + legacy_models_yaml = yaml.safe_load(legacy_models_yaml_path.read_text()) - legacy_models_yaml = yaml.safe_load(legacy_models_yaml_path.read_text()) + yaml_metadata = legacy_models_yaml.pop("__metadata__") + yaml_version = yaml_metadata.get("version") - yaml_metadata = legacy_models_yaml.pop("__metadata__") - yaml_version = yaml_metadata.get("version") + if yaml_version != "3.0.0": + raise ValueError( + f"Attempted migration of unsupported `models.yaml` v{yaml_version}. Only v3.0.0 is supported. Exiting." + ) - if yaml_version != "3.0.0": - raise ValueError( - f"Attempted migration of unsupported `models.yaml` v{yaml_version}. Only v3.0.0 is supported. Exiting." + self._logger.info( + f"Starting one-time migration of {len(legacy_models_yaml.items())} models from {str(legacy_models_yaml_path)}. This may take a few minutes." ) - self._logger.info( - f"Starting one-time migration of {len(legacy_models_yaml.items())} models from {str(legacy_models_yaml_path)}. This may take a few minutes." - ) + if len(db_models) == 0 and len(legacy_models_yaml.items()) != 0: + for model_key, stanza in legacy_models_yaml.items(): + _, _, model_name = str(model_key).split("/") + model_path = Path(stanza["path"]) + if not model_path.is_absolute(): + model_path = self._app_config.models_path / model_path + model_path = model_path.resolve() - if len(db_models) == 0 and len(legacy_models_yaml.items()) != 0: - for model_key, stanza in legacy_models_yaml.items(): - _, _, model_name = str(model_key).split("/") - model_path = Path(stanza["path"]) - if not model_path.is_absolute(): - model_path = self._app_config.models_path / model_path - model_path = model_path.resolve() + config: dict[str, Any] = {} + config["name"] = model_name + config["description"] = stanza.get("description") + config["config_path"] = stanza.get("config") - config: dict[str, Any] = {} - config["name"] = model_name - config["description"] = stanza.get("description") - config["config_path"] = stanza.get("config") + try: + id = self.register_path(model_path=model_path, config=config) + self._logger.info(f"Migrated {model_name} with id {id}") + except Exception as e: + self._logger.warning(f"Model at {model_path} could not be migrated: {e}") - try: - id = self.register_path(model_path=model_path, config=config) - self._logger.info(f"Migrated {model_name} with id {id}") - except Exception as e: - self._logger.warning(f"Model at {model_path} could not be migrated: {e}") + # Rename `models.yaml` to `models.yaml.bak` to prevent re-migration + legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak")) - # Rename `models.yaml` to `models.yaml.bak` to prevent re-migration - legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak")) + # Remove `legacy_models_yaml_path` from the config file - we are done with it either way + self._app_config.legacy_models_yaml_path = None + self._app_config.write_file(self._app_config.init_file_path) def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 self._cached_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()} diff --git a/invokeai/backend/install/invokeai_configure.py b/invokeai/backend/install/invokeai_configure.py index ed6b6512de..a863a98a8a 100755 --- a/invokeai/backend/install/invokeai_configure.py +++ b/invokeai/backend/install/invokeai_configure.py @@ -34,6 +34,7 @@ from transformers import AutoFeatureExtractor import invokeai.configs as model_configs from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.config.config_default import get_config from invokeai.backend.install.install_helper import InstallHelper, InstallSelections from invokeai.backend.model_manager import ModelType from invokeai.backend.util import choose_precision, choose_torch_device @@ -63,8 +64,7 @@ def get_literal_fields(field: str) -> Tuple[Any]: # --------------------------globals----------------------- -# Start from a fresh config object - we will read the user's config from file later, and update it with their choices -config = InvokeAIAppConfig() +config = get_config() PRECISION_CHOICES = get_literal_fields("precision") DEVICE_CHOICES = get_literal_fields("device") diff --git a/scripts/update_config_docstring.py b/scripts/update_config_docstring.py index da9bad2734..081d03d62b 100644 --- a/scripts/update_config_docstring.py +++ b/scripts/update_config_docstring.py @@ -3,6 +3,8 @@ from typing import Literal, get_args, get_type_hints from invokeai.app.services.config.config_default import InvokeAIAppConfig +_excluded = {"schema_version", "legacy_models_yaml_path"} + def generate_config_docstrings() -> str: """Helper function for mkdocs. Generates a docstring for the InvokeAIAppConfig class. @@ -20,7 +22,7 @@ def generate_config_docstrings() -> str: type_hints = get_type_hints(InvokeAIAppConfig) for k, v in InvokeAIAppConfig.model_fields.items(): - if v.exclude: + if v.exclude or k in _excluded: continue field_type = type_hints.get(k) extra = "" diff --git a/tests/test_config.py b/tests/test_config.py index b74e26debe..617e28785d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -9,16 +9,14 @@ from pydantic import ValidationError from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config, load_and_migrate_config v4_config = """ -meta: - schema_version: 4 +schema_version: 4 host: "192.168.1.1" port: 8080 """ invalid_v5_config = """ -meta: - schema_version: 5 +schema_version: 5 host: "192.168.1.1" port: 8080 @@ -44,6 +42,12 @@ InvokeAI: max_vram_cache_size: 50 """ +v3_config_with_bad_values = """ +InvokeAI: + Web Server: + port: "ice cream" +""" + invalid_config = """ i like turtles """ @@ -88,6 +92,29 @@ def test_migrate_v3_config_from_file(tmp_path: Path): assert not hasattr(config, "esrgan") +def test_migrate_v3_backup(tmp_path: Path): + """Test the backup of the config file.""" + temp_config_file = tmp_path / "temp_invokeai.yaml" + temp_config_file.write_text(v3_config) + + load_and_migrate_config(temp_config_file) + assert temp_config_file.with_suffix(".yaml.bak").exists() + assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config + + +def test_failed_migrate_backup(tmp_path: Path): + """Test the failed migration of the config file.""" + temp_config_file = tmp_path / "temp_invokeai.yaml" + temp_config_file.write_text(v3_config_with_bad_values) + + with pytest.raises(RuntimeError): + load_and_migrate_config(temp_config_file) + assert temp_config_file.with_suffix(".yaml.bak").exists() + assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config_with_bad_values + assert temp_config_file.exists() + assert temp_config_file.read_text() == v3_config_with_bad_values + + def test_bails_on_invalid_config(tmp_path: Path): """Test reading configuration from a file.""" temp_config_file = tmp_path / "temp_invokeai.yaml"