feat(db): instantiate SqliteMigrator with a SqliteDatabase

Simplifies a couple things:
- Init is more straightforward
- It's clear in the migrator that the connection we are working with is related to the SqliteDatabase
This commit is contained in:
psychedelicious 2023-12-12 10:46:08 +11:00
parent 417db71471
commit 3414437eea
7 changed files with 58 additions and 101 deletions

View File

@ -77,13 +77,7 @@ class ApiDependencies:
db_path = None if config.use_memory_db else config.db_path db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql) db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
migrator = SQLiteMigrator( migrator = SQLiteMigrator(db=db)
db_path=db.db_path,
conn=db.conn,
lock=db.lock,
logger=logger,
log_sql=config.log_sql,
)
migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files)) migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files))
migrator.register_migration(migration_1) migrator.register_migration(migration_1)
migrator.register_migration(migration_2) migrator.register_migration(migration_2)

View File

@ -1,11 +1,10 @@
import shutil import shutil
import sqlite3 import sqlite3
import threading
from datetime import datetime from datetime import datetime
from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationError, MigrationSet from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationError, MigrationSet
@ -30,26 +29,15 @@ class SQLiteMigrator:
backup_path: Optional[Path] = None backup_path: Optional[Path] = None
def __init__( def __init__(self, db: SqliteDatabase) -> None:
self, self._db = db
db_path: Path | None, self._logger = db.logger
conn: sqlite3.Connection,
lock: threading.RLock,
logger: Logger,
log_sql: bool = False,
) -> None:
self._lock = lock
self._db_path = db_path
self._logger = logger
self._conn = conn
self._log_sql = log_sql
self._cursor = self._conn.cursor()
self._migration_set = MigrationSet() self._migration_set = MigrationSet()
# The presence of an temp database file indicates a catastrophic failure of a previous migration. # The presence of an temp database file indicates a catastrophic failure of a previous migration.
if self._db_path and self._get_temp_db_path(self._db_path).is_file(): if self._db.db_path and self._get_temp_db_path(self._db.db_path).is_file():
self._logger.warning("Previous migration failed! Trying again...") self._logger.warning("Previous migration failed! Trying again...")
self._get_temp_db_path(self._db_path).unlink() self._get_temp_db_path(self._db.db_path).unlink()
def register_migration(self, migration: Migration) -> None: def register_migration(self, migration: Migration) -> None:
"""Registers a migration.""" """Registers a migration."""
@ -58,43 +46,41 @@ class SQLiteMigrator:
def run_migrations(self) -> bool: def run_migrations(self) -> bool:
"""Migrates the database to the latest version.""" """Migrates the database to the latest version."""
with self._lock: with self._db.lock:
# This throws if there is a problem. # This throws if there is a problem.
self._migration_set.validate_migration_chain() self._migration_set.validate_migration_chain()
self._create_migrations_table(cursor=self._cursor) cursor = self._db.conn.cursor()
self._create_migrations_table(cursor=cursor)
if self._migration_set.count == 0: if self._migration_set.count == 0:
self._logger.debug("No migrations registered") self._logger.debug("No migrations registered")
return False return False
if self._get_current_version(self._cursor) == self._migration_set.latest_version: if self._get_current_version(cursor=cursor) == self._migration_set.latest_version:
self._logger.debug("Database is up to date, no migrations to run") self._logger.debug("Database is up to date, no migrations to run")
return False return False
self._logger.info("Database update needed") self._logger.info("Database update needed")
if self._db_path: if self._db.db_path:
# We are using a file database. Create a copy of the database to run the migrations on. # We are using a file database. Create a copy of the database to run the migrations on.
temp_db_path = self._create_temp_db(self._db_path) temp_db_path = self._create_temp_db(self._db.db_path)
self._logger.info(f"Copied database to {temp_db_path} for migration") self._logger.info(f"Copied database to {temp_db_path} for migration")
temp_db_conn = sqlite3.connect(temp_db_path) temp_db = SqliteDatabase(db_path=temp_db_path, logger=self._logger, verbose=self._db.verbose)
# We have to re-set this because we just created a new connection. temp_db_cursor = temp_db.conn.cursor()
if self._log_sql:
temp_db_conn.set_trace_callback(self._logger.debug)
temp_db_cursor = temp_db_conn.cursor()
self._run_migrations(temp_db_cursor) self._run_migrations(temp_db_cursor)
# Close the connections, copy the original database as a backup, and move the temp database to the # Close the connections, copy the original database as a backup, and move the temp database to the
# original database's path. # original database's path.
temp_db_conn.close() temp_db.close()
self._conn.close() self._db.close()
backup_db_path = self._finalize_migration( backup_db_path = self._finalize_migration(
temp_db_path=temp_db_path, temp_db_path=temp_db_path,
original_db_path=self._db_path, original_db_path=self._db.db_path,
) )
self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}") self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}")
else: else:
# We are using a memory database. No special backup or special handling needed. # We are using a memory database. No special backup or special handling needed.
self._run_migrations(self._cursor) self._run_migrations(cursor)
self._logger.info("Database updated successfully") self._logger.info("Database updated successfully")
return True return True
@ -108,7 +94,7 @@ class SQLiteMigrator:
def _run_migration(self, migration: Migration, temp_db_cursor: sqlite3.Cursor) -> None: def _run_migration(self, migration: Migration, temp_db_cursor: sqlite3.Cursor) -> None:
"""Runs a single migration.""" """Runs a single migration."""
with self._lock: with self._db.lock:
try: try:
if self._get_current_version(temp_db_cursor) != migration.from_version: if self._get_current_version(temp_db_cursor) != migration.from_version:
raise MigrationError( raise MigrationError(
@ -145,7 +131,7 @@ class SQLiteMigrator:
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None: def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
"""Creates the migrations table for the database, if one does not already exist.""" """Creates the migrations table for the database, if one does not already exist."""
with self._lock: with self._db.lock:
try: try:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';") cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
if cursor.fetchone() is not None: if cursor.fetchone() is not None:

View File

@ -55,13 +55,7 @@ def mock_services() -> InvocationServices:
logger = InvokeAILogger.get_logger() logger = InvokeAILogger.get_logger()
db_path = None if configuration.use_memory_db else configuration.db_path db_path = None if configuration.use_memory_db else configuration.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql) db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql)
migrator = SQLiteMigrator( migrator = SQLiteMigrator(db=db)
db_path=db.db_path,
conn=db.conn,
lock=db.lock,
logger=logger,
log_sql=configuration.log_sql,
)
migrator.register_migration(migration_1) migrator.register_migration(migration_1)
migrator.register_migration(migration_2) migrator.register_migration(migration_2)
migrator.run_migrations() migrator.run_migrations()

