feat(db): invert backup/restore logic

Do the migration on a temp copy of the db, then back up the original and move the temp into its file.
This commit is contained in:
psychedelicious 2023-12-11 00:47:53 +11:00
parent abeb1bd3b3
commit e461f9925e
6 changed files with 167 additions and 126 deletions

View File

@ -2,6 +2,7 @@
from functools import partial from functools import partial
from logging import Logger from logging import Logger
from pathlib import Path
from invokeai.app.services.shared.sqlite.migrations.migration_1 import migration_1 from invokeai.app.services.shared.sqlite.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite.migrations.migration_2 import migration_2 from invokeai.app.services.shared.sqlite.migrations.migration_2 import migration_2
@ -75,12 +76,22 @@ class ApiDependencies:
image_files = DiskImageFileStorage(f"{output_folder}/images") image_files = DiskImageFileStorage(f"{output_folder}/images")
db = SqliteDatabase(config, logger) db = SqliteDatabase(config, logger)
migrator = SQLiteMigrator(database=db.database, lock=db.lock, logger=logger)
migrator = SQLiteMigrator(
db_path=db.database if isinstance(db.database, Path) else None,
conn=db.conn,
lock=db.lock,
logger=logger,
log_sql=config.log_sql,
)
migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files)) migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files))
migrator.register_migration(migration_1) migrator.register_migration(migration_1)
migrator.register_migration(migration_2) migrator.register_migration(migration_2)
migrator.run_migrations() migrator.run_migrations()
if not db.is_memory:
db.reinitialize()
configuration = config configuration = config
logger = logger logger = logger

View File

@ -6,6 +6,8 @@ from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
def _migrate(cursor: sqlite3.Cursor) -> None: def _migrate(cursor: sqlite3.Cursor) -> None:
"""Migration callback for database version 1.""" """Migration callback for database version 1."""
print("migration 1!!!")
_create_board_images(cursor) _create_board_images(cursor)
_create_boards(cursor) _create_boards(cursor)
_create_images(cursor) _create_images(cursor)

View File

@ -6,6 +6,8 @@ from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
def _migrate(cursor: sqlite3.Cursor) -> None: def _migrate(cursor: sqlite3.Cursor) -> None:
"""Migration callback for database version 2.""" """Migration callback for database version 2."""
print("migration 2!!!")
_add_images_has_workflow(cursor) _add_images_has_workflow(cursor)
_add_session_queue_workflow(cursor) _add_session_queue_workflow(cursor)
_drop_old_workflow_tables(cursor) _drop_old_workflow_tables(cursor)

View File

@ -10,18 +10,21 @@ from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
class SqliteDatabase: class SqliteDatabase:
database: Path | str # Must declare this here to satisfy type checker database: Path | str # Must declare this here to satisfy type checker
def __init__(self, config: InvokeAIAppConfig, logger: Logger): def __init__(self, config: InvokeAIAppConfig, logger: Logger) -> None:
self.initialize(config, logger)
def initialize(self, config: InvokeAIAppConfig, logger: Logger) -> None:
self._logger = logger self._logger = logger
self._config = config self._config = config
self.is_memory = False self.is_memory = False
if self._config.use_memory_db: if self._config.use_memory_db:
self.database = sqlite_memory self.database = sqlite_memory
self.is_memory = True self.is_memory = True
logger.info("Using in-memory database") logger.info("Initializing in-memory database")
else: else:
self.database = self._config.db_path self.database = self._config.db_path
self.database.parent.mkdir(parents=True, exist_ok=True) self.database.parent.mkdir(parents=True, exist_ok=True)
self._logger.info(f"Using database at {self.database}") self._logger.info(f"Initializing database at {self.database}")
self.conn = sqlite3.connect(database=self.database, check_same_thread=False) self.conn = sqlite3.connect(database=self.database, check_same_thread=False)
self.lock = threading.RLock() self.lock = threading.RLock()
@ -32,6 +35,13 @@ class SqliteDatabase:
self.conn.execute("PRAGMA foreign_keys = ON;") self.conn.execute("PRAGMA foreign_keys = ON;")
def reinitialize(self) -> None:
"""Reinitializes the database. Needed after migration."""
self.initialize(self._config, self._logger)
def close(self) -> None:
self.conn.close()
def clean(self) -> None: def clean(self) -> None:
with self.lock: with self.lock:
try: try:

