make config migrator into an instance; refactor location of get_config()

This commit is contained in:
Lincoln Stein 2024-05-02 23:45:34 -04:00
parent d5aee87684
commit a48abfacf4
21 changed files with 314 additions and 219 deletions

View File

@ -26,7 +26,7 @@ import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import TorchDevice

View File

@ -3,7 +3,7 @@ import sys
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
custom_nodes_path = Path(get_config().custom_nodes_path)
custom_nodes_path.mkdir(parents=True, exist_ok=True)

View File

@ -33,7 +33,7 @@ from invokeai.app.invocations.fields import (
FieldKind,
Input,
)
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import uuid_string

View File

@ -2,6 +2,7 @@
from invokeai.app.services.config.config_common import PagingArgumentParser
from .config_default import InvokeAIAppConfig, get_config
from .config_default import InvokeAIAppConfig
from .config_migrate import get_config
__all__ = ["InvokeAIAppConfig", "get_config", "PagingArgumentParser"]

View File

@ -3,11 +3,8 @@
from __future__ import annotations
import locale
import os
import re
import shutil
from functools import lru_cache
from pathlib import Path
from typing import Any, Literal, Optional
@ -16,11 +13,7 @@ import yaml
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
import invokeai.configs as model_configs
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
from .config_migrate import ConfigMigrator
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
@ -348,162 +341,3 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (init_settings,)
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
Args:
config_path: Path to the config file.
Returns:
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
"""
assert config_path.suffix == ".yaml"
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
loaded_config_dict = yaml.safe_load(file)
assert isinstance(loaded_config_dict, dict)
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
migrated_config_dict = ConfigMigrator.migrate(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
# Attempt to load as a v4 config file
try:
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
@lru_cache(maxsize=1)
def get_config() -> InvokeAIAppConfig:
"""Get the global singleton app config.
When first called, this function:
- Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file.
- Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint.
- Sets the root dir, if provided via CLI args.
- Logs in to HF if there is no valid token already.
- Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers).
- Reads and merges in settings from the config file if it exists, else writes out a default config file.
On subsequent calls, the object is returned from the cache.
"""
# This object includes environment variables, as parsed by pydantic-settings
config = InvokeAIAppConfig()
args = InvokeAIArgs.args
# This flag serves as a proxy for whether the config was retrieved in the context of the full application or not.
# If it is False, we should just return a default config and not set the root, log in to HF, etc.
if not InvokeAIArgs.did_parse:
return config
# Set CLI args
if root := getattr(args, "root", None):
config._root = Path(root)
if config_file := getattr(args, "config_file", None):
config._config_file = Path(config_file)
# Create the example config file, with some extra example values provided
example_config = DefaultInvokeAIAppConfig()
example_config.remote_api_tokens = [
URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
]
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
# Copy all legacy configs - We know `__path__[0]` is correct here
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
if config.config_file_path.exists():
config_from_file = load_and_migrate_config(config.config_file_path)
config_from_file.write_file(config.config_file_path)
# Clobbering here will overwrite any settings that were set via environment variables
config.update_config(config_from_file, clobber=False)
else:
# We should never write env vars to the config file
default_config = DefaultInvokeAIAppConfig()
default_config.write_file(config.config_file_path, as_example=False)
return config
####################################################
# VERSION MIGRATIONS
####################################################
@ConfigMigrator.register(from_version="3.0.0", to_version="4.0.0")
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate a v3 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v3 config file.
Returns:
A dictionary of settings from a 4.0.0 config file.
"""
parsed_config_dict: dict[str, Any] = {}
for _category_name, category_dict in config_dict["InvokeAI"].items():
for k, v in category_dict.items():
# `outdir` was renamed to `outputs_dir` in v4
if k == "outdir":
parsed_config_dict["outputs_dir"] = v
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
# If if the incoming config has the default value, skip
continue
elif Path(v).name == "stable-diffusion":
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
else:
# Else we do not attempt to migrate this setting
parsed_config_dict["legacy_conf_dir"] = v
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
return parsed_config_dict
@ConfigMigrator.register(from_version="4.0.0", to_version="4.0.1")
def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate v4.0.0 config dictionary to v4.0.1.
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
A dictionary of settings from a v4.0.1 config file
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
return parsed_config_dict

View File

