mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): instantiate SqliteMigrator with a SqliteDatabase
Simplifies a couple things: - Init is more straightforward - It's clear in the migrator that the connection we are working with is related to the SqliteDatabase
This commit is contained in:
parent
417db71471
commit
3414437eea
@ -77,13 +77,7 @@ class ApiDependencies:
|
|||||||
db_path = None if config.use_memory_db else config.db_path
|
db_path = None if config.use_memory_db else config.db_path
|
||||||
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
|
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
|
||||||
|
|
||||||
migrator = SQLiteMigrator(
|
migrator = SQLiteMigrator(db=db)
|
||||||
db_path=db.db_path,
|
|
||||||
conn=db.conn,
|
|
||||||
lock=db.lock,
|
|
||||||
logger=logger,
|
|
||||||
log_sql=config.log_sql,
|
|
||||||
)
|
|
||||||
migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files))
|
migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files))
|
||||||
migrator.register_migration(migration_1)
|
migrator.register_migration(migration_1)
|
||||||
migrator.register_migration(migration_2)
|
migrator.register_migration(migration_2)
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
import shutil
|
import shutil
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from logging import Logger
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationError, MigrationSet
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationError, MigrationSet
|
||||||
|
|
||||||
|
|
||||||
@ -30,26 +29,15 @@ class SQLiteMigrator:
|
|||||||
|
|
||||||
backup_path: Optional[Path] = None
|
backup_path: Optional[Path] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db: SqliteDatabase) -> None:
|
||||||
self,
|
self._db = db
|
||||||
db_path: Path | None,
|
self._logger = db.logger
|
||||||
conn: sqlite3.Connection,
|
|
||||||
lock: threading.RLock,
|
|
||||||
logger: Logger,
|
|
||||||
log_sql: bool = False,
|
|
||||||
) -> None:
|
|
||||||
self._lock = lock
|
|
||||||
self._db_path = db_path
|
|
||||||
self._logger = logger
|
|
||||||
self._conn = conn
|
|
||||||
self._log_sql = log_sql
|
|
||||||
self._cursor = self._conn.cursor()
|
|
||||||
self._migration_set = MigrationSet()
|
self._migration_set = MigrationSet()
|
||||||
|
|
||||||
# The presence of an temp database file indicates a catastrophic failure of a previous migration.
|
# The presence of an temp database file indicates a catastrophic failure of a previous migration.
|
||||||
if self._db_path and self._get_temp_db_path(self._db_path).is_file():
|
if self._db.db_path and self._get_temp_db_path(self._db.db_path).is_file():
|
||||||
self._logger.warning("Previous migration failed! Trying again...")
|
self._logger.warning("Previous migration failed! Trying again...")
|
||||||
self._get_temp_db_path(self._db_path).unlink()
|
self._get_temp_db_path(self._db.db_path).unlink()
|
||||||
|
|
||||||
def register_migration(self, migration: Migration) -> None:
|
def register_migration(self, migration: Migration) -> None:
|
||||||
"""Registers a migration."""
|
"""Registers a migration."""
|
||||||
@ -58,43 +46,41 @@ class SQLiteMigrator:
|
|||||||
|
|
||||||
def run_migrations(self) -> bool:
|
def run_migrations(self) -> bool:
|
||||||
"""Migrates the database to the latest version."""
|
"""Migrates the database to the latest version."""
|
||||||
with self._lock:
|
with self._db.lock:
|
||||||
# This throws if there is a problem.
|
# This throws if there is a problem.
|
||||||
self._migration_set.validate_migration_chain()
|
self._migration_set.validate_migration_chain()
|
||||||
self._create_migrations_table(cursor=self._cursor)
|
cursor = self._db.conn.cursor()
|
||||||
|
self._create_migrations_table(cursor=cursor)
|
||||||
|
|
||||||
if self._migration_set.count == 0:
|
if self._migration_set.count == 0:
|
||||||
self._logger.debug("No migrations registered")
|
self._logger.debug("No migrations registered")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self._get_current_version(self._cursor) == self._migration_set.latest_version:
|
if self._get_current_version(cursor=cursor) == self._migration_set.latest_version:
|
||||||
self._logger.debug("Database is up to date, no migrations to run")
|
self._logger.debug("Database is up to date, no migrations to run")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self._logger.info("Database update needed")
|
self._logger.info("Database update needed")
|
||||||
|
|
||||||
if self._db_path:
|
if self._db.db_path:
|
||||||
# We are using a file database. Create a copy of the database to run the migrations on.
|
# We are using a file database. Create a copy of the database to run the migrations on.
|
||||||
temp_db_path = self._create_temp_db(self._db_path)
|
temp_db_path = self._create_temp_db(self._db.db_path)
|
||||||
self._logger.info(f"Copied database to {temp_db_path} for migration")
|
self._logger.info(f"Copied database to {temp_db_path} for migration")
|
||||||
temp_db_conn = sqlite3.connect(temp_db_path)
|
temp_db = SqliteDatabase(db_path=temp_db_path, logger=self._logger, verbose=self._db.verbose)
|
||||||
# We have to re-set this because we just created a new connection.
|
temp_db_cursor = temp_db.conn.cursor()
|
||||||
if self._log_sql:
|
|
||||||
temp_db_conn.set_trace_callback(self._logger.debug)
|
|
||||||
temp_db_cursor = temp_db_conn.cursor()
|
|
||||||
self._run_migrations(temp_db_cursor)
|
self._run_migrations(temp_db_cursor)
|
||||||
# Close the connections, copy the original database as a backup, and move the temp database to the
|
# Close the connections, copy the original database as a backup, and move the temp database to the
|
||||||
# original database's path.
|
# original database's path.
|
||||||
temp_db_conn.close()
|
temp_db.close()
|
||||||
self._conn.close()
|
self._db.close()
|
||||||
backup_db_path = self._finalize_migration(
|
backup_db_path = self._finalize_migration(
|
||||||
temp_db_path=temp_db_path,
|
temp_db_path=temp_db_path,
|
||||||
original_db_path=self._db_path,
|
original_db_path=self._db.db_path,
|
||||||
)
|
)
|
||||||
self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}")
|
self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}")
|
||||||
else:
|
else:
|
||||||
# We are using a memory database. No special backup or special handling needed.
|
# We are using a memory database. No special backup or special handling needed.
|
||||||
self._run_migrations(self._cursor)
|
self._run_migrations(cursor)
|
||||||
|
|
||||||
self._logger.info("Database updated successfully")
|
self._logger.info("Database updated successfully")
|
||||||
return True
|
return True
|
||||||
@ -108,7 +94,7 @@ class SQLiteMigrator:
|
|||||||
|
|
||||||
def _run_migration(self, migration: Migration, temp_db_cursor: sqlite3.Cursor) -> None:
|
def _run_migration(self, migration: Migration, temp_db_cursor: sqlite3.Cursor) -> None:
|
||||||
"""Runs a single migration."""
|
"""Runs a single migration."""
|
||||||
with self._lock:
|
with self._db.lock:
|
||||||
try:
|
try:
|
||||||
if self._get_current_version(temp_db_cursor) != migration.from_version:
|
if self._get_current_version(temp_db_cursor) != migration.from_version:
|
||||||
raise MigrationError(
|
raise MigrationError(
|
||||||
@ -145,7 +131,7 @@ class SQLiteMigrator:
|
|||||||
|
|
||||||
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
|
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
|
||||||
"""Creates the migrations table for the database, if one does not already exist."""
|
"""Creates the migrations table for the database, if one does not already exist."""
|
||||||
with self._lock:
|
with self._db.lock:
|
||||||
try:
|
try:
|
||||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
|
||||||
if cursor.fetchone() is not None:
|
if cursor.fetchone() is not None:
|
||||||
|
@ -55,13 +55,7 @@ def mock_services() -> InvocationServices:
|
|||||||
logger = InvokeAILogger.get_logger()
|
logger = InvokeAILogger.get_logger()
|
||||||
db_path = None if configuration.use_memory_db else configuration.db_path
|
db_path = None if configuration.use_memory_db else configuration.db_path
|
||||||
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql)
|
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql)
|
||||||
migrator = SQLiteMigrator(
|
migrator = SQLiteMigrator(db=db)
|
||||||
db_path=db.db_path,
|
|
||||||
conn=db.conn,
|
|
||||||
lock=db.lock,
|
|
||||||
logger=logger,
|
|
||||||
log_sql=configuration.log_sql,
|
|
||||||
)
|
|
||||||
migrator.register_migration(migration_1)
|
migrator.register_migration(migration_1)
|
||||||
migrator.register_migration(migration_2)
|
migrator.register_migration(migration_2)
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
|
@ -59,13 +59,7 @@ def mock_services() -> InvocationServices:
|
|||||||
logger = InvokeAILogger.get_logger()
|
logger = InvokeAILogger.get_logger()
|
||||||
db_path = None if configuration.use_memory_db else configuration.db_path
|
db_path = None if configuration.use_memory_db else configuration.db_path
|
||||||
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql)
|
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql)
|
||||||
migrator = SQLiteMigrator(
|
migrator = SQLiteMigrator(db=db)
|
||||||
db_path=db.db_path,
|
|
||||||
conn=db.conn,
|
|
||||||
lock=db.lock,
|
|
||||||
logger=logger,
|
|
||||||
log_sql=configuration.log_sql,
|
|
||||||
)
|
|
||||||
migrator.register_migration(migration_1)
|
migrator.register_migration(migration_1)
|
||||||
migrator.register_migration(migration_2)
|
migrator.register_migration(migration_2)
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
|
@ -44,13 +44,7 @@ def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
|||||||
logger = InvokeAILogger.get_logger(config=app_config)
|
logger = InvokeAILogger.get_logger(config=app_config)
|
||||||
db_path = None if app_config.use_memory_db else app_config.db_path
|
db_path = None if app_config.use_memory_db else app_config.db_path
|
||||||
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=app_config.log_sql)
|
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=app_config.log_sql)
|
||||||
migrator = SQLiteMigrator(
|
migrator = SQLiteMigrator(db=db)
|
||||||
db_path=db.db_path,
|
|
||||||
conn=db.conn,
|
|
||||||
lock=db.lock,
|
|
||||||
logger=logger,
|
|
||||||
log_sql=app_config.log_sql,
|
|
||||||
)
|
|
||||||
migrator.register_migration(migration_1)
|
migrator.register_migration(migration_1)
|
||||||
migrator.register_migration(migration_2)
|
migrator.register_migration(migration_2)
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
|
@ -35,13 +35,7 @@ def store(datadir: Any) -> ModelRecordServiceBase:
|
|||||||
logger = InvokeAILogger.get_logger(config=config)
|
logger = InvokeAILogger.get_logger(config=config)
|
||||||
db_path = None if config.use_memory_db else config.db_path
|
db_path = None if config.use_memory_db else config.db_path
|
||||||
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
|
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
|
||||||
migrator = SQLiteMigrator(
|
migrator = SQLiteMigrator(db=db)
|
||||||
db_path=db.db_path,
|
|
||||||
conn=db.conn,
|
|
||||||
lock=db.lock,
|
|
||||||
logger=logger,
|
|
||||||
log_sql=config.log_sql,
|
|
||||||
)
|
|
||||||
migrator.register_migration(migration_1)
|
migrator.register_migration(migration_1)
|
||||||
migrator.register_migration(migration_2)
|
migrator.register_migration(migration_2)
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -8,7 +7,7 @@ from tempfile import TemporaryDirectory
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import (
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import (
|
||||||
MigrateCallback,
|
MigrateCallback,
|
||||||
Migration,
|
Migration,
|
||||||
@ -27,14 +26,9 @@ def logger() -> Logger:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def lock() -> threading.RLock:
|
def migrator(logger: Logger) -> SQLiteMigrator:
|
||||||
return threading.RLock()
|
db = SqliteDatabase(db_path=None, logger=logger, verbose=False)
|
||||||
|
return SQLiteMigrator(db=db)
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def migrator(logger: Logger, lock: threading.RLock) -> SQLiteMigrator:
|
|
||||||
conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
|
||||||
return SQLiteMigrator(conn=conn, db_path=None, lock=lock, logger=logger)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -176,50 +170,55 @@ def test_migrator_registers_migration(migrator: SQLiteMigrator, migration_no_op:
|
|||||||
|
|
||||||
|
|
||||||
def test_migrator_creates_migrations_table(migrator: SQLiteMigrator) -> None:
|
def test_migrator_creates_migrations_table(migrator: SQLiteMigrator) -> None:
|
||||||
migrator._create_migrations_table(migrator._cursor)
|
cursor = migrator._db.conn.cursor()
|
||||||
migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='migrations';")
|
migrator._create_migrations_table(cursor)
|
||||||
assert migrator._cursor.fetchone() is not None
|
cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='migrations';")
|
||||||
|
assert cursor.fetchone() is not None
|
||||||
|
|
||||||
|
|
||||||
def test_migrator_migration_sets_version(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
|
def test_migrator_migration_sets_version(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
|
||||||
migrator._create_migrations_table(migrator._cursor)
|
cursor = migrator._db.conn.cursor()
|
||||||
|
migrator._create_migrations_table(cursor)
|
||||||
migrator.register_migration(migration_no_op)
|
migrator.register_migration(migration_no_op)
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
migrator._cursor.execute("SELECT MAX(version) FROM migrations;")
|
cursor.execute("SELECT MAX(version) FROM migrations;")
|
||||||
assert migrator._cursor.fetchone()[0] == 1
|
assert cursor.fetchone()[0] == 1
|
||||||
|
|
||||||
|
|
||||||
def test_migrator_gets_current_version(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
|
def test_migrator_gets_current_version(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
|
||||||
assert migrator._get_current_version(migrator._cursor) == 0
|
cursor = migrator._db.conn.cursor()
|
||||||
migrator._create_migrations_table(migrator._cursor)
|
assert migrator._get_current_version(cursor) == 0
|
||||||
assert migrator._get_current_version(migrator._cursor) == 0
|
migrator._create_migrations_table(cursor)
|
||||||
|
assert migrator._get_current_version(cursor) == 0
|
||||||
migrator.register_migration(migration_no_op)
|
migrator.register_migration(migration_no_op)
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
assert migrator._get_current_version(migrator._cursor) == 1
|
assert migrator._get_current_version(cursor) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_create_test_table: Migration) -> None:
|
def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_create_test_table: Migration) -> None:
|
||||||
migrator._create_migrations_table(migrator._cursor)
|
cursor = migrator._db.conn.cursor()
|
||||||
migrator._run_migration(migration_create_test_table, migrator._cursor)
|
migrator._create_migrations_table(cursor)
|
||||||
assert migrator._get_current_version(migrator._cursor) == 1
|
migrator._run_migration(migration_create_test_table, cursor)
|
||||||
migrator._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
|
assert migrator._get_current_version(cursor) == 1
|
||||||
assert migrator._cursor.fetchone() is not None
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
|
||||||
|
assert cursor.fetchone() is not None
|
||||||
|
|
||||||
|
|
||||||
def test_migrator_runs_all_migrations_in_memory(migrator: SQLiteMigrator) -> None:
|
def test_migrator_runs_all_migrations_in_memory(migrator: SQLiteMigrator) -> None:
|
||||||
|
cursor = migrator._db.conn.cursor()
|
||||||
migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)]
|
migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)]
|
||||||
for migration in migrations:
|
for migration in migrations:
|
||||||
migrator.register_migration(migration)
|
migrator.register_migration(migration)
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
assert migrator._get_current_version(migrator._cursor) == 3
|
assert migrator._get_current_version(cursor) == 3
|
||||||
|
|
||||||
|
|
||||||
def test_migrator_runs_all_migrations_file(logger: Logger, lock: threading.RLock) -> None:
|
def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
|
||||||
with TemporaryDirectory() as tempdir:
|
with TemporaryDirectory() as tempdir:
|
||||||
original_db_path = Path(tempdir) / "invokeai.db"
|
original_db_path = Path(tempdir) / "invokeai.db"
|
||||||
# The Migrator closes the database when it finishes; we cannot use a context manager.
|
# The Migrator closes the database when it finishes; we cannot use a context manager.
|
||||||
original_db_conn = sqlite3.connect(original_db_path)
|
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
|
||||||
migrator = SQLiteMigrator(conn=original_db_conn, db_path=original_db_path, lock=lock, logger=logger)
|
migrator = SQLiteMigrator(db=db)
|
||||||
migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)]
|
migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)]
|
||||||
for migration in migrations:
|
for migration in migrations:
|
||||||
migrator.register_migration(migration)
|
migrator.register_migration(migration)
|
||||||
@ -272,18 +271,20 @@ def test_migrator_finalizes() -> None:
|
|||||||
def test_migrator_makes_no_changes_on_failed_migration(
|
def test_migrator_makes_no_changes_on_failed_migration(
|
||||||
migrator: SQLiteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
|
migrator: SQLiteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
|
||||||
) -> None:
|
) -> None:
|
||||||
|
cursor = migrator._db.conn.cursor()
|
||||||
migrator.register_migration(migration_no_op)
|
migrator.register_migration(migration_no_op)
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
assert migrator._get_current_version(migrator._cursor) == 1
|
assert migrator._get_current_version(cursor) == 1
|
||||||
migrator.register_migration(Migration(from_version=1, to_version=2, migrate=failing_migrate_callback))
|
migrator.register_migration(Migration(from_version=1, to_version=2, migrate=failing_migrate_callback))
|
||||||
with pytest.raises(MigrationError, match="Bad migration"):
|
with pytest.raises(MigrationError, match="Bad migration"):
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
assert migrator._get_current_version(migrator._cursor) == 1
|
assert migrator._get_current_version(cursor) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_idempotent_migrations(migrator: SQLiteMigrator, migration_create_test_table: Migration) -> None:
|
def test_idempotent_migrations(migrator: SQLiteMigrator, migration_create_test_table: Migration) -> None:
|
||||||
|
cursor = migrator._db.conn.cursor()
|
||||||
migrator.register_migration(migration_create_test_table)
|
migrator.register_migration(migration_create_test_table)
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
# not throwing is sufficient
|
# not throwing is sufficient
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
assert migrator._get_current_version(migrator._cursor) == 1
|
assert migrator._get_current_version(cursor) == 1
|
||||||
|
Loading…
Reference in New Issue
Block a user