feat(db): move sqlite_migrator into its own module

This commit is contained in:
psychedelicious
2023-12-11 16:41:30 +11:00
parent fa7d002175
commit 290851016e
8 changed files with 117 additions and 113 deletions

View File

@ -0,0 +1,374 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
def _migrate(cursor: sqlite3.Cursor) -> None:
"""Migration callback for database version 1."""
_create_board_images(cursor)
_create_boards(cursor)
_create_images(cursor)
_create_model_config(cursor)
_create_session_queue(cursor)
_create_workflow_images(cursor)
_create_workflows(cursor)
def _create_board_images(cursor: sqlite3.Cursor) -> None:
"""Creates the `board_images` table, indices and triggers."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS board_images (
board_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between boards and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
]
indices = [
"CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);",
"CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
AFTER UPDATE
ON board_images FOR EACH ROW
BEGIN
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE board_id = old.board_id AND image_name = old.image_name;
END;
"""
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_boards(cursor: sqlite3.Cursor) -> None:
"""Creates the `boards` table, indices and triggers."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS boards (
board_id TEXT NOT NULL PRIMARY KEY,
board_name TEXT NOT NULL,
cover_image_name TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
);
"""
]
indices = ["CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);"]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
AFTER UPDATE
ON boards FOR EACH ROW
BEGIN
UPDATE boards SET updated_at = current_timestamp
WHERE board_id = old.board_id;
END;
"""
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_images(cursor: sqlite3.Cursor) -> None:
"""Creates the `images` table, indices and triggers. Adds the `starred` column."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS images (
image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility
image_origin TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
session_id TEXT,
node_id TEXT,
metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
]
indices = [
"CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);",
"CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);",
"CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);",
"CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
AFTER UPDATE
ON images FOR EACH ROW
BEGIN
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE image_name = old.image_name;
END;
"""
]
# Add the 'starred' column to `images` if it doesn't exist
cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in cursor.fetchall()]
if "starred" not in columns:
tables.append("ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;")
indices.append("CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);")
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_model_config(cursor: sqlite3.Cursor) -> None:
"""Creates the `model_config` table, `model_manager_metadata` table, indices and triggers."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT NOT NULL,
type TEXT NOT NULL,
name TEXT NOT NULL,
path TEXT NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
""",
"""--sql
CREATE TABLE IF NOT EXISTS model_manager_metadata (
metadata_key TEXT NOT NULL PRIMARY KEY,
metadata_value TEXT NOT NULL
);
""",
]
# Add trigger for `updated_at`.
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
AFTER UPDATE
ON model_config FOR EACH ROW
BEGIN
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
]
# Add indexes for searchable fields
indices = [
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_session_queue(cursor: sqlite3.Cursor) -> None:
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS session_queue (
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
field_values TEXT, -- NULL if no values are associated with this queue item
session TEXT NOT NULL, -- the session to be executed
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
error TEXT, -- any errors associated with this queue item
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
started_at DATETIME, -- updated via trigger
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
);
"""
]
indices = [
"CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);",
"CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);",
"CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);",
"CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);",
"CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'completed'
OR NEW.status = 'failed'
OR NEW.status = 'canceled'
BEGIN
UPDATE session_queue
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
""",
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'in_progress'
BEGIN
UPDATE session_queue
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
""",
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
AFTER UPDATE
ON session_queue FOR EACH ROW
BEGIN
UPDATE session_queue
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = old.item_id;
END;
""",
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_workflow_images(cursor: sqlite3.Cursor) -> None:
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS workflow_images (
workflow_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between workflows and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
]
indices = [
"CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id ON workflow_images (workflow_id);",
"CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id_created_at ON workflow_images (workflow_id, created_at);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflow_images_updated_at
AFTER UPDATE
ON workflow_images FOR EACH ROW
BEGIN
UPDATE workflow_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id AND image_name = old.image_name;
END;
"""
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_workflows(cursor: sqlite3.Cursor) -> None:
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS workflows (
workflow TEXT NOT NULL,
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
);
"""
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at
AFTER UPDATE
ON workflows FOR EACH ROW
BEGIN
UPDATE workflows
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id;
END;
"""
]
for stmt in tables + triggers:
cursor.execute(stmt)
migration_1 = Migration(
from_version=0,
to_version=1,
migrate=_migrate,
)
"""
Database version 1 (initial state).
This migration represents the state of the database circa InvokeAI v3.4.0, which was the last
version to not use migrations to manage the database.
As such, this migration does include some ALTER statements, and the SQL statements are written
to be idempotent.
- Create `board_images` junction table
- Create `boards` table
- Create `images` table, add `starred` column
- Create `model_config` table
- Create `session_queue` table
- Create `workflow_images` junction table
- Create `workflows` table
"""

