mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): move sqlite_migrator into its own module
This commit is contained in:
parent
fa7d002175
commit
290851016e
@ -4,10 +4,10 @@ from functools import partial
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.shared.sqlite.migrations.migration_1 import migration_1
|
||||
from invokeai.app.services.shared.sqlite.migrations.migration_2 import migration_2
|
||||
from invokeai.app.services.shared.sqlite.migrations.migration_2_post import migrate_embedded_workflows
|
||||
from invokeai.app.services.shared.sqlite.sqlite_migrator import SQLiteMigrator
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2_post import migrate_embedded_workflows
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
def _migrate(cursor: sqlite3.Cursor) -> None:
|
@ -1,6 +1,6 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
def _migrate(cursor: sqlite3.Cursor) -> None:
|
@ -0,0 +1,109 @@
|
||||
import sqlite3
|
||||
from typing import Callable, Optional, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
|
||||
|
||||
|
||||
class MigrationError(RuntimeError):
|
||||
"""Raised when a migration fails."""
|
||||
|
||||
|
||||
class MigrationVersionError(ValueError):
|
||||
"""Raised when a migration version is invalid."""
|
||||
|
||||
|
||||
class Migration(BaseModel):
|
||||
"""
|
||||
Represents a migration for a SQLite database.
|
||||
|
||||
Migration callbacks will be provided an open cursor to the database. They should not commit their
|
||||
transaction; this is handled by the migrator.
|
||||
|
||||
Pre- and post-migration callback may be registered with :meth:`register_pre_callback` or
|
||||
:meth:`register_post_callback`.
|
||||
|
||||
If a migration has additional dependencies, it is recommended to use functools.partial to provide
|
||||
the dependencies and register the partial as the migration callback.
|
||||
"""
|
||||
|
||||
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
|
||||
to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration")
|
||||
migrate: MigrateCallback = Field(description="The callback to run to perform the migration")
|
||||
pre_migrate: list[MigrateCallback] = Field(
|
||||
default=[], description="A list of callbacks to run before the migration"
|
||||
)
|
||||
post_migrate: list[MigrateCallback] = Field(
|
||||
default=[], description="A list of callbacks to run after the migration"
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_to_version(self) -> "Migration":
|
||||
if self.to_version <= self.from_version:
|
||||
raise ValueError("to_version must be greater than from_version")
|
||||
return self
|
||||
|
||||
def __hash__(self) -> int:
|
||||
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
|
||||
return hash((self.from_version, self.to_version))
|
||||
|
||||
def register_pre_callback(self, callback: MigrateCallback) -> None:
|
||||
"""Registers a pre-migration callback."""
|
||||
self.pre_migrate.append(callback)
|
||||
|
||||
def register_post_callback(self, callback: MigrateCallback) -> None:
|
||||
"""Registers a post-migration callback."""
|
||||
self.post_migrate.append(callback)
|
||||
|
||||
|
||||
class MigrationSet:
|
||||
"""A set of Migrations. Performs validation during migration registration and provides utility methods."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._migrations: set[Migration] = set()
|
||||
|
||||
def register(self, migration: Migration) -> None:
|
||||
"""Registers a migration."""
|
||||
if any(m.from_version == migration.from_version for m in self._migrations):
|
||||
raise MigrationVersionError(f"Migration from {migration.from_version} already registered")
|
||||
if any(m.to_version == migration.to_version for m in self._migrations):
|
||||
raise MigrationVersionError(f"Migration to {migration.to_version} already registered")
|
||||
self._migrations.add(migration)
|
||||
|
||||
def get(self, from_version: int) -> Optional[Migration]:
|
||||
"""Gets the migration that may be run on the given database version."""
|
||||
# register() ensures that there is only one migration with a given from_version, so this is safe.
|
||||
return next((m for m in self._migrations if m.from_version == from_version), None)
|
||||
|
||||
def validate_migration_chain(self) -> None:
|
||||
"""
|
||||
Validates that the migrations form a single chain of migrations from version 0 to the latest version.
|
||||
Raises a MigrationError if there is a problem.
|
||||
"""
|
||||
if self.count == 0:
|
||||
return
|
||||
if self.latest_version == 0:
|
||||
return
|
||||
next_migration = self.get(from_version=0)
|
||||
if next_migration is None:
|
||||
raise MigrationError("Migration chain is fragmented")
|
||||
touched_count = 1
|
||||
while next_migration is not None:
|
||||
next_migration = self.get(next_migration.to_version)
|
||||
if next_migration is not None:
|
||||
touched_count += 1
|
||||
if touched_count != self.count:
|
||||
raise MigrationError("Migration chain is fragmented")
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
"""The count of registered migrations."""
|
||||
return len(self._migrations)
|
||||
|
||||
@property
|
||||
def latest_version(self) -> int:
|
||||
"""Gets latest to_version among registered migrations. Returns 0 if there are no migrations registered."""
|
||||
if self.count == 0:
|
||||
return 0
|
||||
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version
|
@ -4,114 +4,9 @@ import threading
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, TypeAlias
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
|
||||
|
||||
|
||||
class MigrationError(RuntimeError):
|
||||
"""Raised when a migration fails."""
|
||||
|
||||
|
||||
class MigrationVersionError(ValueError):
|
||||
"""Raised when a migration version is invalid."""
|
||||
|
||||
|
||||
class Migration(BaseModel):
|
||||
"""
|
||||
Represents a migration for a SQLite database.
|
||||
|
||||
Migration callbacks will be provided an open cursor to the database. They should not commit their
|
||||
transaction; this is handled by the migrator.
|
||||
|
||||
Pre- and post-migration callback may be registered with :meth:`register_pre_callback` or
|
||||
:meth:`register_post_callback`.
|
||||
|
||||
If a migration has additional dependencies, it is recommended to use functools.partial to provide
|
||||
the dependencies and register the partial as the migration callback.
|
||||
"""
|
||||
|
||||
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
|
||||
to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration")
|
||||
migrate: MigrateCallback = Field(description="The callback to run to perform the migration")
|
||||
pre_migrate: list[MigrateCallback] = Field(
|
||||
default=[], description="A list of callbacks to run before the migration"
|
||||
)
|
||||
post_migrate: list[MigrateCallback] = Field(
|
||||
default=[], description="A list of callbacks to run after the migration"
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_to_version(self) -> "Migration":
|
||||
if self.to_version <= self.from_version:
|
||||
raise ValueError("to_version must be greater than from_version")
|
||||
return self
|
||||
|
||||
def __hash__(self) -> int:
|
||||
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
|
||||
return hash((self.from_version, self.to_version))
|
||||
|
||||
def register_pre_callback(self, callback: MigrateCallback) -> None:
|
||||
"""Registers a pre-migration callback."""
|
||||
self.pre_migrate.append(callback)
|
||||
|
||||
def register_post_callback(self, callback: MigrateCallback) -> None:
|
||||
"""Registers a post-migration callback."""
|
||||
self.post_migrate.append(callback)
|
||||
|
||||
|
||||
class MigrationSet:
|
||||
"""A set of Migrations. Performs validation during migration registration and provides utility methods."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._migrations: set[Migration] = set()
|
||||
|
||||
def register(self, migration: Migration) -> None:
|
||||
"""Registers a migration."""
|
||||
if any(m.from_version == migration.from_version for m in self._migrations):
|
||||
raise MigrationVersionError(f"Migration from {migration.from_version} already registered")
|
||||
if any(m.to_version == migration.to_version for m in self._migrations):
|
||||
raise MigrationVersionError(f"Migration to {migration.to_version} already registered")
|
||||
self._migrations.add(migration)
|
||||
|
||||
def get(self, from_version: int) -> Optional[Migration]:
|
||||
"""Gets the migration that may be run on the given database version."""
|
||||
# register() ensures that there is only one migration with a given from_version, so this is safe.
|
||||
return next((m for m in self._migrations if m.from_version == from_version), None)
|
||||
|
||||
def validate_migration_chain(self) -> None:
|
||||
"""
|
||||
Validates that the migrations form a single chain of migrations from version 0 to the latest version.
|
||||
Raises a MigrationError if there is a problem.
|
||||
"""
|
||||
if self.count == 0:
|
||||
return
|
||||
if self.latest_version == 0:
|
||||
return
|
||||
next_migration = self.get(from_version=0)
|
||||
if next_migration is None:
|
||||
raise MigrationError("Migration chain is fragmented")
|
||||
touched_count = 1
|
||||
while next_migration is not None:
|
||||
next_migration = self.get(next_migration.to_version)
|
||||
if next_migration is not None:
|
||||
touched_count += 1
|
||||
if touched_count != self.count:
|
||||
raise MigrationError("Migration chain is fragmented")
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
"""The count of registered migrations."""
|
||||
return len(self._migrations)
|
||||
|
||||
@property
|
||||
def latest_version(self) -> int:
|
||||
"""Gets latest to_version among registered migrations. Returns 0 if there are no migrations registered."""
|
||||
if self.count == 0:
|
||||
return 0
|
||||
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationError, MigrationSet
|
||||
|
||||
|
||||
class SQLiteMigrator:
|
Loading…
Reference in New Issue
Block a user