2024-04-19 01:33:54 +00:00
# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team
"""
Utility class for migrating among versions of the InvokeAI app config schema .
"""
2024-04-23 07:19:54 +00:00
from dataclasses import dataclass
2024-04-23 07:12:19 +00:00
from typing import Any , Callable , List , TypeAlias
2024-04-19 01:33:54 +00:00
2024-04-19 01:54:17 +00:00
from packaging . version import Version
2024-04-19 01:33:54 +00:00
2024-04-23 07:12:19 +00:00
AppConfigDict : TypeAlias = dict [ str , Any ]
2024-04-23 07:21:05 +00:00
MigrationFunction : TypeAlias = Callable [ [ AppConfigDict ] , AppConfigDict ]
2024-04-19 01:33:54 +00:00
2024-04-23 07:19:54 +00:00
@dataclass
class MigrationEntry :
2024-04-19 01:33:54 +00:00
""" Defines an individual migration. """
2024-04-19 01:54:17 +00:00
from_version : Version
to_version : Version
2024-04-23 07:21:05 +00:00
function : MigrationFunction
2024-04-19 01:33:54 +00:00
2024-04-23 07:11:13 +00:00
class ConfigMigrator :
2024-04-19 01:33:54 +00:00
""" This class allows migrators to register their input and output versions. """
_migrations : List [ MigrationEntry ] = [ ]
@classmethod
def register (
cls ,
2024-04-23 07:19:54 +00:00
from_version : str ,
to_version : str ,
2024-04-23 07:21:05 +00:00
) - > Callable [ [ MigrationFunction ] , MigrationFunction ] :
2024-04-19 01:33:54 +00:00
""" Define a decorator which registers the migration between two versions. """
2024-04-23 07:21:05 +00:00
def decorator ( function : MigrationFunction ) - > MigrationFunction :
2024-04-23 07:25:53 +00:00
if any ( from_version == m . from_version for m in cls . _migrations ) :
2024-04-19 01:33:54 +00:00
raise ValueError (
f " function { function . __name__ } is trying to register a migration for version { str ( from_version ) } , but this migration has already been registered. "
)
2024-04-23 07:19:54 +00:00
cls . _migrations . append (
MigrationEntry ( from_version = Version ( from_version ) , to_version = Version ( to_version ) , function = function )
)
2024-04-19 01:33:54 +00:00
return function
return decorator
@staticmethod
2024-04-25 01:36:28 +00:00
def _check_for_discontinuities ( migrations : List [ MigrationEntry ] ) - > None :
current_version = Version ( " 3.0.0 " )
2024-04-19 01:33:54 +00:00
for m in migrations :
2024-04-25 01:36:28 +00:00
if current_version != m . from_version :
raise ValueError (
f " Migration functions are not continuous. Expected from_version= { current_version } but got from_version= { m . from_version } , for migration function { m . function . __name__ } "
)
current_version = m . to_version
2024-04-19 01:33:54 +00:00
@classmethod
def migrate ( cls , config_dict : AppConfigDict ) - > AppConfigDict :
"""
Use the registered migration steps to bring config up to latest version .
: param config : The original configuration .
: return : The new configuration , lifted up to the latest version .
As a side effect , the new configuration will be written to disk .
If an inconsistency in the registered migration steps ' `from_version`
and ` to_version ` parameters are identified , this will raise a
ValueError exception .
"""
# Sort migrations by version number and raise a ValueError if
2024-04-25 01:36:28 +00:00
# any version range overlaps are detected.
2024-04-19 01:33:54 +00:00
sorted_migrations = sorted ( cls . _migrations , key = lambda x : x . from_version )
2024-04-25 01:36:28 +00:00
cls . _check_for_discontinuities ( sorted_migrations )
2024-04-19 01:33:54 +00:00
if " InvokeAI " in config_dict :
2024-04-19 01:54:17 +00:00
version = Version ( " 3.0.0 " )
2024-04-19 01:33:54 +00:00
else :
2024-04-19 01:54:17 +00:00
version = Version ( config_dict [ " schema_version " ] )
2024-04-19 01:33:54 +00:00
for migration in sorted_migrations :
2024-04-25 01:36:28 +00:00
if version == migration . from_version and version < migration . to_version :
2024-04-19 01:33:54 +00:00
config_dict = migration . function ( config_dict )
version = migration . to_version
config_dict [ " schema_version " ] = str ( version )
return config_dict