use packaging.version rather than version-parse

This commit is contained in:
Lincoln Stein 2024-04-18 21:54:17 -04:00
parent 6ad1948a44
commit 36495b730d
2 changed files with 11 additions and 22 deletions

View File

@ -7,8 +7,8 @@ Utility class for migrating among versions of the InvokeAI app config schema.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, List, TypeVar from typing import TYPE_CHECKING, Any, Callable, List, TypeVar
from packaging.version import Version
from pydantic import BaseModel, ConfigDict, field_validator from pydantic import BaseModel, ConfigDict, field_validator
from version_parser import Version
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass
@ -16,23 +16,13 @@ if TYPE_CHECKING:
AppConfigDict = TypeVar("AppConfigDict", bound=dict[str, Any]) AppConfigDict = TypeVar("AppConfigDict", bound=dict[str, Any])
class AppVersion(Version):
"""Stringlike object that sorts like a version."""
def __hash__(self) -> int: # noqa D105
return hash(str(self))
def __repr__(self) -> str: # noqa D105
return f"AppVersion('{str(self)}')"
class ConfigMigratorBase(ABC): class ConfigMigratorBase(ABC):
"""This class allows migrators to register their input and output versions.""" """This class allows migrators to register their input and output versions."""
@classmethod @classmethod
@abstractmethod @abstractmethod
def register( def register(
cls, from_version: AppVersion, to_version: AppVersion cls, from_version: Version, to_version: Version
) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]: ) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]:
"""Define a decorator which registers the migration between two versions.""" """Define a decorator which registers the migration between two versions."""
@ -54,15 +44,15 @@ class MigrationEntry(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
from_version: AppVersion from_version: Version
to_version: AppVersion to_version: Version
function: Callable[[AppConfigDict], AppConfigDict] function: Callable[[AppConfigDict], AppConfigDict]
@field_validator("from_version", "to_version", mode="before") @field_validator("from_version", "to_version", mode="before")
@classmethod @classmethod
def _string_to_version(cls, v: str | AppVersion) -> AppVersion: # noqa D102 def _string_to_version(cls, v: str | Version) -> Version: # noqa D102
if isinstance(v, str): if isinstance(v, str):
return AppVersion(v) return Version(v)
else: else:
return v return v
@ -75,8 +65,8 @@ class ConfigMigrator(ConfigMigratorBase):
@classmethod @classmethod
def register( def register(
cls, cls,
from_version: AppVersion | str, from_version: Version | str,
to_version: AppVersion | str, to_version: Version | str,
) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]: ) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]:
"""Define a decorator which registers the migration between two versions.""" """Define a decorator which registers the migration between two versions."""
@ -92,7 +82,7 @@ class ConfigMigrator(ConfigMigratorBase):
@staticmethod @staticmethod
def _check_for_overlaps(migrations: List[MigrationEntry]) -> None: def _check_for_overlaps(migrations: List[MigrationEntry]) -> None:
current_version = AppVersion("0.0.0") current_version = Version("0.0.0")
for m in migrations: for m in migrations:
if current_version > m.from_version: if current_version > m.from_version:
raise ValueError(f"Version range overlap detected while processing function {m.function.__name__}") raise ValueError(f"Version range overlap detected while processing function {m.function.__name__}")
@ -116,9 +106,9 @@ class ConfigMigrator(ConfigMigratorBase):
cls._check_for_overlaps(sorted_migrations) cls._check_for_overlaps(sorted_migrations)
if "InvokeAI" in config_dict: if "InvokeAI" in config_dict:
version = AppVersion("3.0.0") version = Version("3.0.0")
else: else:
version = AppVersion(config_dict["schema_version"]) version = Version(config_dict["schema_version"])
for migration in sorted_migrations: for migration in sorted_migrations:
if version >= migration.from_version and version < migration.to_version: if version >= migration.from_version and version < migration.to_version:

View File

@ -90,7 +90,6 @@ dependencies = [
"semver~=3.0.1", "semver~=3.0.1",
"send2trash", "send2trash",
"test-tube~=0.7.5", "test-tube~=0.7.5",
"version-parser",
"windows-curses; sys_platform=='win32'", "windows-curses; sys_platform=='win32'",
] ]