fix(config): edge cases in models.yaml migration

When running the configurator, the `legacy_models_conf_path` was stripped when saving the config file. Then the migration logic didn't fire correctly, and the custom models.yaml paths weren't migrated into the db.

- Rework the logic to migrate this path by adding it to the config object as a normal field that is not excluded from serialization.
- Rearrange the models.yaml migration logic to remove the legacy path after migrating, then write the config file. This way, the legacy path doesn't stick around.
- Move the schema version into the config object.
- Back up the config file before attempting migration.
- Add tests to cover this edge case
This commit is contained in:
psychedelicious 2024-03-15 23:21:21 +11:00
parent 1ed1c1fb24
commit e76cc71e81
5 changed files with 92 additions and 55 deletions

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import os import os
import re import re
import shutil
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
@ -46,12 +47,6 @@ class URLRegexTokenPair(BaseModel):
return v return v
class ConfigMeta(BaseModel):
"""Metadata for the config file. This is not stored in the config object."""
schema_version: int = CONFIG_SCHEMA_VERSION
class InvokeAIAppConfig(BaseSettings): class InvokeAIAppConfig(BaseSettings):
"""Invoke's global app configuration. """Invoke's global app configuration.
@ -109,6 +104,10 @@ class InvokeAIAppConfig(BaseSettings):
# fmt: off # fmt: off
# INTERNAL
schema_version: int = Field(default=CONFIG_SCHEMA_VERSION, description="Schema version of the config file. This is not a user-configurable setting.")
legacy_models_yaml_path: Optional[Path] = Field(default=None, description="Path to the legacy models.yaml file. This is not a user-configurable setting.")
# WEB # WEB
host: str = Field(default="127.0.0.1", description="IP address to bind to. Use `0.0.0.0` to serve to your local network.") host: str = Field(default="127.0.0.1", description="IP address to bind to. Use `0.0.0.0` to serve to your local network.")
port: int = Field(default=9090, description="Port to bind to.") port: int = Field(default=9090, description="Port to bind to.")
@ -175,11 +174,6 @@ class InvokeAIAppConfig(BaseSettings):
hashing_algorithm: HASHING_ALGORITHMS = Field(default="blake3", description="Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.") hashing_algorithm: HASHING_ALGORITHMS = Field(default="blake3", description="Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.")
remote_api_tokens: Optional[list[URLRegexTokenPair]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.") remote_api_tokens: Optional[list[URLRegexTokenPair]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.")
# HIDDEN FIELDS
# v4 (MM2) doesn't use `models.yaml` files, but users were able to set paths in the v3 config. When we migrate a
# v3 config, we need to save the path to the models.yaml. This is only used during migration.
legacy_models_yaml_path: Optional[Path] = Field(default=None, description="The `conf_path` setting from a v3 `invokeai.yaml` file. Only present this app session migrated a config file, and it had `conf_test` on it.", exclude=True)
# fmt: on # fmt: on
model_config = SettingsConfigDict(env_prefix="INVOKEAI_", env_ignore_empty=True) model_config = SettingsConfigDict(env_prefix="INVOKEAI_", env_ignore_empty=True)
@ -217,8 +211,20 @@ class InvokeAIAppConfig(BaseSettings):
dest_path: Path to write the config to. dest_path: Path to write the config to.
""" """
with open(dest_path, "w") as file: with open(dest_path, "w") as file:
meta_dict = {"meta": ConfigMeta().model_dump()} # Meta fields should be written in a separate stanza
config_dict = self.model_dump(mode="json", exclude_unset=True, exclude_defaults=True) meta_dict = self.model_dump(mode="json", include={"schema_version"})
# Only include the legacy_models_yaml_path if it's set
if self.legacy_models_yaml_path:
meta_dict.update(self.model_dump(mode="json", include={"legacy_models_yaml_path"}))
# User settings
config_dict = self.model_dump(
mode="json",
exclude_unset=True,
exclude_defaults=True,
exclude={"schema_version", "legacy_models_yaml_path"},
)
file.write("# Internal metadata - do not edit:\n") file.write("# Internal metadata - do not edit:\n")
file.write(yaml.dump(meta_dict, sort_keys=False)) file.write(yaml.dump(meta_dict, sort_keys=False))
file.write("\n") file.write("\n")
@ -370,11 +376,12 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
if "InvokeAI" in loaded_config_dict: if "InvokeAI" in loaded_config_dict:
# This is a v3 config file, attempt to migrate it # This is a v3 config file, attempt to migrate it
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try: try:
config = migrate_v3_config_dict(loaded_config_dict) config = migrate_v3_config_dict(loaded_config_dict)
except Exception as e: except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
config_path.rename(config_path.with_suffix(".yaml.bak"))
# By excluding defaults, we ensure that the new config file only contains the settings that were explicitly set # By excluding defaults, we ensure that the new config file only contains the settings that were explicitly set
config.write_file(config_path) config.write_file(config_path)
return config return config
@ -382,11 +389,11 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
# Attempt to load as a v4 config file # Attempt to load as a v4 config file
try: try:
# Meta is not included in the model fields, so we need to validate it separately # Meta is not included in the model fields, so we need to validate it separately
config_meta = ConfigMeta.model_validate(loaded_config_dict.pop("meta")) config = InvokeAIAppConfig.model_validate(loaded_config_dict)
assert ( assert (
config_meta.schema_version == CONFIG_SCHEMA_VERSION config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config_meta.schema_version}" ), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
return InvokeAIAppConfig.model_validate(loaded_config_dict) return config
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e

View File

@ -292,45 +292,46 @@ class ModelInstallService(ModelInstallServiceBase):
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml" self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
) )
if not legacy_models_yaml_path.exists(): if legacy_models_yaml_path.exists():
# No yaml to migrate legacy_models_yaml = yaml.safe_load(legacy_models_yaml_path.read_text())
return
legacy_models_yaml = yaml.safe_load(legacy_models_yaml_path.read_text()) yaml_metadata = legacy_models_yaml.pop("__metadata__")
yaml_version = yaml_metadata.get("version")
yaml_metadata = legacy_models_yaml.pop("__metadata__") if yaml_version != "3.0.0":
yaml_version = yaml_metadata.get("version") raise ValueError(
f"Attempted migration of unsupported `models.yaml` v{yaml_version}. Only v3.0.0 is supported. Exiting."
)
if yaml_version != "3.0.0": self._logger.info(
raise ValueError( f"Starting one-time migration of {len(legacy_models_yaml.items())} models from {str(legacy_models_yaml_path)}. This may take a few minutes."
f"Attempted migration of unsupported `models.yaml` v{yaml_version}. Only v3.0.0 is supported. Exiting."
) )
self._logger.info( if len(db_models) == 0 and len(legacy_models_yaml.items()) != 0:
f"Starting one-time migration of {len(legacy_models_yaml.items())} models from {str(legacy_models_yaml_path)}. This may take a few minutes." for model_key, stanza in legacy_models_yaml.items():
) _, _, model_name = str(model_key).split("/")
model_path = Path(stanza["path"])
if not model_path.is_absolute():
model_path = self._app_config.models_path / model_path
model_path = model_path.resolve()
if len(db_models) == 0 and len(legacy_models_yaml.items()) != 0: config: dict[str, Any] = {}
for model_key, stanza in legacy_models_yaml.items(): config["name"] = model_name
_, _, model_name = str(model_key).split("/") config["description"] = stanza.get("description")
model_path = Path(stanza["path"]) config["config_path"] = stanza.get("config")
if not model_path.is_absolute():
model_path = self._app_config.models_path / model_path
model_path = model_path.resolve()
config: dict[str, Any] = {} try:
config["name"] = model_name id = self.register_path(model_path=model_path, config=config)
config["description"] = stanza.get("description") self._logger.info(f"Migrated {model_name} with id {id}")
config["config_path"] = stanza.get("config") except Exception as e:
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
try: # Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
id = self.register_path(model_path=model_path, config=config) legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
self._logger.info(f"Migrated {model_name} with id {id}")
except Exception as e:
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration # Remove `legacy_models_yaml_path` from the config file - we are done with it either way
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak")) self._app_config.legacy_models_yaml_path = None
self._app_config.write_file(self._app_config.init_file_path)
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()} self._cached_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()}