View File

@ -8,16 +8,14 @@ from typing import Callable, Optional, TypeAlias
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None] MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
class MigrationError(Exception): class MigrationError(RuntimeError):
"""Raised when a migration fails.""" """Raised when a migration fails."""
class MigrationVersionError(ValueError, MigrationError): class MigrationVersionError(ValueError):
"""Raised when a migration version is invalid.""" """Raised when a migration version is invalid."""
@ -25,8 +23,14 @@ class Migration(BaseModel):
""" """
Represents a migration for a SQLite database. Represents a migration for a SQLite database.
Migration callbacks will be provided an instance of SqliteDatabase. Migration callbacks will be provided an open cursor to the database. They should not commit their
Migration callbacks should not commit; the migrator will commit the transaction. transaction; this is handled by the migrator.
Pre- and post-migration callback may be registered with :meth:`register_pre_callback` or
:meth:`register_post_callback`.
If a migration has additional dependencies, it is recommended to use functools.partial to provide
the dependencies and register the partial as the migration callback.
""" """
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run") from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
@ -77,6 +81,28 @@ class MigrationSet:
# register() ensures that there is only one migration with a given from_version, so this is safe. # register() ensures that there is only one migration with a given from_version, so this is safe.
return next((m for m in self._migrations if m.from_version == from_version), None) return next((m for m in self._migrations if m.from_version == from_version), None)
def validate_migration_path(self) -> None:
"""
Validates that the migrations form a single path of migrations from version 0 to the latest version.
Raises a MigrationError if there is a problem.
"""
if self.count == 0:
return
if self.latest_version == 0:
return
current_version = 0
touched_count = 0
while current_version < self.latest_version:
migration = self.get(current_version)
if migration is None:
raise MigrationError(f"Missing migration from {current_version}")
current_version = migration.to_version
touched_count += 1
if current_version != self.latest_version:
raise MigrationError(f"Missing migration to {self.latest_version}")
if touched_count != self.count:
raise MigrationError("Migration path is not contiguous")
@property @property
def count(self) -> int: def count(self) -> int:
"""The count of registered migrations.""" """The count of registered migrations."""
@ -90,87 +116,112 @@ class MigrationSet:
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version
def get_temp_db_path(original_db_path: Path) -> Path:
"""Gets the path to the migrated database."""
return original_db_path.parent / original_db_path.name.replace(".db", ".db.temp")
class SQLiteMigrator: class SQLiteMigrator:
""" """
Manages migrations for a SQLite database. Manages migrations for a SQLite database.
:param db: The SqliteDatabase, representing the database on which to run migrations. :param db_path: The path to the database to migrate, or None if using an in-memory database.
:param image_files: An instance of ImageFileStorageBase. Migrations may need to access image files. :param conn: The connection to the database.
:param lock: A lock to use when running migrations.
:param logger: A logger to use for logging.
:param log_sql: Whether to log SQL statements. Only used when the log level is set to debug.
Migrations should be registered with :meth:`register_migration`. Migrations will be run in Migrations should be registered with :meth:`register_migration`.
order of their version number. If the database is already at the latest version, no migrations
will be run. During migration, a copy of the current database is made and the migrations are run on the copy. If the migration
is successful, the original database is backed up and the migrated database is moved to the original database's
path. If the migration fails, the original database is left untouched and the migrated database is deleted.
If the database is in-memory, no backup is made; the migration is run in-place.
""" """
backup_path: Optional[Path] = None backup_path: Optional[Path] = None
def __init__( def __init__(
self, self,
database: Path | str, db_path: Path | None,
conn: sqlite3.Connection,
lock: threading.RLock, lock: threading.RLock,
logger: Logger, logger: Logger,
log_sql: bool = False,
) -> None: ) -> None:
self._lock = lock self._lock = lock
self._database = database self._db_path = db_path
self._is_memory = database == sqlite_memory
self._logger = logger self._logger = logger
self._conn = sqlite3.connect(database) self._conn = conn
self._log_sql = log_sql
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._migrations = MigrationSet() self._migrations = MigrationSet()
# Use a lock file to indicate that a migration is in progress. Should only exist in the event of catastrophic failure. # The presence of an temp database file indicates a catastrophic failure of a previous migration.
self._migration_lock_file_path = ( if self._db_path and get_temp_db_path(self._db_path).is_file():
self._database.parent / ".migration_in_progress" if isinstance(self._database, Path) else None
)
if self._unlink_lock_file():
self._logger.warning("Previous migration failed! Trying again...") self._logger.warning("Previous migration failed! Trying again...")
get_temp_db_path(self._db_path).unlink()
def register_migration(self, migration: Migration) -> None: def register_migration(self, migration: Migration) -> None:
""" """Registers a migration."""
Registers a migration.
Migration callbacks should not commit any changes to the database; the migrator will commit the transaction.
"""
self._migrations.register(migration) self._migrations.register(migration)
self._logger.debug(f"Registered migration {migration.from_version} -> {migration.to_version}") self._logger.debug(f"Registered migration {migration.from_version} -> {migration.to_version}")
def run_migrations(self) -> None: def run_migrations(self) -> None:
"""Migrates the database to the latest version.""" """Migrates the database to the latest version."""
with self._lock: with self._lock:
self._create_version_table() # This throws if there is a problem.
current_version = self._get_current_version() self._migrations.validate_migration_path()
self._create_migrations_table(cursor=self._cursor)
if self._migrations.count == 0: if self._migrations.count == 0:
self._logger.debug("No migrations registered") self._logger.debug("No migrations registered")
return return
latest_version = self._migrations.latest_version if self._get_current_version(self._cursor) == self._migrations.latest_version:
if current_version == latest_version:
self._logger.debug("Database is up to date, no migrations to run") self._logger.debug("Database is up to date, no migrations to run")
return return
self._logger.info("Database update needed") self._logger.info("Database update needed")
# Only make a backup if using a file database (not memory) if self._db_path:
self._backup_db() # We are using a file database. Create a copy of the database to run the migrations on.
temp_db_path = self._create_temp_db(self._db_path)
temp_db_conn = sqlite3.connect(temp_db_path)
# We have to re-set this because we just created a new connection.
if self._log_sql:
temp_db_conn.set_trace_callback(self._logger.debug)
temp_db_cursor = temp_db_conn.cursor()
self._run_migrations(temp_db_cursor)
# Close the connections, copy the original database as a backup, and move the temp database to the
# original database's path.
self._finalize_migration(
temp_db_conn=temp_db_conn,
temp_db_path=temp_db_path,
original_db_path=self._db_path,
)
else:
# We are using a memory database. No special backup or special handling needed.
self._run_migrations(self._cursor)
return
next_migration = self._migrations.get(from_version=current_version)
while next_migration is not None:
try:
self._run_migration(next_migration)
next_migration = self._migrations.get(self._get_current_version())
except MigrationError:
self._restore_db()
raise
self._logger.info("Database updated successfully") self._logger.info("Database updated successfully")
return
def _run_migration(self, migration: Migration) -> None: def _run_migrations(self, temp_db_cursor: sqlite3.Cursor) -> None:
next_migration = self._migrations.get(from_version=self._get_current_version(temp_db_cursor))
while next_migration is not None:
self._run_migration(next_migration, temp_db_cursor)
next_migration = self._migrations.get(self._get_current_version(temp_db_cursor))
def _run_migration(self, migration: Migration, temp_db_cursor: sqlite3.Cursor) -> None:
"""Runs a single migration.""" """Runs a single migration."""
with self._lock: with self._lock:
try: try:
if self._get_current_version() != migration.from_version: if self._get_current_version(temp_db_cursor) != migration.from_version:
raise MigrationError( raise MigrationError(
f"Database is at version {self._get_current_version()}, expected {migration.from_version}" f"Database is at version {self._get_current_version(temp_db_cursor)}, expected {migration.from_version}"
) )
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}") self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
@ -178,37 +229,37 @@ class SQLiteMigrator:
if migration.pre_migrate: if migration.pre_migrate:
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks") self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
for callback in migration.pre_migrate: for callback in migration.pre_migrate:
callback(self._cursor) callback(temp_db_cursor)
# Run the actual migration # Run the actual migration
migration.migrate(self._cursor) migration.migrate(temp_db_cursor)
self._cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,)) temp_db_cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
# Run post-migration callbacks # Run post-migration callbacks
if migration.post_migrate: if migration.post_migrate:
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks") self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
for callback in migration.post_migrate: for callback in migration.post_migrate:
callback(self._cursor) callback(temp_db_cursor)
# Migration callbacks only get a cursor. Commit this migration. # Migration callbacks only get a cursor. Commit this migration.
self._conn.commit() temp_db_cursor.connection.commit()
self._logger.debug( self._logger.debug(
f"Successfully migrated database from {migration.from_version} to {migration.to_version}" f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
) )
except Exception as e: except Exception as e:
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}" msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
self._conn.rollback() temp_db_cursor.connection.rollback()
self._logger.error(msg) self._logger.error(msg)
raise MigrationError(msg) from e raise MigrationError(msg) from e
def _create_version_table(self) -> None: def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
"""Creates a version table for the database, if one does not already exist.""" """Creates the migrations table for the database, if one does not already exist."""
with self._lock: with self._lock:
try: try:
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';") cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
if self._cursor.fetchone() is not None: if cursor.fetchone() is not None:
return return
self._cursor.execute( cursor.execute(
"""--sql """--sql
CREATE TABLE migrations ( CREATE TABLE migrations (
version INTEGER PRIMARY KEY, version INTEGER PRIMARY KEY,
@ -216,21 +267,21 @@ class SQLiteMigrator:
); );
""" """
) )
self._cursor.execute("INSERT INTO migrations (version) VALUES (0);") cursor.execute("INSERT INTO migrations (version) VALUES (0);")
self._conn.commit() cursor.connection.commit()
self._logger.debug("Created migrations table") self._logger.debug("Created migrations table")
except sqlite3.Error as e: except sqlite3.Error as e:
msg = f"Problem creating migrations table: {e}" msg = f"Problem creating migrations table: {e}"
self._logger.error(msg) self._logger.error(msg)
self._conn.rollback() cursor.connection.rollback()
raise MigrationError(msg) from e raise MigrationError(msg) from e
def _get_current_version(self) -> int: def _get_current_version(self, cursor: sqlite3.Cursor) -> int:
"""Gets the current version of the database, or 0 if the version table does not exist.""" """Gets the current version of the database, or 0 if the version table does not exist."""
with self._lock: with self._lock:
try: try:
self._cursor.execute("SELECT MAX(version) FROM migrations;") cursor.execute("SELECT MAX(version) FROM migrations;")
version = self._cursor.fetchone()[0] version = cursor.fetchone()[0]
if version is None: if version is None:
return 0 return 0
return version return version
@ -239,54 +290,19 @@ class SQLiteMigrator:
return 0 return 0
raise raise
def _backup_db(self) -> None: def _create_temp_db(self, current_db_path: Path) -> Path:
"""Backs up the databse, returning the path to the backup file.""" """Copies the current database to a new file for migration."""
if self._is_memory: temp_db_path = get_temp_db_path(current_db_path)
self._logger.debug("Using memory database, skipping backup") shutil.copy2(current_db_path, temp_db_path)
# Sanity check! self._logger.info(f"Copied database to {temp_db_path} for migration")
assert isinstance(self._database, Path) return temp_db_path
with self._lock:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = self._database.parent / f"{self._database.stem}_{timestamp}.db"
self._logger.info(f"Backing up database to {backup_path}")
# Use SQLite's built in backup capabilities so we don't need to worry about locking and such. def _finalize_migration(self, temp_db_conn: sqlite3.Connection, temp_db_path: Path, original_db_path: Path) -> None:
backup_conn = sqlite3.connect(backup_path) """Closes connections, renames the original database as a backup and renames the migrated database to the original db path."""
with backup_conn: self._conn.close()
self._conn.backup(backup_conn) temp_db_conn.close()
backup_conn.close() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_db_path = original_db_path.parent / f"{original_db_path.stem}_backup_{timestamp}.db"
# Sanity check! original_db_path.rename(backup_db_path)
if not backup_path.is_file(): temp_db_path.rename(original_db_path)
raise MigrationError("Unable to back up database") self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}")
self.backup_path = backup_path
def _restore_db(
self,
) -> None:
"""Restores the database from a backup file, unless the database is a memory database."""
if self._is_memory:
return
with self._lock:
self._logger.info(f"Restoring database from {self.backup_path}")
self._conn.close()
assert isinstance(self.backup_path, Path)
shutil.copy2(self.backup_path, self._database)
def _unlink_lock_file(self) -> bool:
"""Unlinks the migration lock file, returning True if it existed."""
if self._is_memory or self._migration_lock_file_path is None:
return False
if self._migration_lock_file_path.is_file():
self._migration_lock_file_path.unlink()
return True
return False
def _write_migration_lock_file(self) -> None:
"""Writes a file to indicate that a migration is in progress."""
if self._is_memory or self._migration_lock_file_path is None:
return
assert isinstance(self._database, Path)
with open(self._migration_lock_file_path, "w") as f:
f.write("1")