@ -4,12 +4,22 @@
Utility class for migrating among versions of the InvokeAI app config schema.
"""
import locale
import shutil
from dataclasses import dataclass
from typing import Any, Callable, List, TypeAlias
from functools import lru_cache
from pathlib import Path
from typing import Callable, List, TypeAlias
import yaml
from packaging.version import Version
AppConfigDict: TypeAlias = dict[str, Any]
import invokeai.configs as model_configs
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
from .config_default import CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, URLRegexTokenPair
from .migrations import AppConfigDict, Migrations, MigrationsBase
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
@ -25,22 +35,23 @@ class MigrationEntry:
class ConfigMigrator:
"""This class allows migrators to register their input and output versions."""
_migrations: List[MigrationEntry] = []
def __init__(self, migrations: type[MigrationsBase]) -> None:
self._migrations: List[MigrationEntry] = []
migrations.load(self)
@classmethod
def register(
cls,
self,
from_version: str,
to_version: str,
) -> Callable[[MigrationFunction], MigrationFunction]:
"""Define a decorator which registers the migration between two versions."""
def decorator(function: MigrationFunction) -> MigrationFunction:
if any(from_version == m.from_version for m in cls._migrations):
if any((from_version == m.from_version) or (to_version == m.to_version) for m in self._migrations):
raise ValueError(
f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered."
)
cls._migrations.append(
self._migrations.append(
MigrationEntry(from_version=Version(from_version), to_version=Version(to_version), function=function)
)
return function
@ -57,8 +68,7 @@ class ConfigMigrator:
)
current_version = m.to_version
@classmethod
def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict:
def run_migrations(self, config_dict: AppConfigDict) -> AppConfigDict:
"""
Use the registered migration steps to bring config up to latest version.
@ -72,8 +82,8 @@ class ConfigMigrator:
"""
# Sort migrations by version number and raise a ValueError if
# any version range overlaps are detected.
sorted_migrations = sorted(cls._migrations, key=lambda x: x.from_version)
cls._check_for_discontinuities(sorted_migrations)
sorted_migrations = sorted(self._migrations, key=lambda x: x.from_version)
self._check_for_discontinuities(sorted_migrations)
if "InvokeAI" in config_dict:
version = Version("3.0.0")
@ -87,3 +97,92 @@ class ConfigMigrator:
config_dict["schema_version"] = str(version)
return config_dict
@lru_cache(maxsize=1)
def get_config() -> InvokeAIAppConfig:
"""Get the global singleton app config.
When first called, this function:
- Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file.
- Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint.
- Sets the root dir, if provided via CLI args.
- Logs in to HF if there is no valid token already.
- Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers).
- Reads and merges in settings from the config file if it exists, else writes out a default config file.
On subsequent calls, the object is returned from the cache.
"""
# This object includes environment variables, as parsed by pydantic-settings
config = InvokeAIAppConfig()
args = InvokeAIArgs.args
# This flag serves as a proxy for whether the config was retrieved in the context of the full application or not.
# If it is False, we should just return a default config and not set the root, log in to HF, etc.
if not InvokeAIArgs.did_parse:
return config
# Set CLI args
if root := getattr(args, "root", None):
config._root = Path(root)
if config_file := getattr(args, "config_file", None):
config._config_file = Path(config_file)
# Create the example config file, with some extra example values provided
example_config = DefaultInvokeAIAppConfig()
example_config.remote_api_tokens = [
URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
]
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
# Copy all legacy configs - We know `__path__[0]` is correct here
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
if config.config_file_path.exists():
config_from_file = load_and_migrate_config(config.config_file_path)
config_from_file.write_file(config.config_file_path)
# Clobbering here will overwrite any settings that were set via environment variables
config.update_config(config_from_file, clobber=False)
else:
# We should never write env vars to the config file
default_config = DefaultInvokeAIAppConfig()
default_config.write_file(config.config_file_path, as_example=False)
return config
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
Args:
config_path: Path to the config file.
Returns:
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
"""
assert config_path.suffix == ".yaml"
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
loaded_config_dict = yaml.safe_load(file)
assert isinstance(loaded_config_dict, dict)
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
migrator = ConfigMigrator(Migrations)
migrated_config_dict = migrator.run_migrations(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
# Attempt to load as a v4 config file
try:
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e

View File

@ -0,0 +1,108 @@
# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team
"""
Schema migrations to perform on an InvokeAIAppConfig object.
The Migrations class defined in this module defines a series of
schema version migration steps for the InvokeAIConfig object.
To define a new migration, add a migration function to
Migrations.load_migrations() following the existing examples.
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeAlias
from .config_default import InvokeAIAppConfig
if TYPE_CHECKING:
from .config_migrate import ConfigMigrator
AppConfigDict: TypeAlias = dict[str, Any]
class MigrationsBase(ABC):
"""Define the config file migration steps to apply, abstract base class."""
@classmethod
@abstractmethod
def load(cls, migrator: "ConfigMigrator") -> None:
"""Use the provided migrator to register the configuration migrations to be run."""
class Migrations(MigrationsBase):
"""Configuration migration steps to apply."""
@classmethod
def load(cls, migrator: "ConfigMigrator") -> None:
"""Define migrations to perform."""
##################
# 3.0.0 -> 4.0.0 #
##################
@migrator.register(from_version="3.0.0", to_version="4.0.0")
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate a v3 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v3 config file.
Returns:
A dictionary of settings from a 4.0.0 config file.
"""
parsed_config_dict: dict[str, Any] = {}
for _category_name, category_dict in config_dict["InvokeAI"].items():
for k, v in category_dict.items():
# `outdir` was renamed to `outputs_dir` in v4
if k == "outdir":
parsed_config_dict["outputs_dir"] = v
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
# If if the incoming config has the default value, skip
continue
elif Path(v).name == "stable-diffusion":
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
else:
# Else we do not attempt to migrate this setting
parsed_config_dict["legacy_conf_dir"] = v
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
return parsed_config_dict
##################
# 4.0.0 -> 4.0.1 #
##################
@migrator.register(from_version="4.0.0", to_version="4.0.1")
def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate v4.0.0 config dictionary to v4.0.1.
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
A dictionary of settings from a v4.0.1 config file
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
return parsed_config_dict

View File

@ -9,7 +9,7 @@ from torch import Tensor
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices

View File

@ -1,6 +1,6 @@
from logging import Logger
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1

View File

@ -1,7 +1,7 @@
import sqlite3
from pathlib import Path
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration

View File

@ -9,7 +9,7 @@ from einops import repeat
from PIL import Image
from torchvision.transforms import Compose
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize

View File

@ -5,7 +5,7 @@
import numpy as np
import onnxruntime as ort
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import TorchDevice

View File

@ -6,7 +6,7 @@ import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import TorchDevice

View File

@ -9,7 +9,7 @@ import numpy as np
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
class PatchMatch:

View File

@ -10,7 +10,7 @@ from imwatermark import WatermarkEncoder
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
config = get_config()

View File

@ -12,7 +12,7 @@ from PIL import Image
from transformers import AutoFeatureExtractor
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.silence_warnings import SilenceWarnings

View File

@ -20,7 +20,7 @@ from diffusers.utils.import_utils import is_xformers_available
from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData

View File

@ -6,7 +6,7 @@ from typing import Any, Callable, Optional, Union
import torch
from typing_extensions import TypeAlias
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
IPAdapterData,
Range,

View File

@ -3,7 +3,7 @@ from typing import Dict, Literal, Optional, Union
import torch
from deprecated import deprecated
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
# legacy APIs
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]