View File

@ -34,6 +34,7 @@ from transformers import AutoFeatureExtractor
import invokeai.configs as model_configs import invokeai.configs as model_configs
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
from invokeai.backend.model_manager import ModelType from invokeai.backend.model_manager import ModelType
from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util import choose_precision, choose_torch_device
@ -63,8 +64,7 @@ def get_literal_fields(field: str) -> Tuple[Any]:
# --------------------------globals----------------------- # --------------------------globals-----------------------
# Start from a fresh config object - we will read the user's config from file later, and update it with their choices config = get_config()
config = InvokeAIAppConfig()
PRECISION_CHOICES = get_literal_fields("precision") PRECISION_CHOICES = get_literal_fields("precision")
DEVICE_CHOICES = get_literal_fields("device") DEVICE_CHOICES = get_literal_fields("device")

View File

@ -3,6 +3,8 @@ from typing import Literal, get_args, get_type_hints
from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.config.config_default import InvokeAIAppConfig
_excluded = {"schema_version", "legacy_models_yaml_path"}
def generate_config_docstrings() -> str: def generate_config_docstrings() -> str:
"""Helper function for mkdocs. Generates a docstring for the InvokeAIAppConfig class. """Helper function for mkdocs. Generates a docstring for the InvokeAIAppConfig class.
@ -20,7 +22,7 @@ def generate_config_docstrings() -> str:
type_hints = get_type_hints(InvokeAIAppConfig) type_hints = get_type_hints(InvokeAIAppConfig)
for k, v in InvokeAIAppConfig.model_fields.items(): for k, v in InvokeAIAppConfig.model_fields.items():
if v.exclude: if v.exclude or k in _excluded:
continue continue
field_type = type_hints.get(k) field_type = type_hints.get(k)
extra = "" extra = ""