View File

@ -21,7 +21,7 @@ from invokeai.app.services.shared.sqlite.sqlite_migrator import (
def migrator() -> SQLiteMigrator: def migrator() -> SQLiteMigrator:
conn = sqlite3.connect(sqlite_memory, check_same_thread=False) conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
return SQLiteMigrator( return SQLiteMigrator(
conn=conn, database=sqlite_memory, lock=threading.RLock(), logger=Logger("test_sqlite_migrator") conn=conn, db_path=sqlite_memory, lock=threading.RLock(), logger=Logger("test_sqlite_migrator")
) )
@ -50,19 +50,19 @@ def test_register_invalid_migration_version(migrator: SQLiteMigrator):
def test_create_version_table(migrator: SQLiteMigrator): def test_create_version_table(migrator: SQLiteMigrator):
migrator._create_version_table() migrator._create_migrations_table()
migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='version';") migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='version';")
assert migrator._cursor.fetchone() is not None assert migrator._cursor.fetchone() is not None
def test_get_current_version(migrator: SQLiteMigrator): def test_get_current_version(migrator: SQLiteMigrator):
migrator._create_version_table() migrator._create_migrations_table()
migrator._conn.commit() migrator._conn.commit()
assert migrator._get_current_version() == 0 # initial version assert migrator._get_current_version() == 0 # initial version
def test_set_version(migrator: SQLiteMigrator): def test_set_version(migrator: SQLiteMigrator):
migrator._create_version_table() migrator._create_migrations_table()
migrator._set_version(db_version=1, app_version="1.0.0") migrator._set_version(db_version=1, app_version="1.0.0")
migrator._cursor.execute("SELECT MAX(db_version) FROM version;") migrator._cursor.execute("SELECT MAX(db_version) FROM version;")
assert migrator._cursor.fetchone()[0] == 1 assert migrator._cursor.fetchone()[0] == 1
@ -71,7 +71,7 @@ def test_set_version(migrator: SQLiteMigrator):
def test_run_migration(migrator: SQLiteMigrator): def test_run_migration(migrator: SQLiteMigrator):
migrator._create_version_table() migrator._create_migrations_table()
def migration_callback(cursor: sqlite3.Cursor) -> None: def migration_callback(cursor: sqlite3.Cursor) -> None:
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);") cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
@ -84,7 +84,7 @@ def test_run_migration(migrator: SQLiteMigrator):
def test_run_migrations(migrator: SQLiteMigrator): def test_run_migrations(migrator: SQLiteMigrator):
migrator._create_version_table() migrator._create_migrations_table()
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]: def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
def migrate(cursor: sqlite3.Cursor) -> None: def migrate(cursor: sqlite3.Cursor) -> None:
@ -109,8 +109,8 @@ def test_backup_and_restore_db():
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);") cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
conn.commit() conn.commit()
migrator = SQLiteMigrator(conn=conn, database=database, lock=threading.RLock(), logger=Logger("test")) migrator = SQLiteMigrator(conn=conn, db_path=database, lock=threading.RLock(), logger=Logger("test"))
backup_path = migrator._backup_db(migrator._database) backup_path = migrator._backup_db(migrator._db_path)
# mangle the db # mangle the db
migrator._cursor.execute("DROP TABLE test;") migrator._cursor.execute("DROP TABLE test;")
@ -135,21 +135,21 @@ def test_no_backup_and_restore_for_memory_db(migrator: SQLiteMigrator):
def test_failed_migration(migrator: SQLiteMigrator, failing_migration: Migration): def test_failed_migration(migrator: SQLiteMigrator, failing_migration: Migration):
migrator._create_version_table() migrator._create_migrations_table()
with pytest.raises(MigrationError, match="Error migrating database from 0 to 1"): with pytest.raises(MigrationError, match="Error migrating database from 0 to 1"):
migrator._run_migration(failing_migration) migrator._run_migration(failing_migration)
assert migrator._get_current_version() == 0 assert migrator._get_current_version() == 0
def test_duplicate_migration_versions(migrator: SQLiteMigrator, good_migration: Migration): def test_duplicate_migration_versions(migrator: SQLiteMigrator, good_migration: Migration):
migrator._create_version_table() migrator._create_migrations_table()
migrator.register_migration(good_migration) migrator.register_migration(good_migration)
with pytest.raises(MigrationVersionError, match="already registered"): with pytest.raises(MigrationVersionError, match="already registered"):
migrator.register_migration(deepcopy(good_migration)) migrator.register_migration(deepcopy(good_migration))
def test_non_sequential_migration_registration(migrator: SQLiteMigrator): def test_non_sequential_migration_registration(migrator: SQLiteMigrator):
migrator._create_version_table() migrator._create_migrations_table()
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]: def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
def migrate(cursor: sqlite3.Cursor) -> None: def migrate(cursor: sqlite3.Cursor) -> None:
@ -167,7 +167,7 @@ def test_non_sequential_migration_registration(migrator: SQLiteMigrator):
def test_db_version_gt_last_migration(migrator: SQLiteMigrator, good_migration: Migration): def test_db_version_gt_last_migration(migrator: SQLiteMigrator, good_migration: Migration):
migrator._create_version_table() migrator._create_migrations_table()
migrator.register_migration(good_migration) migrator.register_migration(good_migration)
migrator._set_version(db_version=2, app_version="2.0.0") migrator._set_version(db_version=2, app_version="2.0.0")
with pytest.raises(MigrationError, match="greater than the latest migration version"): with pytest.raises(MigrationError, match="greater than the latest migration version"):
@ -176,7 +176,7 @@ def test_db_version_gt_last_migration(migrator: SQLiteMigrator, good_migration:
def test_idempotent_migrations(migrator: SQLiteMigrator): def test_idempotent_migrations(migrator: SQLiteMigrator):
migrator._create_version_table() migrator._create_migrations_table()
def create_test_table(cursor: sqlite3.Cursor) -> None: def create_test_table(cursor: sqlite3.Cursor) -> None:
# This SQL throws if run twice # This SQL throws if run twice