mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): instantiate SqliteMigrator with a SqliteDatabase
Simplifies a couple things: - Init is more straightforward - It's clear in the migrator that the connection we are working with is related to the SqliteDatabase
This commit is contained in:
@ -1,11 +1,10 @@
|
||||
import shutil
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationError, MigrationSet
|
||||
|
||||
|
||||
@ -30,26 +29,15 @@ class SQLiteMigrator:
|
||||
|
||||
backup_path: Optional[Path] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: Path | None,
|
||||
conn: sqlite3.Connection,
|
||||
lock: threading.RLock,
|
||||
logger: Logger,
|
||||
log_sql: bool = False,
|
||||
) -> None:
|
||||
self._lock = lock
|
||||
self._db_path = db_path
|
||||
self._logger = logger
|
||||
self._conn = conn
|
||||
self._log_sql = log_sql
|
||||
self._cursor = self._conn.cursor()
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
self._db = db
|
||||
self._logger = db.logger
|
||||
self._migration_set = MigrationSet()
|
||||
|
||||
# The presence of an temp database file indicates a catastrophic failure of a previous migration.
|
||||
if self._db_path and self._get_temp_db_path(self._db_path).is_file():
|
||||
if self._db.db_path and self._get_temp_db_path(self._db.db_path).is_file():
|
||||
self._logger.warning("Previous migration failed! Trying again...")
|
||||
self._get_temp_db_path(self._db_path).unlink()
|
||||
self._get_temp_db_path(self._db.db_path).unlink()
|
||||
|
||||
def register_migration(self, migration: Migration) -> None:
|
||||
"""Registers a migration."""
|
||||
@ -58,43 +46,41 @@ class SQLiteMigrator:
|
||||
|
||||
def run_migrations(self) -> bool:
|
||||
"""Migrates the database to the latest version."""
|
||||
with self._lock:
|
||||
with self._db.lock:
|
||||
# This throws if there is a problem.
|
||||
self._migration_set.validate_migration_chain()
|
||||
self._create_migrations_table(cursor=self._cursor)
|
||||
cursor = self._db.conn.cursor()
|
||||
self._create_migrations_table(cursor=cursor)
|
||||
|
||||
if self._migration_set.count == 0:
|
||||
self._logger.debug("No migrations registered")
|
||||
return False
|
||||
|
||||
if self._get_current_version(self._cursor) == self._migration_set.latest_version:
|
||||
if self._get_current_version(cursor=cursor) == self._migration_set.latest_version:
|
||||
self._logger.debug("Database is up to date, no migrations to run")
|
||||
return False
|
||||
|
||||
self._logger.info("Database update needed")
|
||||
|
||||
if self._db_path:
|
||||
if self._db.db_path:
|
||||
# We are using a file database. Create a copy of the database to run the migrations on.
|
||||
temp_db_path = self._create_temp_db(self._db_path)
|
||||
temp_db_path = self._create_temp_db(self._db.db_path)
|
||||
self._logger.info(f"Copied database to {temp_db_path} for migration")
|
||||
temp_db_conn = sqlite3.connect(temp_db_path)
|
||||
# We have to re-set this because we just created a new connection.
|
||||
if self._log_sql:
|
||||
temp_db_conn.set_trace_callback(self._logger.debug)
|
||||
temp_db_cursor = temp_db_conn.cursor()
|
||||
temp_db = SqliteDatabase(db_path=temp_db_path, logger=self._logger, verbose=self._db.verbose)
|
||||
temp_db_cursor = temp_db.conn.cursor()
|
||||
self._run_migrations(temp_db_cursor)
|
||||
# Close the connections, copy the original database as a backup, and move the temp database to the
|
||||
# original database's path.
|
||||
temp_db_conn.close()
|
||||
self._conn.close()
|
||||
temp_db.close()
|
||||
self._db.close()
|
||||
backup_db_path = self._finalize_migration(
|
||||
temp_db_path=temp_db_path,
|
||||
original_db_path=self._db_path,
|
||||
original_db_path=self._db.db_path,
|
||||
)
|
||||
self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}")
|
||||
else:
|
||||
# We are using a memory database. No special backup or special handling needed.
|
||||
self._run_migrations(self._cursor)
|
||||
self._run_migrations(cursor)
|
||||
|
||||
self._logger.info("Database updated successfully")
|
||||
return True
|
||||
@ -108,7 +94,7 @@ class SQLiteMigrator:
|
||||
|
||||
def _run_migration(self, migration: Migration, temp_db_cursor: sqlite3.Cursor) -> None:
|
||||
"""Runs a single migration."""
|
||||
with self._lock:
|
||||
with self._db.lock:
|
||||
try:
|
||||
if self._get_current_version(temp_db_cursor) != migration.from_version:
|
||||
raise MigrationError(
|
||||
@ -145,7 +131,7 @@ class SQLiteMigrator:
|
||||
|
||||
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the migrations table for the database, if one does not already exist."""
|
||||
with self._lock:
|
||||
with self._db.lock:
|
||||
try:
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
|
||||
if cursor.fetchone() is not None:
|
||||
|
Reference in New Issue
Block a user