View File

@ -59,13 +59,7 @@ def mock_services() -> InvocationServices:
logger = InvokeAILogger.get_logger() logger = InvokeAILogger.get_logger()
db_path = None if configuration.use_memory_db else configuration.db_path db_path = None if configuration.use_memory_db else configuration.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql) db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql)
migrator = SQLiteMigrator( migrator = SQLiteMigrator(db=db)
db_path=db.db_path,
conn=db.conn,
lock=db.lock,
logger=logger,
log_sql=configuration.log_sql,
)
migrator.register_migration(migration_1) migrator.register_migration(migration_1)
migrator.register_migration(migration_2) migrator.register_migration(migration_2)
migrator.run_migrations() migrator.run_migrations()

View File

@ -44,13 +44,7 @@ def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=app_config) logger = InvokeAILogger.get_logger(config=app_config)
db_path = None if app_config.use_memory_db else app_config.db_path db_path = None if app_config.use_memory_db else app_config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=app_config.log_sql) db = SqliteDatabase(db_path=db_path, logger=logger, verbose=app_config.log_sql)
migrator = SQLiteMigrator( migrator = SQLiteMigrator(db=db)
db_path=db.db_path,
conn=db.conn,
lock=db.lock,
logger=logger,
log_sql=app_config.log_sql,
)
migrator.register_migration(migration_1) migrator.register_migration(migration_1)
migrator.register_migration(migration_2) migrator.register_migration(migration_2)
migrator.run_migrations() migrator.run_migrations()

View File

@ -35,13 +35,7 @@ def store(datadir: Any) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=config) logger = InvokeAILogger.get_logger(config=config)
db_path = None if config.use_memory_db else config.db_path db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql) db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
migrator = SQLiteMigrator( migrator = SQLiteMigrator(db=db)
db_path=db.db_path,
conn=db.conn,
lock=db.lock,
logger=logger,
log_sql=config.log_sql,
)
migrator.register_migration(migration_1) migrator.register_migration(migration_1)
migrator.register_migration(migration_2) migrator.register_migration(migration_2)
migrator.run_migrations() migrator.run_migrations()

View File