View File

@ -0,0 +1,107 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
def _migrate(cursor: sqlite3.Cursor) -> None:
"""Migration callback for database version 2."""
_add_images_has_workflow(cursor)
_add_session_queue_workflow(cursor)
_drop_old_workflow_tables(cursor)
_add_workflow_library(cursor)
_drop_model_manager_metadata(cursor)
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
"""Add the `has_workflow` column to `images` table."""
cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in cursor.fetchall()]
if "has_workflow" not in columns:
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
def _add_session_queue_workflow(cursor: sqlite3.Cursor) -> None:
"""Add the `workflow` column to `session_queue` table."""
cursor.execute("PRAGMA table_info(session_queue)")
columns = [column[1] for column in cursor.fetchall()]
if "workflow" not in columns:
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
def _drop_old_workflow_tables(cursor: sqlite3.Cursor) -> None:
"""Drops the `workflows` and `workflow_images` tables."""
cursor.execute("DROP TABLE IF EXISTS workflow_images;")
cursor.execute("DROP TABLE IF EXISTS workflows;")
def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
"""Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS workflow_library (
workflow_id TEXT NOT NULL PRIMARY KEY,
workflow TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated manually when retrieving workflow
opened_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Generated columns, needed for indexing and searching
category TEXT GENERATED ALWAYS as (json_extract(workflow, '$.meta.category')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(workflow, '$.name')) VIRTUAL NOT NULL,
description TEXT GENERATED ALWAYS as (json_extract(workflow, '$.description')) VIRTUAL NOT NULL
);
""",
]
indices = [
"CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at
AFTER UPDATE
ON workflow_library FOR EACH ROW
BEGIN
UPDATE workflow_library
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id;
END;
"""
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
"""Drops the `model_manager_metadata` table."""
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
migration_2 = Migration(
from_version=1,
to_version=2,
migrate=_migrate,
)
"""
Database version 2.
Introduced in v3.5.0 for the new workflow library.
Migration:
- Add `has_workflow` column to `images` table
- Add `workflow` column to `session_queue` table
- Drop `workflows` and `workflow_images` tables
- Add `workflow_library` table
"""

View File

@ -0,0 +1,41 @@
import sqlite3
from logging import Logger
from tqdm import tqdm
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
def migrate_embedded_workflows(
cursor: sqlite3.Cursor,
logger: Logger,
image_files: ImageFileStorageBase,
) -> None:
"""
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
This migrate callbakc checks each image for the presence of an embedded workflow, then updates its entry
in the database accordingly.
"""
# Get the total number of images and chunk it into pages
cursor.execute("SELECT image_name FROM images")
image_names: list[str] = [image[0] for image in cursor.fetchall()]
total_image_names = len(image_names)
if not total_image_names:
return
logger.info(f"Migrating workflows for {total_image_names} images")
# Migrate the images
to_migrate: list[tuple[bool, str]] = []
pbar = tqdm(image_names)
for idx, image_name in enumerate(pbar):
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
pil_image = image_files.get(image_name)
if "invokeai_workflow" in pil_image.info:
to_migrate.append((True, image_name))
logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)

View File

@ -0,0 +1,109 @@
import sqlite3
from typing import Callable, Optional, TypeAlias
from pydantic import BaseModel, Field, model_validator
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
class MigrationError(RuntimeError):
"""Raised when a migration fails."""
class MigrationVersionError(ValueError):
"""Raised when a migration version is invalid."""
class Migration(BaseModel):
"""
Represents a migration for a SQLite database.
Migration callbacks will be provided an open cursor to the database. They should not commit their
transaction; this is handled by the migrator.
Pre- and post-migration callback may be registered with :meth:`register_pre_callback` or
:meth:`register_post_callback`.
If a migration has additional dependencies, it is recommended to use functools.partial to provide
the dependencies and register the partial as the migration callback.
"""
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration")
migrate: MigrateCallback = Field(description="The callback to run to perform the migration")
pre_migrate: list[MigrateCallback] = Field(
default=[], description="A list of callbacks to run before the migration"
)
post_migrate: list[MigrateCallback] = Field(
default=[], description="A list of callbacks to run after the migration"
)
@model_validator(mode="after")
def validate_to_version(self) -> "Migration":
if self.to_version <= self.from_version:
raise ValueError("to_version must be greater than from_version")
return self
def __hash__(self) -> int:
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
return hash((self.from_version, self.to_version))
def register_pre_callback(self, callback: MigrateCallback) -> None:
"""Registers a pre-migration callback."""
self.pre_migrate.append(callback)
def register_post_callback(self, callback: MigrateCallback) -> None:
"""Registers a post-migration callback."""
self.post_migrate.append(callback)
class MigrationSet:
"""A set of Migrations. Performs validation during migration registration and provides utility methods."""
def __init__(self) -> None:
self._migrations: set[Migration] = set()
def register(self, migration: Migration) -> None:
"""Registers a migration."""
if any(m.from_version == migration.from_version for m in self._migrations):
raise MigrationVersionError(f"Migration from {migration.from_version} already registered")
if any(m.to_version == migration.to_version for m in self._migrations):
raise MigrationVersionError(f"Migration to {migration.to_version} already registered")
self._migrations.add(migration)
def get(self, from_version: int) -> Optional[Migration]:
"""Gets the migration that may be run on the given database version."""
# register() ensures that there is only one migration with a given from_version, so this is safe.
return next((m for m in self._migrations if m.from_version == from_version), None)
def validate_migration_chain(self) -> None:
"""
Validates that the migrations form a single chain of migrations from version 0 to the latest version.
Raises a MigrationError if there is a problem.
"""
if self.count == 0:
return
if self.latest_version == 0:
return
next_migration = self.get(from_version=0)
if next_migration is None:
raise MigrationError("Migration chain is fragmented")
touched_count = 1
while next_migration is not None:
next_migration = self.get(next_migration.to_version)
if next_migration is not None:
touched_count += 1
if touched_count != self.count:
raise MigrationError("Migration chain is fragmented")
@property
def count(self) -> int:
"""The count of registered migrations."""
return len(self._migrations)
@property
def latest_version(self) -> int:
"""Gets latest to_version among registered migrations. Returns 0 if there are no migrations registered."""
if self.count == 0:
return 0
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version

View File

@ -0,0 +1,212 @@
import shutil
import sqlite3
import threading
from datetime import datetime
from logging import Logger
from pathlib import Path
from typing import Optional
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationError, MigrationSet
class SQLiteMigrator:
"""
Manages migrations for a SQLite database.
:param db_path: The path to the database to migrate, or None if using an in-memory database.
:param conn: The connection to the database.
:param lock: A lock to use when running migrations.
:param logger: A logger to use for logging.
:param log_sql: Whether to log SQL statements. Only used when the log level is set to debug.
Migrations should be registered with :meth:`register_migration`.
During migration, a copy of the current database is made and the migrations are run on the copy. If the migration
is successful, the original database is backed up and the migrated database is moved to the original database's
path. If the migration fails, the original database is left untouched and the migrated database is deleted.
If the database is in-memory, no backup is made; the migration is run in-place.
"""
backup_path: Optional[Path] = None
def __init__(
self,
db_path: Path | None,
conn: sqlite3.Connection,
lock: threading.RLock,
logger: Logger,
log_sql: bool = False,
) -> None:
self._lock = lock
self._db_path = db_path
self._logger = logger
self._conn = conn
self._log_sql = log_sql
self._cursor = self._conn.cursor()
self._migration_set = MigrationSet()
# The presence of an temp database file indicates a catastrophic failure of a previous migration.
if self._db_path and self._get_temp_db_path(self._db_path).is_file():
self._logger.warning("Previous migration failed! Trying again...")
self._get_temp_db_path(self._db_path).unlink()
def register_migration(self, migration: Migration) -> None:
"""Registers a migration."""
self._migration_set.register(migration)
self._logger.debug(f"Registered migration {migration.from_version} -> {migration.to_version}")
def run_migrations(self) -> bool:
"""Migrates the database to the latest version."""
with self._lock:
# This throws if there is a problem.
self._migration_set.validate_migration_chain()
self._create_migrations_table(cursor=self._cursor)
if self._migration_set.count == 0:
self._logger.debug("No migrations registered")
return False
if self._get_current_version(self._cursor) == self._migration_set.latest_version:
self._logger.debug("Database is up to date, no migrations to run")
return False
self._logger.info("Database update needed")
if self._db_path:
# We are using a file database. Create a copy of the database to run the migrations on.
temp_db_path = self._create_temp_db(self._db_path)
self._logger.info(f"Copied database to {temp_db_path} for migration")
temp_db_conn = sqlite3.connect(temp_db_path)
# We have to re-set this because we just created a new connection.
if self._log_sql:
temp_db_conn.set_trace_callback(self._logger.debug)
temp_db_cursor = temp_db_conn.cursor()
self._run_migrations(temp_db_cursor)
# Close the connections, copy the original database as a backup, and move the temp database to the
# original database's path.
temp_db_conn.close()
self._conn.close()
backup_db_path = self._finalize_migration(
temp_db_path=temp_db_path,
original_db_path=self._db_path,
)
self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}")
else:
# We are using a memory database. No special backup or special handling needed.
self._run_migrations(self._cursor)
self._logger.info("Database updated successfully")
return True
def _run_migrations(self, temp_db_cursor: sqlite3.Cursor) -> None:
"""Runs all migrations in a loop."""
next_migration = self._migration_set.get(from_version=self._get_current_version(temp_db_cursor))
while next_migration is not None:
self._run_migration(next_migration, temp_db_cursor)
next_migration = self._migration_set.get(self._get_current_version(temp_db_cursor))
def _run_migration(self, migration: Migration, temp_db_cursor: sqlite3.Cursor) -> None:
"""Runs a single migration."""
with self._lock:
try:
if self._get_current_version(temp_db_cursor) != migration.from_version:
raise MigrationError(
f"Database is at version {self._get_current_version(temp_db_cursor)}, expected {migration.from_version}"
)
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
# Run pre-migration callbacks
if migration.pre_migrate:
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
for callback in migration.pre_migrate:
callback(temp_db_cursor)
# Run the actual migration
migration.migrate(temp_db_cursor)
temp_db_cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
# Run post-migration callbacks
if migration.post_migrate:
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
for callback in migration.post_migrate:
callback(temp_db_cursor)
# Migration callbacks only get a cursor. Commit this migration.
temp_db_cursor.connection.commit()
self._logger.debug(
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
)
except Exception as e:
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
temp_db_cursor.connection.rollback()
self._logger.error(msg)
raise MigrationError(msg) from e
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
"""Creates the migrations table for the database, if one does not already exist."""
with self._lock:
try:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
if cursor.fetchone() is not None:
return
cursor.execute(
"""--sql
CREATE TABLE migrations (
version INTEGER PRIMARY KEY,
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
)
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
cursor.connection.commit()
self._logger.debug("Created migrations table")
except sqlite3.Error as e:
msg = f"Problem creating migrations table: {e}"
self._logger.error(msg)
cursor.connection.rollback()
raise MigrationError(msg) from e
@classmethod
def _get_current_version(cls, cursor: sqlite3.Cursor) -> int:
"""Gets the current version of the database, or 0 if the migrations table does not exist."""
try:
cursor.execute("SELECT MAX(version) FROM migrations;")
version: int = cursor.fetchone()[0]
if version is None:
return 0
return version
except sqlite3.OperationalError as e:
if "no such table" in str(e):
return 0
raise
@classmethod
def _create_temp_db(cls, original_db_path: Path) -> Path:
"""Copies the current database to a new file for migration."""
temp_db_path = cls._get_temp_db_path(original_db_path)
shutil.copy2(original_db_path, temp_db_path)
return temp_db_path
@classmethod
def _finalize_migration(
cls,
temp_db_path: Path,
original_db_path: Path,
) -> Path:
"""Renames the original database as a backup and renames the migrated database to the original name."""
backup_db_path = cls._get_backup_db_path(original_db_path)
original_db_path.rename(backup_db_path)
temp_db_path.rename(original_db_path)
return backup_db_path
@classmethod
def _get_temp_db_path(cls, original_db_path: Path) -> Path:
"""Gets the path to the temp database."""
return original_db_path.parent / original_db_path.name.replace(".db", ".db.temp")
@classmethod
def _get_backup_db_path(cls, original_db_path: Path) -> Path:
"""Gets the path to the final backup database."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return original_db_path.parent / f"{original_db_path.stem}_backup_{timestamp}.db"