View File

@ -9,16 +9,14 @@ from pydantic import ValidationError
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config, load_and_migrate_config from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config, load_and_migrate_config
v4_config = """ v4_config = """
meta: schema_version: 4
schema_version: 4
host: "192.168.1.1" host: "192.168.1.1"
port: 8080 port: 8080
""" """
invalid_v5_config = """ invalid_v5_config = """
meta: schema_version: 5
schema_version: 5
host: "192.168.1.1" host: "192.168.1.1"
port: 8080 port: 8080
@ -44,6 +42,12 @@ InvokeAI:
max_vram_cache_size: 50 max_vram_cache_size: 50
""" """
v3_config_with_bad_values = """
InvokeAI:
Web Server:
port: "ice cream"
"""
invalid_config = """ invalid_config = """
i like turtles i like turtles
""" """
@ -88,6 +92,29 @@ def test_migrate_v3_config_from_file(tmp_path: Path):
assert not hasattr(config, "esrgan") assert not hasattr(config, "esrgan")
def test_migrate_v3_backup(tmp_path: Path):
"""Test the backup of the config file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v3_config)
load_and_migrate_config(temp_config_file)
assert temp_config_file.with_suffix(".yaml.bak").exists()
assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config
def test_failed_migrate_backup(tmp_path: Path):
"""Test the failed migration of the config file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v3_config_with_bad_values)
with pytest.raises(RuntimeError):
load_and_migrate_config(temp_config_file)
assert temp_config_file.with_suffix(".yaml.bak").exists()
assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config_with_bad_values
assert temp_config_file.exists()
assert temp_config_file.read_text() == v3_config_with_bad_values
def test_bails_on_invalid_config(tmp_path: Path): def test_bails_on_invalid_config(tmp_path: Path):
"""Test reading configuration from a file.""" """Test reading configuration from a file."""
temp_config_file = tmp_path / "temp_invokeai.yaml" temp_config_file = tmp_path / "temp_invokeai.yaml"