View File

@ -180,8 +180,7 @@ import urllib.parse
from pathlib import Path
from typing import Any, Dict, Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import InvokeAIAppConfig, get_config
try:
import syslog

View File

@ -12,10 +12,9 @@ from invokeai.app.services.config.config_default import (
CONFIG_SCHEMA_VERSION,
DefaultInvokeAIAppConfig,
InvokeAIAppConfig,
get_config,
load_and_migrate_config,
)
from invokeai.app.services.config.config_migrate import ConfigMigrator
from invokeai.app.services.config.config_migrate import ConfigMigrator, get_config, load_and_migrate_config
from invokeai.app.services.config.migrations import Migrations, MigrationsBase
from invokeai.app.services.shared.graph import Graph
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
@ -72,6 +71,68 @@ i like turtles
"""
class GoodMigrations(MigrationsBase):
@classmethod
def load(cls, migrator: ConfigMigrator) -> None:
"""Define migrations to perform."""
@migrator.register(from_version="3.0.0", to_version="10.0.0")
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.0", to_version="10.0.1")
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.1", to_version="10.0.2")
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
class BadMigrations1(MigrationsBase):
"""This one fails because there is no path from 10.0.1 to 10.0.2"""
@classmethod
def load(cls, migrator: ConfigMigrator) -> None:
"""Define migrations to perform."""
@migrator.register(from_version="3.0.0", to_version="10.0.0")
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.0", to_version="10.0.1")
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.2", to_version="10.0.3")
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
class BadMigrations2(MigrationsBase):
"""This one fails because the path to 10.0.2 is registered twice"""
@classmethod
def load(cls, migrator: ConfigMigrator) -> None:
"""Define migrations to perform."""
@migrator.register(from_version="3.0.0", to_version="10.0.0")
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.0", to_version="10.0.1")
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.1", to_version="10.0.2")
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.0", to_version="10.0.2")
def migration_4(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@pytest.fixture
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
"""This may be overkill since the current tests don't need the root dir to exist"""
@ -291,33 +352,26 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch
def test_migration_check() -> None:
new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"})
# Test the default set of migrations
migrator = ConfigMigrator(Migrations)
new_config = migrator.run_migrations({"schema_version": "4.0.0"})
assert new_config is not None
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION
# Does this execute at compile time or run time?
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION, to_version=CONFIG_SCHEMA_VERSION + ".1")
def ok_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
# Test a custom set of migrations
migrator = ConfigMigrator(GoodMigrations)
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
assert new_config["schema_version"] == "10.0.2"
new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"})
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".1"
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".2", to_version=CONFIG_SCHEMA_VERSION + ".3")
def bad_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
# Because there is no version for "*.1" => "*.2", this should fail.
# Test a migration that should fail validation
migrator = ConfigMigrator(BadMigrations1)
with pytest.raises(ValueError):
ConfigMigrator.migrate({"schema_version": "4.0.0"})
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".1", to_version=CONFIG_SCHEMA_VERSION + ".2")
def good_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
# should work now, because there is a continuous path to *.3
new_config = ConfigMigrator.migrate(new_config)
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".3"
# Test another bad migration
migrator = ConfigMigrator(BadMigrations2)
with pytest.raises(ValueError):
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
@contextmanager