mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): invert backup/restore logic
Do the migration on a temp copy of the db, then back up the original and move the temp into its file.
This commit is contained in:
parent
abeb1bd3b3
commit
e461f9925e
@ -2,6 +2,7 @@
|
||||
|
||||
from functools import partial
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.shared.sqlite.migrations.migration_1 import migration_1
|
||||
from invokeai.app.services.shared.sqlite.migrations.migration_2 import migration_2
|
||||
@ -75,12 +76,22 @@ class ApiDependencies:
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
db = SqliteDatabase(config, logger)
|
||||
migrator = SQLiteMigrator(database=db.database, lock=db.lock, logger=logger)
|
||||
|
||||
migrator = SQLiteMigrator(
|
||||
db_path=db.database if isinstance(db.database, Path) else None,
|
||||
conn=db.conn,
|
||||
lock=db.lock,
|
||||
logger=logger,
|
||||
log_sql=config.log_sql,
|
||||
)
|
||||
migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files))
|
||||
migrator.register_migration(migration_1)
|
||||
migrator.register_migration(migration_2)
|
||||
migrator.run_migrations()
|
||||
|
||||
if not db.is_memory:
|
||||
db.reinitialize()
|
||||
|
||||
configuration = config
|
||||
logger = logger
|
||||
|
||||
|
@ -6,6 +6,8 @@ from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
||||
def _migrate(cursor: sqlite3.Cursor) -> None:
|
||||
"""Migration callback for database version 1."""
|
||||
|
||||
print("migration 1!!!")
|
||||
|
||||
_create_board_images(cursor)
|
||||
_create_boards(cursor)
|
||||
_create_images(cursor)
|
||||
|
@ -6,6 +6,8 @@ from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
||||
def _migrate(cursor: sqlite3.Cursor) -> None:
|
||||
"""Migration callback for database version 2."""
|
||||
|
||||
print("migration 2!!!")
|
||||
|
||||
_add_images_has_workflow(cursor)
|
||||
_add_session_queue_workflow(cursor)
|
||||
_drop_old_workflow_tables(cursor)
|
||||
|
@ -10,18 +10,21 @@ from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
||||
class SqliteDatabase:
|
||||
database: Path | str # Must declare this here to satisfy type checker
|
||||
|
||||
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
|
||||
def __init__(self, config: InvokeAIAppConfig, logger: Logger) -> None:
|
||||
self.initialize(config, logger)
|
||||
|
||||
def initialize(self, config: InvokeAIAppConfig, logger: Logger) -> None:
|
||||
self._logger = logger
|
||||
self._config = config
|
||||
self.is_memory = False
|
||||
if self._config.use_memory_db:
|
||||
self.database = sqlite_memory
|
||||
self.is_memory = True
|
||||
logger.info("Using in-memory database")
|
||||
logger.info("Initializing in-memory database")
|
||||
else:
|
||||
self.database = self._config.db_path
|
||||
self.database.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._logger.info(f"Using database at {self.database}")
|
||||
self._logger.info(f"Initializing database at {self.database}")
|
||||
|
||||
self.conn = sqlite3.connect(database=self.database, check_same_thread=False)
|
||||
self.lock = threading.RLock()
|
||||
@ -32,6 +35,13 @@ class SqliteDatabase:
|
||||
|
||||
self.conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
def reinitialize(self) -> None:
|
||||
"""Reinitializes the database. Needed after migration."""
|
||||
self.initialize(self._config, self._logger)
|
||||
|
||||
def close(self) -> None:
|
||||
self.conn.close()
|
||||
|
||||
def clean(self) -> None:
|
||||
with self.lock:
|
||||
try:
|
||||
|
@ -8,16 +8,14 @@ from typing import Callable, Optional, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
||||
|
||||
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
|
||||
|
||||
|
||||
class MigrationError(Exception):
|
||||
class MigrationError(RuntimeError):
|
||||
"""Raised when a migration fails."""
|
||||
|
||||
|
||||
class MigrationVersionError(ValueError, MigrationError):
|
||||
class MigrationVersionError(ValueError):
|
||||
"""Raised when a migration version is invalid."""
|
||||
|
||||
|
||||
@ -25,8 +23,14 @@ class Migration(BaseModel):
|
||||
"""
|
||||
Represents a migration for a SQLite database.
|
||||
|
||||
Migration callbacks will be provided an instance of SqliteDatabase.
|
||||
Migration callbacks should not commit; the migrator will commit the transaction.
|
||||
Migration callbacks will be provided an open cursor to the database. They should not commit their
|
||||
transaction; this is handled by the migrator.
|
||||
|
||||
Pre- and post-migration callback may be registered with :meth:`register_pre_callback` or
|
||||
:meth:`register_post_callback`.
|
||||
|
||||
If a migration has additional dependencies, it is recommended to use functools.partial to provide
|
||||
the dependencies and register the partial as the migration callback.
|
||||
"""
|
||||
|
||||
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
|
||||
@ -77,6 +81,28 @@ class MigrationSet:
|
||||
# register() ensures that there is only one migration with a given from_version, so this is safe.
|
||||
return next((m for m in self._migrations if m.from_version == from_version), None)
|
||||
|
||||
def validate_migration_path(self) -> None:
|
||||
"""
|
||||
Validates that the migrations form a single path of migrations from version 0 to the latest version.
|
||||
Raises a MigrationError if there is a problem.
|
||||
"""
|
||||
if self.count == 0:
|
||||
return
|
||||
if self.latest_version == 0:
|
||||
return
|
||||
current_version = 0
|
||||
touched_count = 0
|
||||
while current_version < self.latest_version:
|
||||
migration = self.get(current_version)
|
||||
if migration is None:
|
||||
raise MigrationError(f"Missing migration from {current_version}")
|
||||
current_version = migration.to_version
|
||||
touched_count += 1
|
||||
if current_version != self.latest_version:
|
||||
raise MigrationError(f"Missing migration to {self.latest_version}")
|
||||
if touched_count != self.count:
|
||||
raise MigrationError("Migration path is not contiguous")
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
"""The count of registered migrations."""
|
||||
@ -90,87 +116,112 @@ class MigrationSet:
|
||||
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version
|
||||
|
||||
|
||||
def get_temp_db_path(original_db_path: Path) -> Path:
|
||||
"""Gets the path to the migrated database."""
|
||||
return original_db_path.parent / original_db_path.name.replace(".db", ".db.temp")
|
||||
|
||||
|
||||
class SQLiteMigrator:
|
||||
"""
|
||||
Manages migrations for a SQLite database.
|
||||
|
||||
:param db: The SqliteDatabase, representing the database on which to run migrations.
|
||||
:param image_files: An instance of ImageFileStorageBase. Migrations may need to access image files.
|
||||
:param db_path: The path to the database to migrate, or None if using an in-memory database.
|
||||
:param conn: The connection to the database.
|
||||
:param lock: A lock to use when running migrations.
|
||||
:param logger: A logger to use for logging.
|
||||
:param log_sql: Whether to log SQL statements. Only used when the log level is set to debug.
|
||||
|
||||
Migrations should be registered with :meth:`register_migration`. Migrations will be run in
|
||||
order of their version number. If the database is already at the latest version, no migrations
|
||||
will be run.
|
||||
Migrations should be registered with :meth:`register_migration`.
|
||||
|
||||
During migration, a copy of the current database is made and the migrations are run on the copy. If the migration
|
||||
is successful, the original database is backed up and the migrated database is moved to the original database's
|
||||
path. If the migration fails, the original database is left untouched and the migrated database is deleted.
|
||||
|
||||
If the database is in-memory, no backup is made; the migration is run in-place.
|
||||
"""
|
||||
|
||||
backup_path: Optional[Path] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database: Path | str,
|
||||
db_path: Path | None,
|
||||
conn: sqlite3.Connection,
|
||||
lock: threading.RLock,
|
||||
logger: Logger,
|
||||
log_sql: bool = False,
|
||||
) -> None:
|
||||
self._lock = lock
|
||||
self._database = database
|
||||
self._is_memory = database == sqlite_memory
|
||||
self._db_path = db_path
|
||||
self._logger = logger
|
||||
self._conn = sqlite3.connect(database)
|
||||
self._conn = conn
|
||||
self._log_sql = log_sql
|
||||
self._cursor = self._conn.cursor()
|
||||
self._migrations = MigrationSet()
|
||||
|
||||
# Use a lock file to indicate that a migration is in progress. Should only exist in the event of catastrophic failure.
|
||||
self._migration_lock_file_path = (
|
||||
self._database.parent / ".migration_in_progress" if isinstance(self._database, Path) else None
|
||||
)
|
||||
|
||||
if self._unlink_lock_file():
|
||||
# The presence of an temp database file indicates a catastrophic failure of a previous migration.
|
||||
if self._db_path and get_temp_db_path(self._db_path).is_file():
|
||||
self._logger.warning("Previous migration failed! Trying again...")
|
||||
get_temp_db_path(self._db_path).unlink()
|
||||
|
||||
def register_migration(self, migration: Migration) -> None:
|
||||
"""
|
||||
Registers a migration.
|
||||
Migration callbacks should not commit any changes to the database; the migrator will commit the transaction.
|
||||
"""
|
||||
"""Registers a migration."""
|
||||
self._migrations.register(migration)
|
||||
self._logger.debug(f"Registered migration {migration.from_version} -> {migration.to_version}")
|
||||
|
||||
def run_migrations(self) -> None:
|
||||
"""Migrates the database to the latest version."""
|
||||
with self._lock:
|
||||
self._create_version_table()
|
||||
current_version = self._get_current_version()
|
||||
# This throws if there is a problem.
|
||||
self._migrations.validate_migration_path()
|
||||
self._create_migrations_table(cursor=self._cursor)
|
||||
|
||||
if self._migrations.count == 0:
|
||||
self._logger.debug("No migrations registered")
|
||||
return
|
||||
|
||||
latest_version = self._migrations.latest_version
|
||||
if current_version == latest_version:
|
||||
if self._get_current_version(self._cursor) == self._migrations.latest_version:
|
||||
self._logger.debug("Database is up to date, no migrations to run")
|
||||
return
|
||||
|
||||
self._logger.info("Database update needed")
|
||||
|
||||
# Only make a backup if using a file database (not memory)
|
||||
self._backup_db()
|
||||
if self._db_path:
|
||||
# We are using a file database. Create a copy of the database to run the migrations on.
|
||||
temp_db_path = self._create_temp_db(self._db_path)
|
||||
temp_db_conn = sqlite3.connect(temp_db_path)
|
||||
# We have to re-set this because we just created a new connection.
|
||||
if self._log_sql:
|
||||
temp_db_conn.set_trace_callback(self._logger.debug)
|
||||
temp_db_cursor = temp_db_conn.cursor()
|
||||
self._run_migrations(temp_db_cursor)
|
||||
# Close the connections, copy the original database as a backup, and move the temp database to the
|
||||
# original database's path.
|
||||
self._finalize_migration(
|
||||
temp_db_conn=temp_db_conn,
|
||||
temp_db_path=temp_db_path,
|
||||
original_db_path=self._db_path,
|
||||
)
|
||||
else:
|
||||
# We are using a memory database. No special backup or special handling needed.
|
||||
self._run_migrations(self._cursor)
|
||||
return
|
||||
|
||||
next_migration = self._migrations.get(from_version=current_version)
|
||||
while next_migration is not None:
|
||||
try:
|
||||
self._run_migration(next_migration)
|
||||
next_migration = self._migrations.get(self._get_current_version())
|
||||
except MigrationError:
|
||||
self._restore_db()
|
||||
raise
|
||||
self._logger.info("Database updated successfully")
|
||||
return
|
||||
|
||||
def _run_migration(self, migration: Migration) -> None:
|
||||
def _run_migrations(self, temp_db_cursor: sqlite3.Cursor) -> None:
|
||||
next_migration = self._migrations.get(from_version=self._get_current_version(temp_db_cursor))
|
||||
while next_migration is not None:
|
||||
self._run_migration(next_migration, temp_db_cursor)
|
||||
next_migration = self._migrations.get(self._get_current_version(temp_db_cursor))
|
||||
|
||||
def _run_migration(self, migration: Migration, temp_db_cursor: sqlite3.Cursor) -> None:
|
||||
"""Runs a single migration."""
|
||||
with self._lock:
|
||||
try:
|
||||
if self._get_current_version() != migration.from_version:
|
||||
if self._get_current_version(temp_db_cursor) != migration.from_version:
|
||||
raise MigrationError(
|
||||
f"Database is at version {self._get_current_version()}, expected {migration.from_version}"
|
||||
f"Database is at version {self._get_current_version(temp_db_cursor)}, expected {migration.from_version}"
|
||||
)
|
||||
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
|
||||
|
||||
@ -178,37 +229,37 @@ class SQLiteMigrator:
|
||||
if migration.pre_migrate:
|
||||
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
|
||||
for callback in migration.pre_migrate:
|
||||
callback(self._cursor)
|
||||
callback(temp_db_cursor)
|
||||
|
||||
# Run the actual migration
|
||||
migration.migrate(self._cursor)
|
||||
self._cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
|
||||
migration.migrate(temp_db_cursor)
|
||||
temp_db_cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
|
||||
|
||||
# Run post-migration callbacks
|
||||
if migration.post_migrate:
|
||||
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
|
||||
for callback in migration.post_migrate:
|
||||
callback(self._cursor)
|
||||
callback(temp_db_cursor)
|
||||
|
||||
# Migration callbacks only get a cursor. Commit this migration.
|
||||
self._conn.commit()
|
||||
temp_db_cursor.connection.commit()
|
||||
self._logger.debug(
|
||||
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
|
||||
)
|
||||
except Exception as e:
|
||||
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
|
||||
self._conn.rollback()
|
||||
temp_db_cursor.connection.rollback()
|
||||
self._logger.error(msg)
|
||||
raise MigrationError(msg) from e
|
||||
|
||||
def _create_version_table(self) -> None:
|
||||
"""Creates a version table for the database, if one does not already exist."""
|
||||
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the migrations table for the database, if one does not already exist."""
|
||||
with self._lock:
|
||||
try:
|
||||
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
|
||||
if self._cursor.fetchone() is not None:
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
|
||||
if cursor.fetchone() is not None:
|
||||
return
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
@ -216,21 +267,21 @@ class SQLiteMigrator:
|
||||
);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute("INSERT INTO migrations (version) VALUES (0);")
|
||||
self._conn.commit()
|
||||
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
|
||||
cursor.connection.commit()
|
||||
self._logger.debug("Created migrations table")
|
||||
except sqlite3.Error as e:
|
||||
msg = f"Problem creating migrations table: {e}"
|
||||
self._logger.error(msg)
|
||||
self._conn.rollback()
|
||||
cursor.connection.rollback()
|
||||
raise MigrationError(msg) from e
|
||||
|
||||
def _get_current_version(self) -> int:
|
||||
def _get_current_version(self, cursor: sqlite3.Cursor) -> int:
|
||||
"""Gets the current version of the database, or 0 if the version table does not exist."""
|
||||
with self._lock:
|
||||
try:
|
||||
self._cursor.execute("SELECT MAX(version) FROM migrations;")
|
||||
version = self._cursor.fetchone()[0]
|
||||
cursor.execute("SELECT MAX(version) FROM migrations;")
|
||||
version = cursor.fetchone()[0]
|
||||
if version is None:
|
||||
return 0
|
||||
return version
|
||||
@ -239,54 +290,19 @@ class SQLiteMigrator:
|
||||
return 0
|
||||
raise
|
||||
|
||||
def _backup_db(self) -> None:
|
||||
"""Backs up the databse, returning the path to the backup file."""
|
||||
if self._is_memory:
|
||||
self._logger.debug("Using memory database, skipping backup")
|
||||
# Sanity check!
|
||||
assert isinstance(self._database, Path)
|
||||
with self._lock:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_path = self._database.parent / f"{self._database.stem}_{timestamp}.db"
|
||||
self._logger.info(f"Backing up database to {backup_path}")
|
||||
def _create_temp_db(self, current_db_path: Path) -> Path:
|
||||
"""Copies the current database to a new file for migration."""
|
||||
temp_db_path = get_temp_db_path(current_db_path)
|
||||
shutil.copy2(current_db_path, temp_db_path)
|
||||
self._logger.info(f"Copied database to {temp_db_path} for migration")
|
||||
return temp_db_path
|
||||
|
||||
# Use SQLite's built in backup capabilities so we don't need to worry about locking and such.
|
||||
backup_conn = sqlite3.connect(backup_path)
|
||||
with backup_conn:
|
||||
self._conn.backup(backup_conn)
|
||||
backup_conn.close()
|
||||
|
||||
# Sanity check!
|
||||
if not backup_path.is_file():
|
||||
raise MigrationError("Unable to back up database")
|
||||
self.backup_path = backup_path
|
||||
|
||||
def _restore_db(
|
||||
self,
|
||||
) -> None:
|
||||
"""Restores the database from a backup file, unless the database is a memory database."""
|
||||
if self._is_memory:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
self._logger.info(f"Restoring database from {self.backup_path}")
|
||||
self._conn.close()
|
||||
assert isinstance(self.backup_path, Path)
|
||||
shutil.copy2(self.backup_path, self._database)
|
||||
|
||||
def _unlink_lock_file(self) -> bool:
|
||||
"""Unlinks the migration lock file, returning True if it existed."""
|
||||
if self._is_memory or self._migration_lock_file_path is None:
|
||||
return False
|
||||
if self._migration_lock_file_path.is_file():
|
||||
self._migration_lock_file_path.unlink()
|
||||
return True
|
||||
return False
|
||||
|
||||
def _write_migration_lock_file(self) -> None:
|
||||
"""Writes a file to indicate that a migration is in progress."""
|
||||
if self._is_memory or self._migration_lock_file_path is None:
|
||||
return
|
||||
assert isinstance(self._database, Path)
|
||||
with open(self._migration_lock_file_path, "w") as f:
|
||||
f.write("1")
|
||||
def _finalize_migration(self, temp_db_conn: sqlite3.Connection, temp_db_path: Path, original_db_path: Path) -> None:
|
||||
"""Closes connections, renames the original database as a backup and renames the migrated database to the original db path."""
|
||||
self._conn.close()
|
||||
temp_db_conn.close()
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_db_path = original_db_path.parent / f"{original_db_path.stem}_backup_{timestamp}.db"
|
||||
original_db_path.rename(backup_db_path)
|
||||
temp_db_path.rename(original_db_path)
|
||||
self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}")
|
||||
|
@ -21,7 +21,7 @@ from invokeai.app.services.shared.sqlite.sqlite_migrator import (
|
||||
def migrator() -> SQLiteMigrator:
|
||||
conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||
return SQLiteMigrator(
|
||||
conn=conn, database=sqlite_memory, lock=threading.RLock(), logger=Logger("test_sqlite_migrator")
|
||||
conn=conn, db_path=sqlite_memory, lock=threading.RLock(), logger=Logger("test_sqlite_migrator")
|
||||
)
|
||||
|
||||
|
||||
@ -50,19 +50,19 @@ def test_register_invalid_migration_version(migrator: SQLiteMigrator):
|
||||
|
||||
|
||||
def test_create_version_table(migrator: SQLiteMigrator):
|
||||
migrator._create_version_table()
|
||||
migrator._create_migrations_table()
|
||||
migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='version';")
|
||||
assert migrator._cursor.fetchone() is not None
|
||||
|
||||
|
||||
def test_get_current_version(migrator: SQLiteMigrator):
|
||||
migrator._create_version_table()
|
||||
migrator._create_migrations_table()
|
||||
migrator._conn.commit()
|
||||
assert migrator._get_current_version() == 0 # initial version
|
||||
|
||||
|
||||
def test_set_version(migrator: SQLiteMigrator):
|
||||
migrator._create_version_table()
|
||||
migrator._create_migrations_table()
|
||||
migrator._set_version(db_version=1, app_version="1.0.0")
|
||||
migrator._cursor.execute("SELECT MAX(db_version) FROM version;")
|
||||
assert migrator._cursor.fetchone()[0] == 1
|
||||
@ -71,7 +71,7 @@ def test_set_version(migrator: SQLiteMigrator):
|
||||
|
||||
|
||||
def test_run_migration(migrator: SQLiteMigrator):
|
||||
migrator._create_version_table()
|
||||
migrator._create_migrations_table()
|
||||
|
||||
def migration_callback(cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
@ -84,7 +84,7 @@ def test_run_migration(migrator: SQLiteMigrator):
|
||||
|
||||
|
||||
def test_run_migrations(migrator: SQLiteMigrator):
|
||||
migrator._create_version_table()
|
||||
migrator._create_migrations_table()
|
||||
|
||||
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
|
||||
def migrate(cursor: sqlite3.Cursor) -> None:
|
||||
@ -109,8 +109,8 @@ def test_backup_and_restore_db():
|
||||
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
conn.commit()
|
||||
|
||||
migrator = SQLiteMigrator(conn=conn, database=database, lock=threading.RLock(), logger=Logger("test"))
|
||||
backup_path = migrator._backup_db(migrator._database)
|
||||
migrator = SQLiteMigrator(conn=conn, db_path=database, lock=threading.RLock(), logger=Logger("test"))
|
||||
backup_path = migrator._backup_db(migrator._db_path)
|
||||
|
||||
# mangle the db
|
||||
migrator._cursor.execute("DROP TABLE test;")
|
||||
@ -135,21 +135,21 @@ def test_no_backup_and_restore_for_memory_db(migrator: SQLiteMigrator):
|
||||
|
||||
|
||||
def test_failed_migration(migrator: SQLiteMigrator, failing_migration: Migration):
|
||||
migrator._create_version_table()
|
||||
migrator._create_migrations_table()
|
||||
with pytest.raises(MigrationError, match="Error migrating database from 0 to 1"):
|
||||
migrator._run_migration(failing_migration)
|
||||
assert migrator._get_current_version() == 0
|
||||
|
||||
|
||||
def test_duplicate_migration_versions(migrator: SQLiteMigrator, good_migration: Migration):
|
||||
migrator._create_version_table()
|
||||
migrator._create_migrations_table()
|
||||
migrator.register_migration(good_migration)
|
||||
with pytest.raises(MigrationVersionError, match="already registered"):
|
||||
migrator.register_migration(deepcopy(good_migration))
|
||||
|
||||
|
||||
def test_non_sequential_migration_registration(migrator: SQLiteMigrator):
|
||||
migrator._create_version_table()
|
||||
migrator._create_migrations_table()
|
||||
|
||||
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
|
||||
def migrate(cursor: sqlite3.Cursor) -> None:
|
||||
@ -167,7 +167,7 @@ def test_non_sequential_migration_registration(migrator: SQLiteMigrator):
|
||||
|
||||
|
||||
def test_db_version_gt_last_migration(migrator: SQLiteMigrator, good_migration: Migration):
|
||||
migrator._create_version_table()
|
||||
migrator._create_migrations_table()
|
||||
migrator.register_migration(good_migration)
|
||||
migrator._set_version(db_version=2, app_version="2.0.0")
|
||||
with pytest.raises(MigrationError, match="greater than the latest migration version"):
|
||||
@ -176,7 +176,7 @@ def test_db_version_gt_last_migration(migrator: SQLiteMigrator, good_migration:
|
||||
|
||||
|
||||
def test_idempotent_migrations(migrator: SQLiteMigrator):
|
||||
migrator._create_version_table()
|
||||
migrator._create_migrations_table()
|
||||
|
||||
def create_test_table(cursor: sqlite3.Cursor) -> None:
|
||||
# This SQL throws if run twice
|
||||
|
Loading…
Reference in New Issue
Block a user