feat(db): add SQLiteMigrator to perform db migrations

This commit is contained in:
psychedelicious 2023-12-04 23:38:56 +11:00
parent ef807cf63a
commit f2c6819d68
13 changed files with 868 additions and 516 deletions

View File

@ -20,63 +20,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
self._conn = db.conn self._conn = db.conn
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
try:
self._lock.acquire()
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `board_images` junction table."""
# Create the `board_images` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS board_images (
board_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between boards and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
)
# Add index for board id
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);
"""
)
# Add index for board id, sorted by created_at
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
AFTER UPDATE
ON board_images FOR EACH ROW
BEGIN
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE board_id = old.board_id AND image_name = old.image_name;
END;
"""
)
def add_image_to_board( def add_image_to_board(
self, self,
board_id: str, board_id: str,

View File

@ -28,52 +28,6 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
self._conn = db.conn self._conn = db.conn
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
try:
self._lock.acquire()
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `boards` table and `board_images` junction table."""
# Create the `boards` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS boards (
board_id TEXT NOT NULL PRIMARY KEY,
board_name TEXT NOT NULL,
cover_image_name TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
AFTER UPDATE
ON boards FOR EACH ROW
BEGIN
UPDATE boards SET updated_at = current_timestamp
WHERE board_id = old.board_id;
END;
"""
)
def delete(self, board_id: str) -> None: def delete(self, board_id: str) -> None:
try: try:
self._lock.acquire() self._lock.acquire()

View File

@ -32,101 +32,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._conn = db.conn self._conn = db.conn
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
try:
self._lock.acquire()
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `images` table."""
# Create the `images` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images (
image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility
image_origin TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
session_id TEXT,
node_id TEXT,
metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
)
self._cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in self._cursor.fetchall()]
if "starred" not in columns:
self._cursor.execute(
"""--sql
ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;
"""
)
# Create the `images` table indices.
self._cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
AFTER UPDATE
ON images FOR EACH ROW
BEGIN
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE image_name = old.image_name;
END;
"""
)
self._cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in self._cursor.fetchall()]
if "has_workflow" not in columns:
self._cursor.execute(
"""--sql
ALTER TABLE images
ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;
"""
)
def get(self, image_name: str) -> ImageRecord: def get(self, image_name: str) -> ImageRecord:
try: try:
self._lock.acquire() self._lock.acquire()

View File

@ -9,9 +9,6 @@ from typing import List, Optional, Union
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
# should match the InvokeAI version when this is first released.
CONFIG_FILE_VERSION = "3.2.0"
class DuplicateModelException(Exception): class DuplicateModelException(Exception):
"""Raised on an attempt to add a model with the same key twice.""" """Raised on an attempt to add a model with the same key twice."""
@ -32,12 +29,6 @@ class ConfigFileVersionMismatchException(Exception):
class ModelRecordServiceBase(ABC): class ModelRecordServiceBase(ABC):
"""Abstract base class for storage and retrieval of model configs.""" """Abstract base class for storage and retrieval of model configs."""
@property
@abstractmethod
def version(self) -> str:
"""Return the config file/database schema version."""
pass
@abstractmethod @abstractmethod
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig: def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
""" """

View File

@ -54,7 +54,6 @@ from invokeai.backend.model_manager.config import (
from ..shared.sqlite.sqlite_database import SqliteDatabase from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import ( from .model_records_base import (
CONFIG_FILE_VERSION,
DuplicateModelException, DuplicateModelException,
ModelRecordServiceBase, ModelRecordServiceBase,
UnknownModelException, UnknownModelException,
@ -78,85 +77,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db = db self._db = db
self._cursor = self._db.conn.cursor() self._cursor = self._db.conn.cursor()
with self._db.lock:
# Enable foreign keys
self._db.conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._db.conn.commit()
assert (
str(self.version) == CONFIG_FILE_VERSION
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
def _create_tables(self) -> None:
"""Create sqlite3 tables."""
# model_config table breaks out the fields that are common to all config objects
# and puts class-specific ones in a serialized json object
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT NOT NULL,
type TEXT NOT NULL,
name TEXT NOT NULL,
path TEXT NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
)
# metadata table
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_manager_metadata (
metadata_key TEXT NOT NULL PRIMARY KEY,
metadata_value TEXT NOT NULL
);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
AFTER UPDATE
ON model_config FOR EACH ROW
BEGIN
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
)
# Add indexes for searchable fields
for stmt in [
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
]:
self._cursor.execute(stmt)
# Add our version to the metadata table
self._cursor.execute(
"""--sql
INSERT OR IGNORE into model_manager_metadata (
metadata_key,
metadata_value
)
VALUES (?,?);
""",
("version", CONFIG_FILE_VERSION),
)
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig: def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
""" """
Add a model to the database. Add a model to the database.
@ -214,22 +134,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
return self.get_model(key) return self.get_model(key)
@property
def version(self) -> str:
"""Return the version of the database schema."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata_value FROM model_manager_metadata
WHERE metadata_key=?;
""",
("version",),
)
rows = self._cursor.fetchone()
if not rows:
raise KeyError("Models database does not have metadata key 'version'")
return rows[0]
def del_model(self, key: str) -> None: def del_model(self, key: str) -> None:
""" """
Delete a model. Delete a model.

View File

@ -50,7 +50,6 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock = db.lock self.__lock = db.lock
self.__conn = db.conn self.__conn = db.conn
self.__cursor = self.__conn.cursor() self.__cursor = self.__conn.cursor()
self._create_tables()
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool: def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
return event[1]["event"] in match_in return event[1]["event"] in match_in
@ -98,123 +97,6 @@ class SqliteSessionQueue(SessionQueueBase):
except SessionQueueItemNotFoundError: except SessionQueueItemNotFoundError:
return return
def _create_tables(self) -> None:
"""Creates the session queue tables, indicies, and triggers"""
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS session_queue (
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
field_values TEXT, -- NULL if no values are associated with this queue item
session TEXT NOT NULL, -- the session to be executed
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
error TEXT, -- any errors associated with this queue item
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
started_at DATETIME, -- updated via trigger
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
);
"""
)
self.__cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'completed'
OR NEW.status = 'failed'
OR NEW.status = 'canceled'
BEGIN
UPDATE session_queue
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'in_progress'
BEGIN
UPDATE session_queue
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
AFTER UPDATE
ON session_queue FOR EACH ROW
BEGIN
UPDATE session_queue
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = old.item_id;
END;
"""
)
self.__cursor.execute("PRAGMA table_info(session_queue)")
columns = [column[1] for column in self.__cursor.fetchall()]
if "workflow" not in columns:
self.__cursor.execute(
"""--sql
ALTER TABLE session_queue ADD COLUMN workflow TEXT;
"""
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
def _set_in_progress_to_canceled(self) -> None: def _set_in_progress_to_canceled(self) -> None:
""" """
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue. Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.

View File

@ -0,0 +1,370 @@
import sqlite3
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
def _migrate(cursor: sqlite3.Cursor) -> None:
"""Migration callback for database version 1."""
_create_board_images(cursor)
_create_boards(cursor)
_create_images(cursor)
_create_model_config(cursor)
_create_session_queue(cursor)
_create_workflow_images(cursor)
_create_workflows(cursor)
def _create_board_images(cursor: sqlite3.Cursor) -> None:
"""Creates the `board_images` table, indices and triggers."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS board_images (
board_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between boards and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
]
indices = [
"CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);",
"CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
AFTER UPDATE
ON board_images FOR EACH ROW
BEGIN
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE board_id = old.board_id AND image_name = old.image_name;
END;
"""
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_boards(cursor: sqlite3.Cursor) -> None:
"""Creates the `boards` table, indices and triggers."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS boards (
board_id TEXT NOT NULL PRIMARY KEY,
board_name TEXT NOT NULL,
cover_image_name TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
);
"""
]
indices = ["CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);"]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
AFTER UPDATE
ON boards FOR EACH ROW
BEGIN
UPDATE boards SET updated_at = current_timestamp
WHERE board_id = old.board_id;
END;
"""
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_images(cursor: sqlite3.Cursor) -> None:
"""Creates the `images` table, indices and triggers. Adds the `starred` column."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS images (
image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility
image_origin TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
session_id TEXT,
node_id TEXT,
metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
]
indices = [
"CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);",
"CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);",
"CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);",
"CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
AFTER UPDATE
ON images FOR EACH ROW
BEGIN
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE image_name = old.image_name;
END;
"""
]
# Add the 'starred' column to `images` if it doesn't exist
cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in cursor.fetchall()]
if "starred" not in columns:
tables.append("ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;")
indices.append("CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);")
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_model_config(cursor: sqlite3.Cursor) -> None:
"""Creates the `model_config` table, `model_manager_metadata` table, indices and triggers."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT NOT NULL,
type TEXT NOT NULL,
name TEXT NOT NULL,
path TEXT NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
""",
"""--sql
CREATE TABLE IF NOT EXISTS model_manager_metadata (
metadata_key TEXT NOT NULL PRIMARY KEY,
metadata_value TEXT NOT NULL
);
""",
]
# Add trigger for `updated_at`.
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
AFTER UPDATE
ON model_config FOR EACH ROW
BEGIN
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
]
# Add indexes for searchable fields
indices = [
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_session_queue(cursor: sqlite3.Cursor) -> None:
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS session_queue (
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
field_values TEXT, -- NULL if no values are associated with this queue item
session TEXT NOT NULL, -- the session to be executed
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
error TEXT, -- any errors associated with this queue item
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
started_at DATETIME, -- updated via trigger
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
);
"""
]
indices = [
"CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);",
"CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);",
"CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);",
"CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);",
"CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'completed'
OR NEW.status = 'failed'
OR NEW.status = 'canceled'
BEGIN
UPDATE session_queue
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
""",
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'in_progress'
BEGIN
UPDATE session_queue
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
""",
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
AFTER UPDATE
ON session_queue FOR EACH ROW
BEGIN
UPDATE session_queue
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = old.item_id;
END;
""",
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_workflow_images(cursor: sqlite3.Cursor) -> None:
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS workflow_images (
workflow_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between workflows and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
]
indices = [
"CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id ON workflow_images (workflow_id);",
"CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id_created_at ON workflow_images (workflow_id, created_at);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflow_images_updated_at
AFTER UPDATE
ON workflow_images FOR EACH ROW
BEGIN
UPDATE workflow_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id AND image_name = old.image_name;
END;
"""
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_workflows(cursor: sqlite3.Cursor) -> None:
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS workflows (
workflow TEXT NOT NULL,
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
);
"""
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at
AFTER UPDATE
ON workflows FOR EACH ROW
BEGIN
UPDATE workflows
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id;
END;
"""
]
for stmt in tables + triggers:
cursor.execute(stmt)
migration_1 = Migration(db_version=1, app_version="3.4.0", migrate=_migrate)
"""
Database version 1 (initial state).
This migration represents the state of the database circa InvokeAI v3.4.0, which was the last
version to not use migrations to manage the database.
As such, this migration does include some ALTER statements, and the SQL statements are written
to be idempotent.
- Create `board_images` junction table
- Create `boards` table
- Create `images` table, add `starred` column
- Create `model_config` table
- Create `session_queue` table
- Create `workflow_images` junction table
- Create `workflows` table
"""

View File

@ -0,0 +1,97 @@
import sqlite3
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
def _migrate(cursor: sqlite3.Cursor) -> None:
"""Migration callback for database version 2."""
_add_images_has_workflow(cursor)
_add_session_queue_workflow(cursor)
_drop_old_workflow_tables(cursor)
_add_workflow_library(cursor)
_drop_model_manager_metadata(cursor)
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
"""Add the `has_workflow` column to `images` table."""
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
def _add_session_queue_workflow(cursor: sqlite3.Cursor) -> None:
"""Add the `workflow` column to `session_queue` table."""
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
def _drop_old_workflow_tables(cursor: sqlite3.Cursor) -> None:
"""Drops the `workflows` and `workflow_images` tables."""
cursor.execute("DROP TABLE workflow_images;")
cursor.execute("DROP TABLE workflows;")
def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
"""Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables."""
tables = [
"""--sql
CREATE TABLE workflow_library (
workflow_id TEXT NOT NULL PRIMARY KEY,
workflow TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated manually when retrieving workflow
opened_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Generated columns, needed for indexing and searching
category TEXT GENERATED ALWAYS as (json_extract(workflow, '$.meta.category')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(workflow, '$.name')) VIRTUAL NOT NULL,
description TEXT GENERATED ALWAYS as (json_extract(workflow, '$.description')) VIRTUAL NOT NULL
);
""",
]
indices = [
"CREATE INDEX idx_workflow_library_created_at ON workflow_library(created_at);",
"CREATE INDEX idx_workflow_library_updated_at ON workflow_library(updated_at);",
"CREATE INDEX idx_workflow_library_opened_at ON workflow_library(opened_at);",
"CREATE INDEX idx_workflow_library_category ON workflow_library(category);",
"CREATE INDEX idx_workflow_library_name ON workflow_library(name);",
"CREATE INDEX idx_workflow_library_description ON workflow_library(description);",
]
triggers = [
"""--sql
CREATE TRIGGER tg_workflow_library_updated_at
AFTER UPDATE
ON workflow_library FOR EACH ROW
BEGIN
UPDATE workflow_library
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id;
END;
"""
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
"""Drops the `model_manager_metadata` table."""
cursor.execute("DROP TABLE model_manager_metadata;")
migration_2 = Migration(
db_version=2,
app_version="3.5.0",
migrate=_migrate,
)
"""
Database version 2.
Introduced in v3.5.0 for the new workflow library.
- Add `has_workflow` column to `images` table
- Add `workflow` column to `session_queue` table
- Drop `workflows` and `workflow_images` tables
- Add `workflow_library` table
"""

View File

@ -4,24 +4,27 @@ from logging import Logger
from pathlib import Path from pathlib import Path
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
from invokeai.app.services.shared.sqlite.sqlite_migrator import SQLiteMigrator
class SqliteDatabase: class SqliteDatabase:
database: Path | str
def __init__(self, config: InvokeAIAppConfig, logger: Logger): def __init__(self, config: InvokeAIAppConfig, logger: Logger):
self._logger = logger self._logger = logger
self._config = config self._config = config
if self._config.use_memory_db: if self._config.use_memory_db:
self.db_path = sqlite_memory self.database = sqlite_memory
logger.info("Using in-memory database") logger.info("Using in-memory database")
else: else:
db_path = self._config.db_path self.database = self._config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True) self.database.parent.mkdir(parents=True, exist_ok=True)
self.db_path = str(db_path) self._logger.info(f"Using database at {self.database}")
self._logger.info(f"Using database at {self.db_path}")
self.conn = sqlite3.connect(self.db_path, check_same_thread=False) self.conn = sqlite3.connect(database=self.database, check_same_thread=False)
self.lock = threading.RLock() self.lock = threading.RLock()
self.conn.row_factory = sqlite3.Row self.conn.row_factory = sqlite3.Row
@ -30,15 +33,20 @@ class SqliteDatabase:
self.conn.execute("PRAGMA foreign_keys = ON;") self.conn.execute("PRAGMA foreign_keys = ON;")
migrator = SQLiteMigrator(conn=self.conn, database=self.database, lock=self.lock, logger=self._logger)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
def clean(self) -> None: def clean(self) -> None:
with self.lock: with self.lock:
try: try:
if self.db_path == sqlite_memory: if self.database == sqlite_memory:
return return
initial_db_size = Path(self.db_path).stat().st_size initial_db_size = Path(self.database).stat().st_size
self.conn.execute("VACUUM;") self.conn.execute("VACUUM;")
self.conn.commit() self.conn.commit()
final_db_size = Path(self.db_path).stat().st_size final_db_size = Path(self.database).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2) freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0: if freed_space_in_mb > 0:
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)") self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")

View File

@ -0,0 +1,210 @@
import shutil
import sqlite3
import threading
from datetime import datetime
from logging import Logger
from pathlib import Path
from typing import Callable, Optional, TypeAlias
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
class MigrationError(Exception):
"""Raised when a migration fails."""
class MigrationVersionError(ValueError, MigrationError):
"""Raised when a migration version is invalid."""
class Migration:
"""Represents a migration for a SQLite database.
:param db_version: The database schema version this migration results in.
:param app_version: The app version this migration is introduced in.
:param migrate: The callback to run to perform the migration. The callback will be passed a
cursor to the database. The migrator will manage locking database access and committing the
transaction; the callback should not do either of these things.
"""
def __init__(
self,
db_version: int,
app_version: str,
migrate: MigrateCallback,
) -> None:
self.db_version = db_version
self.app_version = app_version
self.migrate = migrate
class SQLiteMigrator:
"""
Manages migrations for a SQLite database.
:param conn: The database connection.
:param database: The path to the database file, or ":memory:" for an in-memory database.
:param lock: A lock to use when accessing the database.
:param logger: The logger to use.
Migrations should be registered with :meth:`register_migration`. Migrations will be run in
order of their version number. If the database is already at the latest version, no migrations
will be run.
"""
def __init__(self, conn: sqlite3.Connection, database: Path | str, lock: threading.RLock, logger: Logger) -> None:
self._logger = logger
self._conn = conn
self._cursor = self._conn.cursor()
self._lock = lock
self._database = database
self._migrations: set[Migration] = set()
def register_migration(self, migration: Migration) -> None:
"""Registers a migration."""
if not isinstance(migration.db_version, int) or migration.db_version < 1:
raise MigrationVersionError(f"Invalid migration version {migration.db_version}")
if any(m.db_version == migration.db_version for m in self._migrations):
raise MigrationVersionError(f"Migration version {migration.db_version} already registered")
self._migrations.add(migration)
self._logger.debug(f"Registered migration {migration.db_version}")
def run_migrations(self) -> None:
"""Migrates the database to the latest version."""
with self._lock:
self._create_version_table()
sorted_migrations = sorted(self._migrations, key=lambda m: m.db_version)
current_version = self._get_current_version()
if len(sorted_migrations) == 0:
self._logger.debug("No migrations registered")
return
latest_version = sorted_migrations[-1].db_version
if current_version == latest_version:
self._logger.debug("Database is up to date, no migrations to run")
return
if current_version > latest_version:
raise MigrationError(
f"Database version {current_version} is greater than the latest migration version {latest_version}"
)
self._logger.info("Database update needed")
# Only make a backup if using a file database (not memory)
backup_path: Optional[Path] = None
if isinstance(self._database, Path):
backup_path = self._backup_db(self._database)
else:
self._logger.info("Using in-memory database, skipping backup")
for migration in sorted_migrations:
try:
self._run_migration(migration)
except MigrationError:
if backup_path is not None:
self._logger.error(f" Restoring from {backup_path}")
self._restore_db(backup_path)
raise
self._logger.info("Database updated successfully")
def _run_migration(self, migration: Migration) -> None:
"""Runs a single migration."""
with self._lock:
current_version = self._get_current_version()
try:
if current_version >= migration.db_version:
return
migration.migrate(self._cursor)
# Migration callbacks only get a cursor; they cannot commit the transaction.
self._conn.commit()
self._set_version(db_version=migration.db_version, app_version=migration.app_version)
self._logger.debug(f"Successfully migrated database from {current_version} to {migration.db_version}")
except Exception as e:
msg = f"Error migrating database from {current_version} to {migration.db_version}: {e}"
self._conn.rollback()
self._logger.error(msg)
raise MigrationError(msg) from e
def _create_version_table(self) -> None:
"""Creates a version table for the database, if one does not already exist."""
with self._lock:
try:
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='version';")
if self._cursor.fetchone() is not None:
return
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS version (
db_version INTEGER PRIMARY KEY,
app_version TEXT NOT NULL,
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
)
self._cursor.execute("INSERT INTO version (db_version, app_version) VALUES (?,?);", (0, "0.0.0"))
self._conn.commit()
self._logger.debug("Created version table")
except sqlite3.Error as e:
msg = f"Problem creation version table: {e}"
self._logger.error(msg)
self._conn.rollback()
raise MigrationError(msg) from e
def _get_current_version(self) -> int:
"""Gets the current version of the database, or 0 if the version table does not exist."""
with self._lock:
try:
self._cursor.execute("SELECT MAX(db_version) FROM version;")
version = self._cursor.fetchone()[0]
if version is None:
return 0
return version
except sqlite3.OperationalError as e:
if "no such table" in str(e):
return 0
raise
def _set_version(self, db_version: int, app_version: str) -> None:
"""Adds a version entry to the table's version table."""
with self._lock:
try:
self._cursor.execute(
"INSERT INTO version (db_version, app_version) VALUES (?,?);", (db_version, app_version)
)
self._conn.commit()
except sqlite3.Error as e:
msg = f"Problem setting database version: {e}"
self._logger.error(msg)
self._conn.rollback()
raise MigrationError(msg) from e
def _backup_db(self, db_path: Path | str) -> Path:
"""Backs up the databse, returning the path to the backup file."""
if db_path == sqlite_memory:
raise MigrationError("Cannot back up memory database")
if not isinstance(db_path, Path):
raise MigrationError(f'Database path must be "{sqlite_memory}" or a Path')
with self._lock:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = db_path.parent / f"{db_path.stem}_{timestamp}.db"
self._logger.info(f"Backing up database to {backup_path}")
backup_conn = sqlite3.connect(backup_path)
with backup_conn:
self._conn.backup(backup_conn)
backup_conn.close()
return backup_path
def _restore_db(self, backup_path: Path) -> None:
"""Restores the database from a backup file, unless the database is a memory database."""
if self._database == sqlite_memory:
return
with self._lock:
self._logger.info(f"Restoring database from {backup_path}")
self._conn.close()
if not Path(backup_path).is_file():
raise FileNotFoundError(f"Backup file {backup_path} does not exist")
shutil.copy2(backup_path, self._database)

View File

@ -26,7 +26,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
self._lock = db.lock self._lock = db.lock
self._conn = db.conn self._conn = db.conn
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._create_tables()
def start(self, invoker: Invoker) -> None: def start(self, invoker: Invoker) -> None:
self._invoker = invoker self._invoker = invoker
@ -233,87 +232,3 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
raise raise
finally: finally:
self._lock.release() self._lock.release()
def _create_tables(self) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS workflow_library (
workflow_id TEXT NOT NULL PRIMARY KEY,
workflow TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated manually when retrieving workflow
opened_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Generated columns, needed for indexing and searching
category TEXT GENERATED ALWAYS as (json_extract(workflow, '$.meta.category')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(workflow, '$.name')) VIRTUAL NOT NULL,
description TEXT GENERATED ALWAYS as (json_extract(workflow, '$.description')) VIRTUAL NOT NULL
);
"""
)
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at
AFTER UPDATE
ON workflow_library FOR EACH ROW
BEGIN
UPDATE workflow_library
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id;
END;
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description);
"""
)
# We do not need the original `workflows` table or `workflow_images` junction table.
self._cursor.execute(
"""--sql
DROP TABLE IF EXISTS workflow_images;
"""
)
self._cursor.execute(
"""--sql
DROP TABLE IF EXISTS workflows;
"""
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()

View File

@ -0,0 +1,173 @@
import sqlite3
import threading
from copy import deepcopy
from logging import Logger
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Callable
import pytest
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
from invokeai.app.services.shared.sqlite.sqlite_migrator import (
Migration,
MigrationError,
MigrationVersionError,
SQLiteMigrator,
)
@pytest.fixture
def migrator() -> SQLiteMigrator:
conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
return SQLiteMigrator(
conn=conn, database=sqlite_memory, lock=threading.RLock(), logger=Logger("test_sqlite_migrator")
)
@pytest.fixture
def good_migration() -> Migration:
return Migration(db_version=1, app_version="1.0.0", migrate=lambda cursor: None)
@pytest.fixture
def failing_migration() -> Migration:
def failing_migration(cursor: sqlite3.Cursor) -> None:
raise Exception("Bad migration")
return Migration(db_version=1, app_version="1.0.0", migrate=failing_migration)
def test_register_migration(migrator: SQLiteMigrator, good_migration: Migration):
migration = good_migration
migrator.register_migration(migration)
assert migration in migrator._migrations
with pytest.raises(MigrationError, match="Invalid migration version"):
migrator.register_migration(Migration(db_version=0, app_version="0.0.0", migrate=lambda cursor: None))
def test_register_invalid_migration_version(migrator: SQLiteMigrator):
with pytest.raises(MigrationError, match="Invalid migration version"):
migrator.register_migration(Migration(db_version=0, app_version="0.0.0", migrate=lambda cursor: None))
def test_create_version_table(migrator: SQLiteMigrator):
migrator._create_version_table()
migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='version';")
assert migrator._cursor.fetchone() is not None
def test_get_current_version(migrator: SQLiteMigrator):
migrator._create_version_table()
migrator._conn.commit()
assert migrator._get_current_version() == 0 # initial version
def test_set_version(migrator: SQLiteMigrator):
migrator._create_version_table()
migrator._set_version(db_version=1, app_version="1.0.0")
migrator._cursor.execute("SELECT MAX(db_version) FROM version;")
assert migrator._cursor.fetchone()[0] == 1
migrator._cursor.execute("SELECT app_version from version WHERE db_version = 1;")
assert migrator._cursor.fetchone()[0] == "1.0.0"
def test_run_migration(migrator: SQLiteMigrator):
migrator._create_version_table()
def migration_callback(cursor: sqlite3.Cursor) -> None:
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
migration = Migration(db_version=1, app_version="1.0.0", migrate=migration_callback)
migrator._run_migration(migration)
assert migrator._get_current_version() == 1
migrator._cursor.execute("SELECT app_version from version WHERE db_version = 1;")
assert migrator._cursor.fetchone()[0] == "1.0.0"
def test_run_migrations(migrator: SQLiteMigrator):
migrator._create_version_table()
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
def migrate(cursor: sqlite3.Cursor) -> None:
cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);")
return migrate
migrations = [Migration(db_version=i, app_version=f"{i}.0.0", migrate=create_migrate(i)) for i in range(1, 4)]
for migration in migrations:
migrator.register_migration(migration)
migrator.run_migrations()
assert migrator._get_current_version() == 3
def test_backup_and_restore_db(migrator: SQLiteMigrator):
with TemporaryDirectory() as tempdir:
# must do this with a file database - we don't backup/restore for memory
database = Path(tempdir) / "test.db"
migrator._database = database
migrator._cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
migrator._conn.commit()
backup_path = migrator._backup_db(migrator._database)
migrator._cursor.execute("DROP TABLE test;")
migrator._conn.commit()
migrator._restore_db(backup_path) # this closes the connection
# reconnect to db
restored_conn = sqlite3.connect(database)
restored_cursor = restored_conn.cursor()
restored_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
assert restored_cursor.fetchone() is not None
def test_no_backup_and_restore_for_memory_db(migrator: SQLiteMigrator):
with pytest.raises(MigrationError, match="Cannot back up memory database"):
migrator._backup_db(sqlite_memory)
def test_failed_migration(migrator: SQLiteMigrator, failing_migration: Migration):
migrator._create_version_table()
with pytest.raises(MigrationError, match="Error migrating database from 0 to 1"):
migrator._run_migration(failing_migration)
assert migrator._get_current_version() == 0
def test_duplicate_migration_versions(migrator: SQLiteMigrator, good_migration: Migration):
migrator._create_version_table()
migrator.register_migration(good_migration)
with pytest.raises(MigrationVersionError, match="already registered"):
migrator.register_migration(deepcopy(good_migration))
def test_non_sequential_migration_registration(migrator: SQLiteMigrator):
migrator._create_version_table()
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
def migrate(cursor: sqlite3.Cursor) -> None:
cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);")
return migrate
migrations = [
Migration(db_version=i, app_version=f"{i}.0.0", migrate=create_migrate(i)) for i in reversed(range(1, 4))
]
for migration in migrations:
migrator.register_migration(migration)
migrator.run_migrations()
assert migrator._get_current_version() == 3
def test_db_version_gt_last_migration(migrator: SQLiteMigrator, good_migration: Migration):
migrator._create_version_table()
migrator.register_migration(good_migration)
migrator._set_version(db_version=2, app_version="2.0.0")
with pytest.raises(MigrationError, match="greater than the latest migration version"):
migrator.run_migrations()
assert migrator._get_current_version() == 2
def test_idempotent_migrations(migrator: SQLiteMigrator, good_migration: Migration):
migrator._create_version_table()
migrator.register_migration(good_migration)
migrator.run_migrations()
migrator.run_migrations()
assert migrator._get_current_version() == 1