fix(db): fix windows db migrator tests

- Ensure db files are closed before manipulating them
- Use contextlib.closing() so that sqlite connections are closed on existing the context
This commit is contained in:
psychedelicious@windows 2023-12-11 16:12:03 +11:00 committed by psychedelicious
parent 26ab917021
commit f1b6f78319
2 changed files with 15 additions and 10 deletions

View File

@ -190,12 +190,12 @@ class SQLiteMigrator:
self._run_migrations(temp_db_cursor)
# Close the connections, copy the original database as a backup, and move the temp database to the
# original database's path.
temp_db_conn.close()
self._conn.close()
backup_db_path = self._finalize_migration(
temp_db_path=temp_db_path,
original_db_path=self._db_path,
)
temp_db_conn.close()
self._conn.close()
self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}")
else:
# We are using a memory database. No special backup or special handling needed.

View File

@ -1,5 +1,6 @@
import sqlite3
import threading
from contextlib import closing
from logging import Logger
from pathlib import Path
from tempfile import TemporaryDirectory
@ -217,7 +218,7 @@ def test_migrator_runs_all_migrations_file(logger: Logger, lock: threading.RLock
for migration in migrations:
migrator.register_migration(migration)
migrator.run_migrations()
with sqlite3.connect(original_db_path) as original_db_conn:
with closing(sqlite3.connect(original_db_path)) as original_db_conn:
original_db_cursor = original_db_conn.cursor()
assert SQLiteMigrator._get_current_version(original_db_cursor) == 3
@ -225,7 +226,7 @@ def test_migrator_runs_all_migrations_file(logger: Logger, lock: threading.RLock
def test_migrator_creates_temp_db():
with TemporaryDirectory() as tempdir:
original_db_path = Path(tempdir) / "invokeai.db"
with sqlite3.connect(original_db_path):
with closing(sqlite3.connect(original_db_path)):
# create the db file so _create_temp_db has something to copy
pass
temp_db_path = SQLiteMigrator._create_temp_db(original_db_path)
@ -238,18 +239,22 @@ def test_migrator_finalizes():
original_db_path = Path(tempdir) / "invokeai.db"
temp_db_path = SQLiteMigrator._get_temp_db_path(original_db_path)
backup_db_path = SQLiteMigrator._get_backup_db_path(original_db_path)
with sqlite3.connect(original_db_path) as original_db_conn, sqlite3.connect(temp_db_path) as temp_db_conn:
with closing(sqlite3.connect(original_db_path)) as original_db_conn, closing(
sqlite3.connect(temp_db_path)
) as temp_db_conn:
original_db_cursor = original_db_conn.cursor()
original_db_cursor.execute("CREATE TABLE original_db_test (id INTEGER PRIMARY KEY);")
original_db_conn.commit()
temp_db_cursor = temp_db_conn.cursor()
temp_db_cursor.execute("CREATE TABLE temp_db_test (id INTEGER PRIMARY KEY);")
temp_db_conn.commit()
SQLiteMigrator._finalize_migration(
original_db_path=original_db_path,
temp_db_path=temp_db_path,
)
with sqlite3.connect(backup_db_path) as backup_db_conn, sqlite3.connect(temp_db_path) as temp_db_conn:
SQLiteMigrator._finalize_migration(
original_db_path=original_db_path,
temp_db_path=temp_db_path,
)
with closing(sqlite3.connect(backup_db_path)) as backup_db_conn, closing(
sqlite3.connect(temp_db_path)
) as temp_db_conn:
backup_db_cursor = backup_db_conn.cursor()
backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='original_db_test';")
assert backup_db_cursor.fetchone() is not None