mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): invert backup/restore logic
Do the migration on a temp copy of the db, then back up the original and move the temp into its file.
This commit is contained in:
parent
abeb1bd3b3
commit
e461f9925e
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from invokeai.app.services.shared.sqlite.migrations.migration_1 import migration_1
|
from invokeai.app.services.shared.sqlite.migrations.migration_1 import migration_1
|
||||||
from invokeai.app.services.shared.sqlite.migrations.migration_2 import migration_2
|
from invokeai.app.services.shared.sqlite.migrations.migration_2 import migration_2
|
||||||
@ -75,12 +76,22 @@ class ApiDependencies:
|
|||||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
|
|
||||||
db = SqliteDatabase(config, logger)
|
db = SqliteDatabase(config, logger)
|
||||||
migrator = SQLiteMigrator(database=db.database, lock=db.lock, logger=logger)
|
|
||||||
|
migrator = SQLiteMigrator(
|
||||||
|
db_path=db.database if isinstance(db.database, Path) else None,
|
||||||
|
conn=db.conn,
|
||||||
|
lock=db.lock,
|
||||||
|
logger=logger,
|
||||||
|
log_sql=config.log_sql,
|
||||||
|
)
|
||||||
migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files))
|
migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files))
|
||||||
migrator.register_migration(migration_1)
|
migrator.register_migration(migration_1)
|
||||||
migrator.register_migration(migration_2)
|
migrator.register_migration(migration_2)
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
|
|
||||||
|
if not db.is_memory:
|
||||||
|
db.reinitialize()
|
||||||
|
|
||||||
configuration = config
|
configuration = config
|
||||||
logger = logger
|
logger = logger
|
||||||
|
|
||||||
|
@ -6,6 +6,8 @@ from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
|||||||
def _migrate(cursor: sqlite3.Cursor) -> None:
|
def _migrate(cursor: sqlite3.Cursor) -> None:
|
||||||
"""Migration callback for database version 1."""
|
"""Migration callback for database version 1."""
|
||||||
|
|
||||||
|
print("migration 1!!!")
|
||||||
|
|
||||||
_create_board_images(cursor)
|
_create_board_images(cursor)
|
||||||
_create_boards(cursor)
|
_create_boards(cursor)
|
||||||
_create_images(cursor)
|
_create_images(cursor)
|
||||||
|
@ -6,6 +6,8 @@ from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
|||||||
def _migrate(cursor: sqlite3.Cursor) -> None:
|
def _migrate(cursor: sqlite3.Cursor) -> None:
|
||||||
"""Migration callback for database version 2."""
|
"""Migration callback for database version 2."""
|
||||||
|
|
||||||
|
print("migration 2!!!")
|
||||||
|
|
||||||
_add_images_has_workflow(cursor)
|
_add_images_has_workflow(cursor)
|
||||||
_add_session_queue_workflow(cursor)
|
_add_session_queue_workflow(cursor)
|
||||||
_drop_old_workflow_tables(cursor)
|
_drop_old_workflow_tables(cursor)
|
||||||
|
@ -10,18 +10,21 @@ from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
|||||||
class SqliteDatabase:
|
class SqliteDatabase:
|
||||||
database: Path | str # Must declare this here to satisfy type checker
|
database: Path | str # Must declare this here to satisfy type checker
|
||||||
|
|
||||||
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
|
def __init__(self, config: InvokeAIAppConfig, logger: Logger) -> None:
|
||||||
|
self.initialize(config, logger)
|
||||||
|
|
||||||
|
def initialize(self, config: InvokeAIAppConfig, logger: Logger) -> None:
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._config = config
|
self._config = config
|
||||||
self.is_memory = False
|
self.is_memory = False
|
||||||
if self._config.use_memory_db:
|
if self._config.use_memory_db:
|
||||||
self.database = sqlite_memory
|
self.database = sqlite_memory
|
||||||
self.is_memory = True
|
self.is_memory = True
|
||||||
logger.info("Using in-memory database")
|
logger.info("Initializing in-memory database")
|
||||||
else:
|
else:
|
||||||
self.database = self._config.db_path
|
self.database = self._config.db_path
|
||||||
self.database.parent.mkdir(parents=True, exist_ok=True)
|
self.database.parent.mkdir(parents=True, exist_ok=True)
|
||||||
self._logger.info(f"Using database at {self.database}")
|
self._logger.info(f"Initializing database at {self.database}")
|
||||||
|
|
||||||
self.conn = sqlite3.connect(database=self.database, check_same_thread=False)
|
self.conn = sqlite3.connect(database=self.database, check_same_thread=False)
|
||||||
self.lock = threading.RLock()
|
self.lock = threading.RLock()
|
||||||
@ -32,6 +35,13 @@ class SqliteDatabase:
|
|||||||
|
|
||||||
self.conn.execute("PRAGMA foreign_keys = ON;")
|
self.conn.execute("PRAGMA foreign_keys = ON;")
|
||||||
|
|
||||||
|
def reinitialize(self) -> None:
|
||||||
|
"""Reinitializes the database. Needed after migration."""
|
||||||
|
self.initialize(self._config, self._logger)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self.conn.close()
|
||||||
|
|
||||||
def clean(self) -> None:
|
def clean(self) -> None:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
try:
|
try:
|
||||||
|
@ -8,16 +8,14 @@ from typing import Callable, Optional, TypeAlias
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
|
||||||
|
|
||||||
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
|
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
|
||||||
|
|
||||||
|
|
||||||
class MigrationError(Exception):
|
class MigrationError(RuntimeError):
|
||||||
"""Raised when a migration fails."""
|
"""Raised when a migration fails."""
|
||||||
|
|
||||||
|
|
||||||
class MigrationVersionError(ValueError, MigrationError):
|
class MigrationVersionError(ValueError):
|
||||||
"""Raised when a migration version is invalid."""
|
"""Raised when a migration version is invalid."""
|
||||||
|
|
||||||
|
|
||||||
@ -25,8 +23,14 @@ class Migration(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Represents a migration for a SQLite database.
|
Represents a migration for a SQLite database.
|
||||||
|
|
||||||
Migration callbacks will be provided an instance of SqliteDatabase.
|
Migration callbacks will be provided an open cursor to the database. They should not commit their
|
||||||
Migration callbacks should not commit; the migrator will commit the transaction.
|
transaction; this is handled by the migrator.
|
||||||
|
|
||||||
|
Pre- and post-migration callback may be registered with :meth:`register_pre_callback` or
|
||||||
|
:meth:`register_post_callback`.
|
||||||
|
|
||||||
|
If a migration has additional dependencies, it is recommended to use functools.partial to provide
|
||||||
|
the dependencies and register the partial as the migration callback.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
|
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
|
||||||
@ -77,6 +81,28 @@ class MigrationSet:
|
|||||||
# register() ensures that there is only one migration with a given from_version, so this is safe.
|
# register() ensures that there is only one migration with a given from_version, so this is safe.
|
||||||
return next((m for m in self._migrations if m.from_version == from_version), None)
|
return next((m for m in self._migrations if m.from_version == from_version), None)
|
||||||
|
|
||||||
|
def validate_migration_path(self) -> None:
|
||||||
|
"""
|
||||||
|
Validates that the migrations form a single path of migrations from version 0 to the latest version.
|
||||||
|
Raises a MigrationError if there is a problem.
|
||||||
|
"""
|
||||||
|
if self.count == 0:
|
||||||
|
return
|
||||||
|
if self.latest_version == 0:
|
||||||
|
return
|
||||||
|
current_version = 0
|
||||||
|
touched_count = 0
|
||||||
|
while current_version < self.latest_version:
|
||||||
|
migration = self.get(current_version)
|
||||||
|
if migration is None:
|
||||||
|
raise MigrationError(f"Missing migration from {current_version}")
|
||||||
|
current_version = migration.to_version
|
||||||
|
touched_count += 1
|
||||||
|
if current_version != self.latest_version:
|
||||||
|
raise MigrationError(f"Missing migration to {self.latest_version}")
|
||||||
|
if touched_count != self.count:
|
||||||
|
raise MigrationError("Migration path is not contiguous")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def count(self) -> int:
|
def count(self) -> int:
|
||||||
"""The count of registered migrations."""
|
"""The count of registered migrations."""
|
||||||
@ -90,87 +116,112 @@ class MigrationSet:
|
|||||||
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version
|
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version
|
||||||
|
|
||||||
|
|
||||||
|
def get_temp_db_path(original_db_path: Path) -> Path:
|
||||||
|
"""Gets the path to the migrated database."""
|
||||||
|
return original_db_path.parent / original_db_path.name.replace(".db", ".db.temp")
|
||||||
|
|
||||||
|
|
||||||
class SQLiteMigrator:
|
class SQLiteMigrator:
|
||||||
"""
|
"""
|
||||||
Manages migrations for a SQLite database.
|
Manages migrations for a SQLite database.
|
||||||
|
|
||||||
:param db: The SqliteDatabase, representing the database on which to run migrations.
|
:param db_path: The path to the database to migrate, or None if using an in-memory database.
|
||||||
:param image_files: An instance of ImageFileStorageBase. Migrations may need to access image files.
|
:param conn: The connection to the database.
|
||||||
|
:param lock: A lock to use when running migrations.
|
||||||
|
:param logger: A logger to use for logging.
|
||||||
|
:param log_sql: Whether to log SQL statements. Only used when the log level is set to debug.
|
||||||
|
|
||||||
Migrations should be registered with :meth:`register_migration`. Migrations will be run in
|
Migrations should be registered with :meth:`register_migration`.
|
||||||
order of their version number. If the database is already at the latest version, no migrations
|
|
||||||
will be run.
|
During migration, a copy of the current database is made and the migrations are run on the copy. If the migration
|
||||||
|
is successful, the original database is backed up and the migrated database is moved to the original database's
|
||||||
|
path. If the migration fails, the original database is left untouched and the migrated database is deleted.
|
||||||
|
|
||||||
|
If the database is in-memory, no backup is made; the migration is run in-place.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
backup_path: Optional[Path] = None
|
backup_path: Optional[Path] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
database: Path | str,
|
db_path: Path | None,
|
||||||
|
conn: sqlite3.Connection,
|
||||||
lock: threading.RLock,
|
lock: threading.RLock,
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
|
log_sql: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._lock = lock
|
self._lock = lock
|
||||||
self._database = database
|
self._db_path = db_path
|
||||||
self._is_memory = database == sqlite_memory
|
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._conn = sqlite3.connect(database)
|
self._conn = conn
|
||||||
|
self._log_sql = log_sql
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
self._migrations = MigrationSet()
|
self._migrations = MigrationSet()
|
||||||
|
|
||||||
# Use a lock file to indicate that a migration is in progress. Should only exist in the event of catastrophic failure.
|
# The presence of an temp database file indicates a catastrophic failure of a previous migration.
|
||||||
self._migration_lock_file_path = (
|
if self._db_path and get_temp_db_path(self._db_path).is_file():
|
||||||
self._database.parent / ".migration_in_progress" if isinstance(self._database, Path) else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._unlink_lock_file():
|
|
||||||
self._logger.warning("Previous migration failed! Trying again...")
|
self._logger.warning("Previous migration failed! Trying again...")
|
||||||
|
get_temp_db_path(self._db_path).unlink()
|
||||||
|
|
||||||
def register_migration(self, migration: Migration) -> None:
|
def register_migration(self, migration: Migration) -> None:
|
||||||
"""
|
"""Registers a migration."""
|
||||||
Registers a migration.
|
|
||||||
Migration callbacks should not commit any changes to the database; the migrator will commit the transaction.
|
|
||||||
"""
|
|
||||||
self._migrations.register(migration)
|
self._migrations.register(migration)
|
||||||
self._logger.debug(f"Registered migration {migration.from_version} -> {migration.to_version}")
|
self._logger.debug(f"Registered migration {migration.from_version} -> {migration.to_version}")
|
||||||
|
|
||||||
def run_migrations(self) -> None:
|
def run_migrations(self) -> None:
|
||||||
"""Migrates the database to the latest version."""
|
"""Migrates the database to the latest version."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._create_version_table()
|
# This throws if there is a problem.
|
||||||
current_version = self._get_current_version()
|
self._migrations.validate_migration_path()
|
||||||
|
self._create_migrations_table(cursor=self._cursor)
|
||||||
|
|
||||||
if self._migrations.count == 0:
|
if self._migrations.count == 0:
|
||||||
self._logger.debug("No migrations registered")
|
self._logger.debug("No migrations registered")
|
||||||
return
|
return
|
||||||
|
|
||||||
latest_version = self._migrations.latest_version
|
if self._get_current_version(self._cursor) == self._migrations.latest_version:
|
||||||
if current_version == latest_version:
|
|
||||||
self._logger.debug("Database is up to date, no migrations to run")
|
self._logger.debug("Database is up to date, no migrations to run")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._logger.info("Database update needed")
|
self._logger.info("Database update needed")
|
||||||
|
|
||||||
# Only make a backup if using a file database (not memory)
|
if self._db_path:
|
||||||
self._backup_db()
|
# We are using a file database. Create a copy of the database to run the migrations on.
|
||||||
|
temp_db_path = self._create_temp_db(self._db_path)
|
||||||
|
temp_db_conn = sqlite3.connect(temp_db_path)
|
||||||
|
# We have to re-set this because we just created a new connection.
|
||||||
|
if self._log_sql:
|
||||||
|
temp_db_conn.set_trace_callback(self._logger.debug)
|
||||||
|
temp_db_cursor = temp_db_conn.cursor()
|
||||||
|
self._run_migrations(temp_db_cursor)
|
||||||
|
# Close the connections, copy the original database as a backup, and move the temp database to the
|
||||||
|
# original database's path.
|
||||||
|
self._finalize_migration(
|
||||||
|
temp_db_conn=temp_db_conn,
|
||||||
|
temp_db_path=temp_db_path,
|
||||||
|
original_db_path=self._db_path,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# We are using a memory database. No special backup or special handling needed.
|
||||||
|
self._run_migrations(self._cursor)
|
||||||
|
return
|
||||||
|
|
||||||
next_migration = self._migrations.get(from_version=current_version)
|
|
||||||
while next_migration is not None:
|
|
||||||
try:
|
|
||||||
self._run_migration(next_migration)
|
|
||||||
next_migration = self._migrations.get(self._get_current_version())
|
|
||||||
except MigrationError:
|
|
||||||
self._restore_db()
|
|
||||||
raise
|
|
||||||
self._logger.info("Database updated successfully")
|
self._logger.info("Database updated successfully")
|
||||||
|
return
|
||||||
|
|
||||||
def _run_migration(self, migration: Migration) -> None:
|
def _run_migrations(self, temp_db_cursor: sqlite3.Cursor) -> None:
|
||||||
|
next_migration = self._migrations.get(from_version=self._get_current_version(temp_db_cursor))
|
||||||
|
while next_migration is not None:
|
||||||
|
self._run_migration(next_migration, temp_db_cursor)
|
||||||
|
next_migration = self._migrations.get(self._get_current_version(temp_db_cursor))
|
||||||
|
|
||||||
|
def _run_migration(self, migration: Migration, temp_db_cursor: sqlite3.Cursor) -> None:
|
||||||
"""Runs a single migration."""
|
"""Runs a single migration."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
try:
|
try:
|
||||||
if self._get_current_version() != migration.from_version:
|
if self._get_current_version(temp_db_cursor) != migration.from_version:
|
||||||
raise MigrationError(
|
raise MigrationError(
|
||||||
f"Database is at version {self._get_current_version()}, expected {migration.from_version}"
|
f"Database is at version {self._get_current_version(temp_db_cursor)}, expected {migration.from_version}"
|
||||||
)
|
)
|
||||||
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
|
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
|
||||||
|
|
||||||
@ -178,37 +229,37 @@ class SQLiteMigrator:
|
|||||||
if migration.pre_migrate:
|
if migration.pre_migrate:
|
||||||
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
|
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
|
||||||
for callback in migration.pre_migrate:
|
for callback in migration.pre_migrate:
|
||||||
callback(self._cursor)
|
callback(temp_db_cursor)
|
||||||
|
|
||||||
# Run the actual migration
|
# Run the actual migration
|
||||||
migration.migrate(self._cursor)
|
migration.migrate(temp_db_cursor)
|
||||||
self._cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
|
temp_db_cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
|
||||||
|
|
||||||
# Run post-migration callbacks
|
# Run post-migration callbacks
|
||||||
if migration.post_migrate:
|
if migration.post_migrate:
|
||||||
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
|
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
|
||||||
for callback in migration.post_migrate:
|
for callback in migration.post_migrate:
|
||||||
callback(self._cursor)
|
callback(temp_db_cursor)
|
||||||
|
|
||||||
# Migration callbacks only get a cursor. Commit this migration.
|
# Migration callbacks only get a cursor. Commit this migration.
|
||||||
self._conn.commit()
|
temp_db_cursor.connection.commit()
|
||||||
self._logger.debug(
|
self._logger.debug(
|
||||||
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
|
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
|
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
|
||||||
self._conn.rollback()
|
temp_db_cursor.connection.rollback()
|
||||||
self._logger.error(msg)
|
self._logger.error(msg)
|
||||||
raise MigrationError(msg) from e
|
raise MigrationError(msg) from e
|
||||||
|
|
||||||
def _create_version_table(self) -> None:
|
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
|
||||||
"""Creates a version table for the database, if one does not already exist."""
|
"""Creates the migrations table for the database, if one does not already exist."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
try:
|
try:
|
||||||
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
|
||||||
if self._cursor.fetchone() is not None:
|
if cursor.fetchone() is not None:
|
||||||
return
|
return
|
||||||
self._cursor.execute(
|
cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
CREATE TABLE migrations (
|
CREATE TABLE migrations (
|
||||||
version INTEGER PRIMARY KEY,
|
version INTEGER PRIMARY KEY,
|
||||||
@ -216,21 +267,21 @@ class SQLiteMigrator:
|
|||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
self._cursor.execute("INSERT INTO migrations (version) VALUES (0);")
|
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
|
||||||
self._conn.commit()
|
cursor.connection.commit()
|
||||||
self._logger.debug("Created migrations table")
|
self._logger.debug("Created migrations table")
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
msg = f"Problem creating migrations table: {e}"
|
msg = f"Problem creating migrations table: {e}"
|
||||||
self._logger.error(msg)
|
self._logger.error(msg)
|
||||||
self._conn.rollback()
|
cursor.connection.rollback()
|
||||||
raise MigrationError(msg) from e
|
raise MigrationError(msg) from e
|
||||||
|
|
||||||
def _get_current_version(self) -> int:
|
def _get_current_version(self, cursor: sqlite3.Cursor) -> int:
|
||||||
"""Gets the current version of the database, or 0 if the version table does not exist."""
|
"""Gets the current version of the database, or 0 if the version table does not exist."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
try:
|
try:
|
||||||
self._cursor.execute("SELECT MAX(version) FROM migrations;")
|
cursor.execute("SELECT MAX(version) FROM migrations;")
|
||||||
version = self._cursor.fetchone()[0]
|
version = cursor.fetchone()[0]
|
||||||
if version is None:
|
if version is None:
|
||||||
return 0
|
return 0
|
||||||
return version
|
return version
|
||||||
@ -239,54 +290,19 @@ class SQLiteMigrator:
|
|||||||
return 0
|
return 0
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _backup_db(self) -> None:
|
def _create_temp_db(self, current_db_path: Path) -> Path:
|
||||||
"""Backs up the databse, returning the path to the backup file."""
|
"""Copies the current database to a new file for migration."""
|
||||||
if self._is_memory:
|
temp_db_path = get_temp_db_path(current_db_path)
|
||||||
self._logger.debug("Using memory database, skipping backup")
|
shutil.copy2(current_db_path, temp_db_path)
|
||||||
# Sanity check!
|
self._logger.info(f"Copied database to {temp_db_path} for migration")
|
||||||
assert isinstance(self._database, Path)
|
return temp_db_path
|
||||||
with self._lock:
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
backup_path = self._database.parent / f"{self._database.stem}_{timestamp}.db"
|
|
||||||
self._logger.info(f"Backing up database to {backup_path}")
|
|
||||||
|
|
||||||
# Use SQLite's built in backup capabilities so we don't need to worry about locking and such.
|
def _finalize_migration(self, temp_db_conn: sqlite3.Connection, temp_db_path: Path, original_db_path: Path) -> None:
|
||||||
backup_conn = sqlite3.connect(backup_path)
|
"""Closes connections, renames the original database as a backup and renames the migrated database to the original db path."""
|
||||||
with backup_conn:
|
self._conn.close()
|
||||||
self._conn.backup(backup_conn)
|
temp_db_conn.close()
|
||||||
backup_conn.close()
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
backup_db_path = original_db_path.parent / f"{original_db_path.stem}_backup_{timestamp}.db"
|
||||||
# Sanity check!
|
original_db_path.rename(backup_db_path)
|
||||||
if not backup_path.is_file():
|
temp_db_path.rename(original_db_path)
|
||||||
raise MigrationError("Unable to back up database")
|
self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}")
|
||||||
self.backup_path = backup_path
|
|
||||||
|
|
||||||
def _restore_db(
|
|
||||||
self,
|
|
||||||
) -> None:
|
|
||||||
"""Restores the database from a backup file, unless the database is a memory database."""
|
|
||||||
if self._is_memory:
|
|
||||||
return
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
self._logger.info(f"Restoring database from {self.backup_path}")
|
|
||||||
self._conn.close()
|
|
||||||
assert isinstance(self.backup_path, Path)
|
|
||||||
shutil.copy2(self.backup_path, self._database)
|
|
||||||
|
|
||||||
def _unlink_lock_file(self) -> bool:
|
|
||||||
"""Unlinks the migration lock file, returning True if it existed."""
|
|
||||||
if self._is_memory or self._migration_lock_file_path is None:
|
|
||||||
return False
|
|
||||||
if self._migration_lock_file_path.is_file():
|
|
||||||
self._migration_lock_file_path.unlink()
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _write_migration_lock_file(self) -> None:
|
|
||||||
"""Writes a file to indicate that a migration is in progress."""
|
|
||||||
if self._is_memory or self._migration_lock_file_path is None:
|
|
||||||
return
|
|
||||||
assert isinstance(self._database, Path)
|
|
||||||
with open(self._migration_lock_file_path, "w") as f:
|
|
||||||
f.write("1")
|
|
||||||
|
@ -21,7 +21,7 @@ from invokeai.app.services.shared.sqlite.sqlite_migrator import (
|
|||||||
def migrator() -> SQLiteMigrator:
|
def migrator() -> SQLiteMigrator:
|
||||||
conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||||
return SQLiteMigrator(
|
return SQLiteMigrator(
|
||||||
conn=conn, database=sqlite_memory, lock=threading.RLock(), logger=Logger("test_sqlite_migrator")
|
conn=conn, db_path=sqlite_memory, lock=threading.RLock(), logger=Logger("test_sqlite_migrator")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -50,19 +50,19 @@ def test_register_invalid_migration_version(migrator: SQLiteMigrator):
|
|||||||
|
|
||||||
|
|
||||||
def test_create_version_table(migrator: SQLiteMigrator):
|
def test_create_version_table(migrator: SQLiteMigrator):
|
||||||
migrator._create_version_table()
|
migrator._create_migrations_table()
|
||||||
migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='version';")
|
migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='version';")
|
||||||
assert migrator._cursor.fetchone() is not None
|
assert migrator._cursor.fetchone() is not None
|
||||||
|
|
||||||
|
|
||||||
def test_get_current_version(migrator: SQLiteMigrator):
|
def test_get_current_version(migrator: SQLiteMigrator):
|
||||||
migrator._create_version_table()
|
migrator._create_migrations_table()
|
||||||
migrator._conn.commit()
|
migrator._conn.commit()
|
||||||
assert migrator._get_current_version() == 0 # initial version
|
assert migrator._get_current_version() == 0 # initial version
|
||||||
|
|
||||||
|
|
||||||
def test_set_version(migrator: SQLiteMigrator):
|
def test_set_version(migrator: SQLiteMigrator):
|
||||||
migrator._create_version_table()
|
migrator._create_migrations_table()
|
||||||
migrator._set_version(db_version=1, app_version="1.0.0")
|
migrator._set_version(db_version=1, app_version="1.0.0")
|
||||||
migrator._cursor.execute("SELECT MAX(db_version) FROM version;")
|
migrator._cursor.execute("SELECT MAX(db_version) FROM version;")
|
||||||
assert migrator._cursor.fetchone()[0] == 1
|
assert migrator._cursor.fetchone()[0] == 1
|
||||||
@ -71,7 +71,7 @@ def test_set_version(migrator: SQLiteMigrator):
|
|||||||
|
|
||||||
|
|
||||||
def test_run_migration(migrator: SQLiteMigrator):
|
def test_run_migration(migrator: SQLiteMigrator):
|
||||||
migrator._create_version_table()
|
migrator._create_migrations_table()
|
||||||
|
|
||||||
def migration_callback(cursor: sqlite3.Cursor) -> None:
|
def migration_callback(cursor: sqlite3.Cursor) -> None:
|
||||||
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||||
@ -84,7 +84,7 @@ def test_run_migration(migrator: SQLiteMigrator):
|
|||||||
|
|
||||||
|
|
||||||
def test_run_migrations(migrator: SQLiteMigrator):
|
def test_run_migrations(migrator: SQLiteMigrator):
|
||||||
migrator._create_version_table()
|
migrator._create_migrations_table()
|
||||||
|
|
||||||
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
|
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
|
||||||
def migrate(cursor: sqlite3.Cursor) -> None:
|
def migrate(cursor: sqlite3.Cursor) -> None:
|
||||||
@ -109,8 +109,8 @@ def test_backup_and_restore_db():
|
|||||||
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
migrator = SQLiteMigrator(conn=conn, database=database, lock=threading.RLock(), logger=Logger("test"))
|
migrator = SQLiteMigrator(conn=conn, db_path=database, lock=threading.RLock(), logger=Logger("test"))
|
||||||
backup_path = migrator._backup_db(migrator._database)
|
backup_path = migrator._backup_db(migrator._db_path)
|
||||||
|
|
||||||
# mangle the db
|
# mangle the db
|
||||||
migrator._cursor.execute("DROP TABLE test;")
|
migrator._cursor.execute("DROP TABLE test;")
|
||||||
@ -135,21 +135,21 @@ def test_no_backup_and_restore_for_memory_db(migrator: SQLiteMigrator):
|
|||||||
|
|
||||||
|
|
||||||
def test_failed_migration(migrator: SQLiteMigrator, failing_migration: Migration):
|
def test_failed_migration(migrator: SQLiteMigrator, failing_migration: Migration):
|
||||||
migrator._create_version_table()
|
migrator._create_migrations_table()
|
||||||
with pytest.raises(MigrationError, match="Error migrating database from 0 to 1"):
|
with pytest.raises(MigrationError, match="Error migrating database from 0 to 1"):
|
||||||
migrator._run_migration(failing_migration)
|
migrator._run_migration(failing_migration)
|
||||||
assert migrator._get_current_version() == 0
|
assert migrator._get_current_version() == 0
|
||||||
|
|
||||||
|
|
||||||
def test_duplicate_migration_versions(migrator: SQLiteMigrator, good_migration: Migration):
|
def test_duplicate_migration_versions(migrator: SQLiteMigrator, good_migration: Migration):
|
||||||
migrator._create_version_table()
|
migrator._create_migrations_table()
|
||||||
migrator.register_migration(good_migration)
|
migrator.register_migration(good_migration)
|
||||||
with pytest.raises(MigrationVersionError, match="already registered"):
|
with pytest.raises(MigrationVersionError, match="already registered"):
|
||||||
migrator.register_migration(deepcopy(good_migration))
|
migrator.register_migration(deepcopy(good_migration))
|
||||||
|
|
||||||
|
|
||||||
def test_non_sequential_migration_registration(migrator: SQLiteMigrator):
|
def test_non_sequential_migration_registration(migrator: SQLiteMigrator):
|
||||||
migrator._create_version_table()
|
migrator._create_migrations_table()
|
||||||
|
|
||||||
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
|
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
|
||||||
def migrate(cursor: sqlite3.Cursor) -> None:
|
def migrate(cursor: sqlite3.Cursor) -> None:
|
||||||
@ -167,7 +167,7 @@ def test_non_sequential_migration_registration(migrator: SQLiteMigrator):
|
|||||||
|
|
||||||
|
|
||||||
def test_db_version_gt_last_migration(migrator: SQLiteMigrator, good_migration: Migration):
|
def test_db_version_gt_last_migration(migrator: SQLiteMigrator, good_migration: Migration):
|
||||||
migrator._create_version_table()
|
migrator._create_migrations_table()
|
||||||
migrator.register_migration(good_migration)
|
migrator.register_migration(good_migration)
|
||||||
migrator._set_version(db_version=2, app_version="2.0.0")
|
migrator._set_version(db_version=2, app_version="2.0.0")
|
||||||
with pytest.raises(MigrationError, match="greater than the latest migration version"):
|
with pytest.raises(MigrationError, match="greater than the latest migration version"):
|
||||||
@ -176,7 +176,7 @@ def test_db_version_gt_last_migration(migrator: SQLiteMigrator, good_migration:
|
|||||||
|
|
||||||
|
|
||||||
def test_idempotent_migrations(migrator: SQLiteMigrator):
|
def test_idempotent_migrations(migrator: SQLiteMigrator):
|
||||||
migrator._create_version_table()
|
migrator._create_migrations_table()
|
||||||
|
|
||||||
def create_test_table(cursor: sqlite3.Cursor) -> None:
|
def create_test_table(cursor: sqlite3.Cursor) -> None:
|
||||||
# This SQL throws if run twice
|
# This SQL throws if run twice
|
||||||
|
Loading…
x
Reference in New Issue
Block a user