mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
use packaging.version rather than version-parse
This commit is contained in:
parent
6ad1948a44
commit
36495b730d
@ -7,8 +7,8 @@ Utility class for migrating among versions of the InvokeAI app config schema.
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, Callable, List, TypeVar
|
from typing import TYPE_CHECKING, Any, Callable, List, TypeVar
|
||||||
|
|
||||||
|
from packaging.version import Version
|
||||||
from pydantic import BaseModel, ConfigDict, field_validator
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||||||
from version_parser import Version
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
pass
|
pass
|
||||||
@ -16,23 +16,13 @@ if TYPE_CHECKING:
|
|||||||
AppConfigDict = TypeVar("AppConfigDict", bound=dict[str, Any])
|
AppConfigDict = TypeVar("AppConfigDict", bound=dict[str, Any])
|
||||||
|
|
||||||
|
|
||||||
class AppVersion(Version):
|
|
||||||
"""Stringlike object that sorts like a version."""
|
|
||||||
|
|
||||||
def __hash__(self) -> int: # noqa D105
|
|
||||||
return hash(str(self))
|
|
||||||
|
|
||||||
def __repr__(self) -> str: # noqa D105
|
|
||||||
return f"AppVersion('{str(self)}')"
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigMigratorBase(ABC):
|
class ConfigMigratorBase(ABC):
|
||||||
"""This class allows migrators to register their input and output versions."""
|
"""This class allows migrators to register their input and output versions."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def register(
|
def register(
|
||||||
cls, from_version: AppVersion, to_version: AppVersion
|
cls, from_version: Version, to_version: Version
|
||||||
) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]:
|
) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]:
|
||||||
"""Define a decorator which registers the migration between two versions."""
|
"""Define a decorator which registers the migration between two versions."""
|
||||||
|
|
||||||
@ -54,15 +44,15 @@ class MigrationEntry(BaseModel):
|
|||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
from_version: AppVersion
|
from_version: Version
|
||||||
to_version: AppVersion
|
to_version: Version
|
||||||
function: Callable[[AppConfigDict], AppConfigDict]
|
function: Callable[[AppConfigDict], AppConfigDict]
|
||||||
|
|
||||||
@field_validator("from_version", "to_version", mode="before")
|
@field_validator("from_version", "to_version", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _string_to_version(cls, v: str | AppVersion) -> AppVersion: # noqa D102
|
def _string_to_version(cls, v: str | Version) -> Version: # noqa D102
|
||||||
if isinstance(v, str):
|
if isinstance(v, str):
|
||||||
return AppVersion(v)
|
return Version(v)
|
||||||
else:
|
else:
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@ -75,8 +65,8 @@ class ConfigMigrator(ConfigMigratorBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def register(
|
def register(
|
||||||
cls,
|
cls,
|
||||||
from_version: AppVersion | str,
|
from_version: Version | str,
|
||||||
to_version: AppVersion | str,
|
to_version: Version | str,
|
||||||
) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]:
|
) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]:
|
||||||
"""Define a decorator which registers the migration between two versions."""
|
"""Define a decorator which registers the migration between two versions."""
|
||||||
|
|
||||||
@ -92,7 +82,7 @@ class ConfigMigrator(ConfigMigratorBase):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_for_overlaps(migrations: List[MigrationEntry]) -> None:
|
def _check_for_overlaps(migrations: List[MigrationEntry]) -> None:
|
||||||
current_version = AppVersion("0.0.0")
|
current_version = Version("0.0.0")
|
||||||
for m in migrations:
|
for m in migrations:
|
||||||
if current_version > m.from_version:
|
if current_version > m.from_version:
|
||||||
raise ValueError(f"Version range overlap detected while processing function {m.function.__name__}")
|
raise ValueError(f"Version range overlap detected while processing function {m.function.__name__}")
|
||||||
@ -116,9 +106,9 @@ class ConfigMigrator(ConfigMigratorBase):
|
|||||||
cls._check_for_overlaps(sorted_migrations)
|
cls._check_for_overlaps(sorted_migrations)
|
||||||
|
|
||||||
if "InvokeAI" in config_dict:
|
if "InvokeAI" in config_dict:
|
||||||
version = AppVersion("3.0.0")
|
version = Version("3.0.0")
|
||||||
else:
|
else:
|
||||||
version = AppVersion(config_dict["schema_version"])
|
version = Version(config_dict["schema_version"])
|
||||||
|
|
||||||
for migration in sorted_migrations:
|
for migration in sorted_migrations:
|
||||||
if version >= migration.from_version and version < migration.to_version:
|
if version >= migration.from_version and version < migration.to_version:
|
||||||
|
@ -90,7 +90,6 @@ dependencies = [
|
|||||||
"semver~=3.0.1",
|
"semver~=3.0.1",
|
||||||
"send2trash",
|
"send2trash",
|
||||||
"test-tube~=0.7.5",
|
"test-tube~=0.7.5",
|
||||||
"version-parser",
|
|
||||||
"windows-curses; sys_platform=='win32'",
|
"windows-curses; sys_platform=='win32'",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user