fix(db): fix windows db migrator tests

- Ensure db files are closed before manipulating them
- Use contextlib.closing() so that sqlite connections are closed on existing the context
This commit is contained in:
psychedelicious@windows 2023-12-11 16:12:03 +11:00 committed by psychedelicious
parent 26ab917021
commit f1b6f78319
2 changed files with 15 additions and 10 deletions

View File

@ -190,12 +190,12 @@ class SQLiteMigrator:
self._run_migrations(temp_db_cursor) self._run_migrations(temp_db_cursor)
# Close the connections, copy the original database as a backup, and move the temp database to the # Close the connections, copy the original database as a backup, and move the temp database to the
# original database's path. # original database's path.
temp_db_conn.close()
self._conn.close()
backup_db_path = self._finalize_migration( backup_db_path = self._finalize_migration(
temp_db_path=temp_db_path, temp_db_path=temp_db_path,
original_db_path=self._db_path, original_db_path=self._db_path,
) )
temp_db_conn.close()
self._conn.close()
self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}") self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}")
else: else:
# We are using a memory database. No special backup or special handling needed. # We are using a memory database. No special backup or special handling needed.

View File

@ -1,5 +1,6 @@
import sqlite3 import sqlite3
import threading import threading
from contextlib import closing
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
@ -217,7 +218,7 @@ def test_migrator_runs_all_migrations_file(logger: Logger, lock: threading.RLock
for migration in migrations: for migration in migrations:
migrator.register_migration(migration) migrator.register_migration(migration)
migrator.run_migrations() migrator.run_migrations()
with sqlite3.connect(original_db_path) as original_db_conn: with closing(sqlite3.connect(original_db_path)) as original_db_conn:
original_db_cursor = original_db_conn.cursor() original_db_cursor = original_db_conn.cursor()
assert SQLiteMigrator._get_current_version(original_db_cursor) == 3 assert SQLiteMigrator._get_current_version(original_db_cursor) == 3
@ -225,7 +226,7 @@ def test_migrator_runs_all_migrations_file(logger: Logger, lock: threading.RLock
def test_migrator_creates_temp_db(): def test_migrator_creates_temp_db():
with TemporaryDirectory() as tempdir: with TemporaryDirectory() as tempdir:
original_db_path = Path(tempdir) / "invokeai.db" original_db_path = Path(tempdir) / "invokeai.db"
with sqlite3.connect(original_db_path): with closing(sqlite3.connect(original_db_path)):
# create the db file so _create_temp_db has something to copy # create the db file so _create_temp_db has something to copy
pass pass
temp_db_path = SQLiteMigrator._create_temp_db(original_db_path) temp_db_path = SQLiteMigrator._create_temp_db(original_db_path)
@ -238,18 +239,22 @@ def test_migrator_finalizes():
original_db_path = Path(tempdir) / "invokeai.db" original_db_path = Path(tempdir) / "invokeai.db"
temp_db_path = SQLiteMigrator._get_temp_db_path(original_db_path) temp_db_path = SQLiteMigrator._get_temp_db_path(original_db_path)
backup_db_path = SQLiteMigrator._get_backup_db_path(original_db_path) backup_db_path = SQLiteMigrator._get_backup_db_path(original_db_path)
with sqlite3.connect(original_db_path) as original_db_conn, sqlite3.connect(temp_db_path) as temp_db_conn: with closing(sqlite3.connect(original_db_path)) as original_db_conn, closing(
sqlite3.connect(temp_db_path)
) as temp_db_conn:
original_db_cursor = original_db_conn.cursor() original_db_cursor = original_db_conn.cursor()
original_db_cursor.execute("CREATE TABLE original_db_test (id INTEGER PRIMARY KEY);") original_db_cursor.execute("CREATE TABLE original_db_test (id INTEGER PRIMARY KEY);")
original_db_conn.commit() original_db_conn.commit()
temp_db_cursor = temp_db_conn.cursor() temp_db_cursor = temp_db_conn.cursor()
temp_db_cursor.execute("CREATE TABLE temp_db_test (id INTEGER PRIMARY KEY);") temp_db_cursor.execute("CREATE TABLE temp_db_test (id INTEGER PRIMARY KEY);")
temp_db_conn.commit() temp_db_conn.commit()
SQLiteMigrator._finalize_migration( SQLiteMigrator._finalize_migration(
original_db_path=original_db_path, original_db_path=original_db_path,
temp_db_path=temp_db_path, temp_db_path=temp_db_path,
) )
with sqlite3.connect(backup_db_path) as backup_db_conn, sqlite3.connect(temp_db_path) as temp_db_conn: with closing(sqlite3.connect(backup_db_path)) as backup_db_conn, closing(
sqlite3.connect(temp_db_path)
) as temp_db_conn:
backup_db_cursor = backup_db_conn.cursor() backup_db_cursor = backup_db_conn.cursor()
backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='original_db_test';") backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='original_db_test';")
assert backup_db_cursor.fetchone() is not None assert backup_db_cursor.fetchone() is not None