@ -1,5 +1,4 @@
import sqlite3 import sqlite3
import threading
from contextlib import closing from contextlib import closing
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
@ -8,7 +7,7 @@ from tempfile import TemporaryDirectory
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import ( from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import (
MigrateCallback, MigrateCallback,
Migration, Migration,
@ -27,14 +26,9 @@ def logger() -> Logger:
@pytest.fixture @pytest.fixture
def lock() -> threading.RLock: def migrator(logger: Logger) -> SQLiteMigrator:
return threading.RLock() db = SqliteDatabase(db_path=None, logger=logger, verbose=False)
return SQLiteMigrator(db=db)
@pytest.fixture
def migrator(logger: Logger, lock: threading.RLock) -> SQLiteMigrator:
conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
return SQLiteMigrator(conn=conn, db_path=None, lock=lock, logger=logger)
@pytest.fixture @pytest.fixture
@ -176,50 +170,55 @@ def test_migrator_registers_migration(migrator: SQLiteMigrator, migration_no_op:
def test_migrator_creates_migrations_table(migrator: SQLiteMigrator) -> None: def test_migrator_creates_migrations_table(migrator: SQLiteMigrator) -> None:
migrator._create_migrations_table(migrator._cursor) cursor = migrator._db.conn.cursor()
migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='migrations';") migrator._create_migrations_table(cursor)
assert migrator._cursor.fetchone() is not None cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='migrations';")
assert cursor.fetchone() is not None
def test_migrator_migration_sets_version(migrator: SQLiteMigrator, migration_no_op: Migration) -> None: def test_migrator_migration_sets_version(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
migrator._create_migrations_table(migrator._cursor) cursor = migrator._db.conn.cursor()
migrator._create_migrations_table(cursor)
migrator.register_migration(migration_no_op) migrator.register_migration(migration_no_op)
migrator.run_migrations() migrator.run_migrations()
migrator._cursor.execute("SELECT MAX(version) FROM migrations;") cursor.execute("SELECT MAX(version) FROM migrations;")
assert migrator._cursor.fetchone()[0] == 1 assert cursor.fetchone()[0] == 1
def test_migrator_gets_current_version(migrator: SQLiteMigrator, migration_no_op: Migration) -> None: def test_migrator_gets_current_version(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
assert migrator._get_current_version(migrator._cursor) == 0 cursor = migrator._db.conn.cursor()
migrator._create_migrations_table(migrator._cursor) assert migrator._get_current_version(cursor) == 0
assert migrator._get_current_version(migrator._cursor) == 0 migrator._create_migrations_table(cursor)
assert migrator._get_current_version(cursor) == 0
migrator.register_migration(migration_no_op) migrator.register_migration(migration_no_op)
migrator.run_migrations() migrator.run_migrations()
assert migrator._get_current_version(migrator._cursor) == 1 assert migrator._get_current_version(cursor) == 1
def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_create_test_table: Migration) -> None: def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_create_test_table: Migration) -> None:
migrator._create_migrations_table(migrator._cursor) cursor = migrator._db.conn.cursor()
migrator._run_migration(migration_create_test_table, migrator._cursor) migrator._create_migrations_table(cursor)
assert migrator._get_current_version(migrator._cursor) == 1 migrator._run_migration(migration_create_test_table, cursor)
migrator._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';") assert migrator._get_current_version(cursor) == 1
assert migrator._cursor.fetchone() is not None cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
assert cursor.fetchone() is not None
def test_migrator_runs_all_migrations_in_memory(migrator: SQLiteMigrator) -> None: def test_migrator_runs_all_migrations_in_memory(migrator: SQLiteMigrator) -> None:
cursor = migrator._db.conn.cursor()
migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)] migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)]
for migration in migrations: for migration in migrations:
migrator.register_migration(migration) migrator.register_migration(migration)
migrator.run_migrations() migrator.run_migrations()
assert migrator._get_current_version(migrator._cursor) == 3 assert migrator._get_current_version(cursor) == 3
def test_migrator_runs_all_migrations_file(logger: Logger, lock: threading.RLock) -> None: def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
with TemporaryDirectory() as tempdir: with TemporaryDirectory() as tempdir:
original_db_path = Path(tempdir) / "invokeai.db" original_db_path = Path(tempdir) / "invokeai.db"
# The Migrator closes the database when it finishes; we cannot use a context manager. # The Migrator closes the database when it finishes; we cannot use a context manager.
original_db_conn = sqlite3.connect(original_db_path) db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
migrator = SQLiteMigrator(conn=original_db_conn, db_path=original_db_path, lock=lock, logger=logger) migrator = SQLiteMigrator(db=db)
migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)] migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)]
for migration in migrations: for migration in migrations:
migrator.register_migration(migration) migrator.register_migration(migration)
@ -272,18 +271,20 @@ def test_migrator_finalizes() -> None:
def test_migrator_makes_no_changes_on_failed_migration( def test_migrator_makes_no_changes_on_failed_migration(
migrator: SQLiteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback migrator: SQLiteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
) -> None: ) -> None:
cursor = migrator._db.conn.cursor()
migrator.register_migration(migration_no_op) migrator.register_migration(migration_no_op)
migrator.run_migrations() migrator.run_migrations()
assert migrator._get_current_version(migrator._cursor) == 1 assert migrator._get_current_version(cursor) == 1
migrator.register_migration(Migration(from_version=1, to_version=2, migrate=failing_migrate_callback)) migrator.register_migration(Migration(from_version=1, to_version=2, migrate=failing_migrate_callback))
with pytest.raises(MigrationError, match="Bad migration"): with pytest.raises(MigrationError, match="Bad migration"):
migrator.run_migrations() migrator.run_migrations()
assert migrator._get_current_version(migrator._cursor) == 1 assert migrator._get_current_version(cursor) == 1
def test_idempotent_migrations(migrator: SQLiteMigrator, migration_create_test_table: Migration) -> None: def test_idempotent_migrations(migrator: SQLiteMigrator, migration_create_test_table: Migration) -> None:
cursor = migrator._db.conn.cursor()
migrator.register_migration(migration_create_test_table) migrator.register_migration(migration_create_test_table)
migrator.run_migrations() migrator.run_migrations()
# not throwing is sufficient # not throwing is sufficient
migrator.run_migrations() migrator.run_migrations()
assert migrator._get_current_version(migrator._cursor) == 1 assert migrator._get_current_version(cursor) == 1