diff --git a/invokeai/app/services/board_image_records/board_image_records_sqlite.py b/invokeai/app/services/board_image_records/board_image_records_sqlite.py index 54d9a0af04..cde810a739 100644 --- a/invokeai/app/services/board_image_records/board_image_records_sqlite.py +++ b/invokeai/app/services/board_image_records/board_image_records_sqlite.py @@ -20,63 +20,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase): self._conn = db.conn self._cursor = self._conn.cursor() - try: - self._lock.acquire() - self._create_tables() - self._conn.commit() - finally: - self._lock.release() - - def _create_tables(self) -> None: - """Creates the `board_images` junction table.""" - - # Create the `board_images` junction table. - self._cursor.execute( - """--sql - CREATE TABLE IF NOT EXISTS board_images ( - board_id TEXT NOT NULL, - image_name TEXT NOT NULL, - created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), - -- updated via trigger - updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), - -- Soft delete, currently unused - deleted_at DATETIME, - -- enforce one-to-many relationship between boards and images using PK - -- (we can extend this to many-to-many later) - PRIMARY KEY (image_name), - FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE, - FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE - ); - """ - ) - - # Add index for board id - self._cursor.execute( - """--sql - CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id); - """ - ) - - # Add index for board id, sorted by created_at - self._cursor.execute( - """--sql - CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at); - """ - ) - - # Add trigger for `updated_at`. - self._cursor.execute( - """--sql - CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at - AFTER UPDATE - ON board_images FOR EACH ROW - BEGIN - UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') - WHERE board_id = old.board_id AND image_name = old.image_name; - END; - """ - ) - def add_image_to_board( self, board_id: str, diff --git a/invokeai/app/services/board_records/board_records_sqlite.py b/invokeai/app/services/board_records/board_records_sqlite.py index 165ce8df0c..a3836cb6c7 100644 --- a/invokeai/app/services/board_records/board_records_sqlite.py +++ b/invokeai/app/services/board_records/board_records_sqlite.py @@ -28,52 +28,6 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): self._conn = db.conn self._cursor = self._conn.cursor() - try: - self._lock.acquire() - self._create_tables() - self._conn.commit() - finally: - self._lock.release() - - def _create_tables(self) -> None: - """Creates the `boards` table and `board_images` junction table.""" - - # Create the `boards` table. - self._cursor.execute( - """--sql - CREATE TABLE IF NOT EXISTS boards ( - board_id TEXT NOT NULL PRIMARY KEY, - board_name TEXT NOT NULL, - cover_image_name TEXT, - created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), - -- Updated via trigger - updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), - -- Soft delete, currently unused - deleted_at DATETIME, - FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL - ); - """ - ) - - self._cursor.execute( - """--sql - CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at); - """ - ) - - # Add trigger for `updated_at`. - self._cursor.execute( - """--sql - CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at - AFTER UPDATE - ON boards FOR EACH ROW - BEGIN - UPDATE boards SET updated_at = current_timestamp - WHERE board_id = old.board_id; - END; - """ - ) - def delete(self, board_id: str) -> None: try: self._lock.acquire() diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py index b14c322f50..74f82e7d84 100644 --- a/invokeai/app/services/image_records/image_records_sqlite.py +++ b/invokeai/app/services/image_records/image_records_sqlite.py @@ -32,101 +32,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): self._conn = db.conn self._cursor = self._conn.cursor() - try: - self._lock.acquire() - self._create_tables() - self._conn.commit() - finally: - self._lock.release() - - def _create_tables(self) -> None: - """Creates the `images` table.""" - - # Create the `images` table. - self._cursor.execute( - """--sql - CREATE TABLE IF NOT EXISTS images ( - image_name TEXT NOT NULL PRIMARY KEY, - -- This is an enum in python, unrestricted string here for flexibility - image_origin TEXT NOT NULL, - -- This is an enum in python, unrestricted string here for flexibility - image_category TEXT NOT NULL, - width INTEGER NOT NULL, - height INTEGER NOT NULL, - session_id TEXT, - node_id TEXT, - metadata TEXT, - is_intermediate BOOLEAN DEFAULT FALSE, - created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), - -- Updated via trigger - updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), - -- Soft delete, currently unused - deleted_at DATETIME - ); - """ - ) - - self._cursor.execute("PRAGMA table_info(images)") - columns = [column[1] for column in self._cursor.fetchall()] - - if "starred" not in columns: - self._cursor.execute( - """--sql - ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE; - """ - ) - - # Create the `images` table indices. - self._cursor.execute( - """--sql - CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name); - """ - ) - self._cursor.execute( - """--sql - CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin); - """ - ) - self._cursor.execute( - """--sql - CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category); - """ - ) - self._cursor.execute( - """--sql - CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at); - """ - ) - - self._cursor.execute( - """--sql - CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred); - """ - ) - - # Add trigger for `updated_at`. - self._cursor.execute( - """--sql - CREATE TRIGGER IF NOT EXISTS tg_images_updated_at - AFTER UPDATE - ON images FOR EACH ROW - BEGIN - UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') - WHERE image_name = old.image_name; - END; - """ - ) - - self._cursor.execute("PRAGMA table_info(images)") - columns = [column[1] for column in self._cursor.fetchall()] - if "has_workflow" not in columns: - self._cursor.execute( - """--sql - ALTER TABLE images - ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE; - """ - ) - def get(self, image_name: str) -> ImageRecord: try: self._lock.acquire() diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 9b0612a846..679d05fccd 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -9,9 +9,6 @@ from typing import List, Optional, Union from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType -# should match the InvokeAI version when this is first released. -CONFIG_FILE_VERSION = "3.2.0" - class DuplicateModelException(Exception): """Raised on an attempt to add a model with the same key twice.""" @@ -32,12 +29,6 @@ class ConfigFileVersionMismatchException(Exception): class ModelRecordServiceBase(ABC): """Abstract base class for storage and retrieval of model configs.""" - @property - @abstractmethod - def version(self) -> str: - """Return the config file/database schema version.""" - pass - @abstractmethod def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig: """ diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 69ac0e158f..83b4d5b627 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -54,7 +54,6 @@ from invokeai.backend.model_manager.config import ( from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_records_base import ( - CONFIG_FILE_VERSION, DuplicateModelException, ModelRecordServiceBase, UnknownModelException, @@ -78,85 +77,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): self._db = db self._cursor = self._db.conn.cursor() - with self._db.lock: - # Enable foreign keys - self._db.conn.execute("PRAGMA foreign_keys = ON;") - self._create_tables() - self._db.conn.commit() - assert ( - str(self.version) == CONFIG_FILE_VERSION - ), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}" - - def _create_tables(self) -> None: - """Create sqlite3 tables.""" - # model_config table breaks out the fields that are common to all config objects - # and puts class-specific ones in a serialized json object - self._cursor.execute( - """--sql - CREATE TABLE IF NOT EXISTS model_config ( - id TEXT NOT NULL PRIMARY KEY, - -- The next 3 fields are enums in python, unrestricted string here - base TEXT NOT NULL, - type TEXT NOT NULL, - name TEXT NOT NULL, - path TEXT NOT NULL, - original_hash TEXT, -- could be null - -- Serialized JSON representation of the whole config object, - -- which will contain additional fields from subclasses - config TEXT NOT NULL, - created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), - -- Updated via trigger - updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), - -- unique constraint on combo of name, base and type - UNIQUE(name, base, type) - ); - """ - ) - - # metadata table - self._cursor.execute( - """--sql - CREATE TABLE IF NOT EXISTS model_manager_metadata ( - metadata_key TEXT NOT NULL PRIMARY KEY, - metadata_value TEXT NOT NULL - ); - """ - ) - - # Add trigger for `updated_at`. - self._cursor.execute( - """--sql - CREATE TRIGGER IF NOT EXISTS model_config_updated_at - AFTER UPDATE - ON model_config FOR EACH ROW - BEGIN - UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') - WHERE id = old.id; - END; - """ - ) - - # Add indexes for searchable fields - for stmt in [ - "CREATE INDEX IF NOT EXISTS base_index ON model_config(base);", - "CREATE INDEX IF NOT EXISTS type_index ON model_config(type);", - "CREATE INDEX IF NOT EXISTS name_index ON model_config(name);", - "CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);", - ]: - self._cursor.execute(stmt) - - # Add our version to the metadata table - self._cursor.execute( - """--sql - INSERT OR IGNORE into model_manager_metadata ( - metadata_key, - metadata_value - ) - VALUES (?,?); - """, - ("version", CONFIG_FILE_VERSION), - ) - def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig: """ Add a model to the database. @@ -214,22 +134,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): return self.get_model(key) - @property - def version(self) -> str: - """Return the version of the database schema.""" - with self._db.lock: - self._cursor.execute( - """--sql - SELECT metadata_value FROM model_manager_metadata - WHERE metadata_key=?; - """, - ("version",), - ) - rows = self._cursor.fetchone() - if not rows: - raise KeyError("Models database does not have metadata key 'version'") - return rows[0] - def del_model(self, key: str) -> None: """ Delete a model. diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 71f28c102b..64642690e9 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -50,7 +50,6 @@ class SqliteSessionQueue(SessionQueueBase): self.__lock = db.lock self.__conn = db.conn self.__cursor = self.__conn.cursor() - self._create_tables() def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool: return event[1]["event"] in match_in @@ -98,123 +97,6 @@ class SqliteSessionQueue(SessionQueueBase): except SessionQueueItemNotFoundError: return - def _create_tables(self) -> None: - """Creates the session queue tables, indicies, and triggers""" - try: - self.__lock.acquire() - self.__cursor.execute( - """--sql - CREATE TABLE IF NOT EXISTS session_queue ( - item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination - batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to - queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to - session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access - field_values TEXT, -- NULL if no values are associated with this queue item - session TEXT NOT NULL, -- the session to be executed - status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled' - priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important - error TEXT, -- any errors associated with this queue item - created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), - updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger - started_at DATETIME, -- updated via trigger - completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup - -- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE... - -- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE - ); - """ - ) - - self.__cursor.execute( - """--sql - CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id); - """ - ) - - self.__cursor.execute( - """--sql - CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id); - """ - ) - - self.__cursor.execute( - """--sql - CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id); - """ - ) - - self.__cursor.execute( - """--sql - CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority); - """ - ) - - self.__cursor.execute( - """--sql - CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status); - """ - ) - - self.__cursor.execute( - """--sql - CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at - AFTER UPDATE OF status ON session_queue - FOR EACH ROW - WHEN - NEW.status = 'completed' - OR NEW.status = 'failed' - OR NEW.status = 'canceled' - BEGIN - UPDATE session_queue - SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') - WHERE item_id = NEW.item_id; - END; - """ - ) - - self.__cursor.execute( - """--sql - CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at - AFTER UPDATE OF status ON session_queue - FOR EACH ROW - WHEN - NEW.status = 'in_progress' - BEGIN - UPDATE session_queue - SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') - WHERE item_id = NEW.item_id; - END; - """ - ) - - self.__cursor.execute( - """--sql - CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at - AFTER UPDATE - ON session_queue FOR EACH ROW - BEGIN - UPDATE session_queue - SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') - WHERE item_id = old.item_id; - END; - """ - ) - - self.__cursor.execute("PRAGMA table_info(session_queue)") - columns = [column[1] for column in self.__cursor.fetchall()] - if "workflow" not in columns: - self.__cursor.execute( - """--sql - ALTER TABLE session_queue ADD COLUMN workflow TEXT; - """ - ) - - self.__conn.commit() - except Exception: - self.__conn.rollback() - raise - finally: - self.__lock.release() - def _set_in_progress_to_canceled(self) -> None: """ Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue. diff --git a/invokeai/app/services/shared/sqlite/migrations/__init__.py b/invokeai/app/services/shared/sqlite/migrations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/app/services/shared/sqlite/migrations/migration_1.py b/invokeai/app/services/shared/sqlite/migrations/migration_1.py new file mode 100644 index 0000000000..52d10e095d --- /dev/null +++ b/invokeai/app/services/shared/sqlite/migrations/migration_1.py @@ -0,0 +1,370 @@ +import sqlite3 + +from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration + + +def _migrate(cursor: sqlite3.Cursor) -> None: + """Migration callback for database version 1.""" + + _create_board_images(cursor) + _create_boards(cursor) + _create_images(cursor) + _create_model_config(cursor) + _create_session_queue(cursor) + _create_workflow_images(cursor) + _create_workflows(cursor) + + +def _create_board_images(cursor: sqlite3.Cursor) -> None: + """Creates the `board_images` table, indices and triggers.""" + tables = [ + """--sql + CREATE TABLE IF NOT EXISTS board_images ( + board_id TEXT NOT NULL, + image_name TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Soft delete, currently unused + deleted_at DATETIME, + -- enforce one-to-many relationship between boards and images using PK + -- (we can extend this to many-to-many later) + PRIMARY KEY (image_name), + FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE, + FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE + ); + """ + ] + + indices = [ + "CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);", + "CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);", + ] + + triggers = [ + """--sql + CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at + AFTER UPDATE + ON board_images FOR EACH ROW + BEGIN + UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE board_id = old.board_id AND image_name = old.image_name; + END; + """ + ] + + for stmt in tables + indices + triggers: + cursor.execute(stmt) + + +def _create_boards(cursor: sqlite3.Cursor) -> None: + """Creates the `boards` table, indices and triggers.""" + tables = [ + """--sql + CREATE TABLE IF NOT EXISTS boards ( + board_id TEXT NOT NULL PRIMARY KEY, + board_name TEXT NOT NULL, + cover_image_name TEXT, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Soft delete, currently unused + deleted_at DATETIME, + FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL + ); + """ + ] + + indices = ["CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);"] + + triggers = [ + """--sql + CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at + AFTER UPDATE + ON boards FOR EACH ROW + BEGIN + UPDATE boards SET updated_at = current_timestamp + WHERE board_id = old.board_id; + END; + """ + ] + + for stmt in tables + indices + triggers: + cursor.execute(stmt) + + +def _create_images(cursor: sqlite3.Cursor) -> None: + """Creates the `images` table, indices and triggers. Adds the `starred` column.""" + + tables = [ + """--sql + CREATE TABLE IF NOT EXISTS images ( + image_name TEXT NOT NULL PRIMARY KEY, + -- This is an enum in python, unrestricted string here for flexibility + image_origin TEXT NOT NULL, + -- This is an enum in python, unrestricted string here for flexibility + image_category TEXT NOT NULL, + width INTEGER NOT NULL, + height INTEGER NOT NULL, + session_id TEXT, + node_id TEXT, + metadata TEXT, + is_intermediate BOOLEAN DEFAULT FALSE, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Soft delete, currently unused + deleted_at DATETIME + ); + """ + ] + + indices = [ + "CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);", + "CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);", + "CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);", + "CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);", + ] + + triggers = [ + """--sql + CREATE TRIGGER IF NOT EXISTS tg_images_updated_at + AFTER UPDATE + ON images FOR EACH ROW + BEGIN + UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE image_name = old.image_name; + END; + """ + ] + + # Add the 'starred' column to `images` if it doesn't exist + cursor.execute("PRAGMA table_info(images)") + columns = [column[1] for column in cursor.fetchall()] + + if "starred" not in columns: + tables.append("ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;") + indices.append("CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);") + + for stmt in tables + indices + triggers: + cursor.execute(stmt) + + +def _create_model_config(cursor: sqlite3.Cursor) -> None: + """Creates the `model_config` table, `model_manager_metadata` table, indices and triggers.""" + + tables = [ + """--sql + CREATE TABLE IF NOT EXISTS model_config ( + id TEXT NOT NULL PRIMARY KEY, + -- The next 3 fields are enums in python, unrestricted string here + base TEXT NOT NULL, + type TEXT NOT NULL, + name TEXT NOT NULL, + path TEXT NOT NULL, + original_hash TEXT, -- could be null + -- Serialized JSON representation of the whole config object, + -- which will contain additional fields from subclasses + config TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- unique constraint on combo of name, base and type + UNIQUE(name, base, type) + ); + """, + """--sql + CREATE TABLE IF NOT EXISTS model_manager_metadata ( + metadata_key TEXT NOT NULL PRIMARY KEY, + metadata_value TEXT NOT NULL + ); + """, + ] + + # Add trigger for `updated_at`. + triggers = [ + """--sql + CREATE TRIGGER IF NOT EXISTS model_config_updated_at + AFTER UPDATE + ON model_config FOR EACH ROW + BEGIN + UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE id = old.id; + END; + """ + ] + + # Add indexes for searchable fields + indices = [ + "CREATE INDEX IF NOT EXISTS base_index ON model_config(base);", + "CREATE INDEX IF NOT EXISTS type_index ON model_config(type);", + "CREATE INDEX IF NOT EXISTS name_index ON model_config(name);", + "CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);", + ] + + for stmt in tables + indices + triggers: + cursor.execute(stmt) + + +def _create_session_queue(cursor: sqlite3.Cursor) -> None: + tables = [ + """--sql + CREATE TABLE IF NOT EXISTS session_queue ( + item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination + batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to + queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to + session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access + field_values TEXT, -- NULL if no values are associated with this queue item + session TEXT NOT NULL, -- the session to be executed + status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled' + priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important + error TEXT, -- any errors associated with this queue item + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger + started_at DATETIME, -- updated via trigger + completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup + -- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE... + -- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE + ); + """ + ] + + indices = [ + "CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);", + "CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);", + "CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);", + "CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);", + ] + + triggers = [ + """--sql + CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at + AFTER UPDATE OF status ON session_queue + FOR EACH ROW + WHEN + NEW.status = 'completed' + OR NEW.status = 'failed' + OR NEW.status = 'canceled' + BEGIN + UPDATE session_queue + SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE item_id = NEW.item_id; + END; + """, + """--sql + CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at + AFTER UPDATE OF status ON session_queue + FOR EACH ROW + WHEN + NEW.status = 'in_progress' + BEGIN + UPDATE session_queue + SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE item_id = NEW.item_id; + END; + """, + """--sql + CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at + AFTER UPDATE + ON session_queue FOR EACH ROW + BEGIN + UPDATE session_queue + SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE item_id = old.item_id; + END; + """, + ] + + for stmt in tables + indices + triggers: + cursor.execute(stmt) + + +def _create_workflow_images(cursor: sqlite3.Cursor) -> None: + tables = [ + """--sql + CREATE TABLE IF NOT EXISTS workflow_images ( + workflow_id TEXT NOT NULL, + image_name TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Soft delete, currently unused + deleted_at DATETIME, + -- enforce one-to-many relationship between workflows and images using PK + -- (we can extend this to many-to-many later) + PRIMARY KEY (image_name), + FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE CASCADE, + FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE + ); + """ + ] + + indices = [ + "CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id ON workflow_images (workflow_id);", + "CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id_created_at ON workflow_images (workflow_id, created_at);", + ] + + triggers = [ + """--sql + CREATE TRIGGER IF NOT EXISTS tg_workflow_images_updated_at + AFTER UPDATE + ON workflow_images FOR EACH ROW + BEGIN + UPDATE workflow_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE workflow_id = old.workflow_id AND image_name = old.image_name; + END; + """ + ] + + for stmt in tables + indices + triggers: + cursor.execute(stmt) + + +def _create_workflows(cursor: sqlite3.Cursor) -> None: + tables = [ + """--sql + CREATE TABLE IF NOT EXISTS workflows ( + workflow TEXT NOT NULL, + workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger + ); + """ + ] + + triggers = [ + """--sql + CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at + AFTER UPDATE + ON workflows FOR EACH ROW + BEGIN + UPDATE workflows + SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE workflow_id = old.workflow_id; + END; + """ + ] + + for stmt in tables + triggers: + cursor.execute(stmt) + + +migration_1 = Migration(db_version=1, app_version="3.4.0", migrate=_migrate) +""" +Database version 1 (initial state). + +This migration represents the state of the database circa InvokeAI v3.4.0, which was the last +version to not use migrations to manage the database. + +As such, this migration does include some ALTER statements, and the SQL statements are written +to be idempotent. + +- Create `board_images` junction table +- Create `boards` table +- Create `images` table, add `starred` column +- Create `model_config` table +- Create `session_queue` table +- Create `workflow_images` junction table +- Create `workflows` table +""" diff --git a/invokeai/app/services/shared/sqlite/migrations/migration_2.py b/invokeai/app/services/shared/sqlite/migrations/migration_2.py new file mode 100644 index 0000000000..da9d670917 --- /dev/null +++ b/invokeai/app/services/shared/sqlite/migrations/migration_2.py @@ -0,0 +1,97 @@ +import sqlite3 + +from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration + + +def _migrate(cursor: sqlite3.Cursor) -> None: + """Migration callback for database version 2.""" + + _add_images_has_workflow(cursor) + _add_session_queue_workflow(cursor) + _drop_old_workflow_tables(cursor) + _add_workflow_library(cursor) + _drop_model_manager_metadata(cursor) + + +def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None: + """Add the `has_workflow` column to `images` table.""" + cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;") + + +def _add_session_queue_workflow(cursor: sqlite3.Cursor) -> None: + """Add the `workflow` column to `session_queue` table.""" + cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;") + + +def _drop_old_workflow_tables(cursor: sqlite3.Cursor) -> None: + """Drops the `workflows` and `workflow_images` tables.""" + cursor.execute("DROP TABLE workflow_images;") + cursor.execute("DROP TABLE workflows;") + + +def _add_workflow_library(cursor: sqlite3.Cursor) -> None: + """Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables.""" + tables = [ + """--sql + CREATE TABLE workflow_library ( + workflow_id TEXT NOT NULL PRIMARY KEY, + workflow TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- updated manually when retrieving workflow + opened_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Generated columns, needed for indexing and searching + category TEXT GENERATED ALWAYS as (json_extract(workflow, '$.meta.category')) VIRTUAL NOT NULL, + name TEXT GENERATED ALWAYS as (json_extract(workflow, '$.name')) VIRTUAL NOT NULL, + description TEXT GENERATED ALWAYS as (json_extract(workflow, '$.description')) VIRTUAL NOT NULL + ); + """, + ] + + indices = [ + "CREATE INDEX idx_workflow_library_created_at ON workflow_library(created_at);", + "CREATE INDEX idx_workflow_library_updated_at ON workflow_library(updated_at);", + "CREATE INDEX idx_workflow_library_opened_at ON workflow_library(opened_at);", + "CREATE INDEX idx_workflow_library_category ON workflow_library(category);", + "CREATE INDEX idx_workflow_library_name ON workflow_library(name);", + "CREATE INDEX idx_workflow_library_description ON workflow_library(description);", + ] + + triggers = [ + """--sql + CREATE TRIGGER tg_workflow_library_updated_at + AFTER UPDATE + ON workflow_library FOR EACH ROW + BEGIN + UPDATE workflow_library + SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE workflow_id = old.workflow_id; + END; + """ + ] + + for stmt in tables + indices + triggers: + cursor.execute(stmt) + + +def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None: + """Drops the `model_manager_metadata` table.""" + cursor.execute("DROP TABLE model_manager_metadata;") + + +migration_2 = Migration( + db_version=2, + app_version="3.5.0", + migrate=_migrate, +) +""" +Database version 2. + +Introduced in v3.5.0 for the new workflow library. + +- Add `has_workflow` column to `images` table +- Add `workflow` column to `session_queue` table +- Drop `workflows` and `workflow_images` tables +- Add `workflow_library` table +""" diff --git a/invokeai/app/services/shared/sqlite/sqlite_database.py b/invokeai/app/services/shared/sqlite/sqlite_database.py index 006eb61cbd..5cc9b66b3b 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_database.py +++ b/invokeai/app/services/shared/sqlite/sqlite_database.py @@ -4,24 +4,27 @@ from logging import Logger from pathlib import Path from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.shared.sqlite.migrations.migration_1 import migration_1 +from invokeai.app.services.shared.sqlite.migrations.migration_2 import migration_2 from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory +from invokeai.app.services.shared.sqlite.sqlite_migrator import SQLiteMigrator class SqliteDatabase: + database: Path | str + def __init__(self, config: InvokeAIAppConfig, logger: Logger): self._logger = logger self._config = config - if self._config.use_memory_db: - self.db_path = sqlite_memory + self.database = sqlite_memory logger.info("Using in-memory database") else: - db_path = self._config.db_path - db_path.parent.mkdir(parents=True, exist_ok=True) - self.db_path = str(db_path) - self._logger.info(f"Using database at {self.db_path}") + self.database = self._config.db_path + self.database.parent.mkdir(parents=True, exist_ok=True) + self._logger.info(f"Using database at {self.database}") - self.conn = sqlite3.connect(self.db_path, check_same_thread=False) + self.conn = sqlite3.connect(database=self.database, check_same_thread=False) self.lock = threading.RLock() self.conn.row_factory = sqlite3.Row @@ -30,15 +33,20 @@ class SqliteDatabase: self.conn.execute("PRAGMA foreign_keys = ON;") + migrator = SQLiteMigrator(conn=self.conn, database=self.database, lock=self.lock, logger=self._logger) + migrator.register_migration(migration_1) + migrator.register_migration(migration_2) + migrator.run_migrations() + def clean(self) -> None: with self.lock: try: - if self.db_path == sqlite_memory: + if self.database == sqlite_memory: return - initial_db_size = Path(self.db_path).stat().st_size + initial_db_size = Path(self.database).stat().st_size self.conn.execute("VACUUM;") self.conn.commit() - final_db_size = Path(self.db_path).stat().st_size + final_db_size = Path(self.database).stat().st_size freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2) if freed_space_in_mb > 0: self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)") diff --git a/invokeai/app/services/shared/sqlite/sqlite_migrator.py b/invokeai/app/services/shared/sqlite/sqlite_migrator.py new file mode 100644 index 0000000000..a5cf22e3c3 --- /dev/null +++ b/invokeai/app/services/shared/sqlite/sqlite_migrator.py @@ -0,0 +1,210 @@ +import shutil +import sqlite3 +import threading +from datetime import datetime +from logging import Logger +from pathlib import Path +from typing import Callable, Optional, TypeAlias + +from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory + +MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None] + + +class MigrationError(Exception): + """Raised when a migration fails.""" + + +class MigrationVersionError(ValueError, MigrationError): + """Raised when a migration version is invalid.""" + + +class Migration: + """Represents a migration for a SQLite database. + + :param db_version: The database schema version this migration results in. + :param app_version: The app version this migration is introduced in. + :param migrate: The callback to run to perform the migration. The callback will be passed a + cursor to the database. The migrator will manage locking database access and committing the + transaction; the callback should not do either of these things. + """ + + def __init__( + self, + db_version: int, + app_version: str, + migrate: MigrateCallback, + ) -> None: + self.db_version = db_version + self.app_version = app_version + self.migrate = migrate + + +class SQLiteMigrator: + """ + Manages migrations for a SQLite database. + + :param conn: The database connection. + :param database: The path to the database file, or ":memory:" for an in-memory database. + :param lock: A lock to use when accessing the database. + :param logger: The logger to use. + + Migrations should be registered with :meth:`register_migration`. Migrations will be run in + order of their version number. If the database is already at the latest version, no migrations + will be run. + """ + + def __init__(self, conn: sqlite3.Connection, database: Path | str, lock: threading.RLock, logger: Logger) -> None: + self._logger = logger + self._conn = conn + self._cursor = self._conn.cursor() + self._lock = lock + self._database = database + self._migrations: set[Migration] = set() + + def register_migration(self, migration: Migration) -> None: + """Registers a migration.""" + if not isinstance(migration.db_version, int) or migration.db_version < 1: + raise MigrationVersionError(f"Invalid migration version {migration.db_version}") + if any(m.db_version == migration.db_version for m in self._migrations): + raise MigrationVersionError(f"Migration version {migration.db_version} already registered") + self._migrations.add(migration) + self._logger.debug(f"Registered migration {migration.db_version}") + + def run_migrations(self) -> None: + """Migrates the database to the latest version.""" + with self._lock: + self._create_version_table() + sorted_migrations = sorted(self._migrations, key=lambda m: m.db_version) + current_version = self._get_current_version() + + if len(sorted_migrations) == 0: + self._logger.debug("No migrations registered") + return + + latest_version = sorted_migrations[-1].db_version + if current_version == latest_version: + self._logger.debug("Database is up to date, no migrations to run") + return + + if current_version > latest_version: + raise MigrationError( + f"Database version {current_version} is greater than the latest migration version {latest_version}" + ) + + self._logger.info("Database update needed") + + # Only make a backup if using a file database (not memory) + backup_path: Optional[Path] = None + if isinstance(self._database, Path): + backup_path = self._backup_db(self._database) + else: + self._logger.info("Using in-memory database, skipping backup") + + for migration in sorted_migrations: + try: + self._run_migration(migration) + except MigrationError: + if backup_path is not None: + self._logger.error(f" Restoring from {backup_path}") + self._restore_db(backup_path) + raise + self._logger.info("Database updated successfully") + + def _run_migration(self, migration: Migration) -> None: + """Runs a single migration.""" + with self._lock: + current_version = self._get_current_version() + try: + if current_version >= migration.db_version: + return + migration.migrate(self._cursor) + # Migration callbacks only get a cursor; they cannot commit the transaction. + self._conn.commit() + self._set_version(db_version=migration.db_version, app_version=migration.app_version) + self._logger.debug(f"Successfully migrated database from {current_version} to {migration.db_version}") + except Exception as e: + msg = f"Error migrating database from {current_version} to {migration.db_version}: {e}" + self._conn.rollback() + self._logger.error(msg) + raise MigrationError(msg) from e + + def _create_version_table(self) -> None: + """Creates a version table for the database, if one does not already exist.""" + with self._lock: + try: + self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='version';") + if self._cursor.fetchone() is not None: + return + self._cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS version ( + db_version INTEGER PRIMARY KEY, + app_version TEXT NOT NULL, + migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) + ); + """ + ) + self._cursor.execute("INSERT INTO version (db_version, app_version) VALUES (?,?);", (0, "0.0.0")) + self._conn.commit() + self._logger.debug("Created version table") + except sqlite3.Error as e: + msg = f"Problem creation version table: {e}" + self._logger.error(msg) + self._conn.rollback() + raise MigrationError(msg) from e + + def _get_current_version(self) -> int: + """Gets the current version of the database, or 0 if the version table does not exist.""" + with self._lock: + try: + self._cursor.execute("SELECT MAX(db_version) FROM version;") + version = self._cursor.fetchone()[0] + if version is None: + return 0 + return version + except sqlite3.OperationalError as e: + if "no such table" in str(e): + return 0 + raise + + def _set_version(self, db_version: int, app_version: str) -> None: + """Adds a version entry to the table's version table.""" + with self._lock: + try: + self._cursor.execute( + "INSERT INTO version (db_version, app_version) VALUES (?,?);", (db_version, app_version) + ) + self._conn.commit() + except sqlite3.Error as e: + msg = f"Problem setting database version: {e}" + self._logger.error(msg) + self._conn.rollback() + raise MigrationError(msg) from e + + def _backup_db(self, db_path: Path | str) -> Path: + """Backs up the databse, returning the path to the backup file.""" + if db_path == sqlite_memory: + raise MigrationError("Cannot back up memory database") + if not isinstance(db_path, Path): + raise MigrationError(f'Database path must be "{sqlite_memory}" or a Path') + with self._lock: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_path = db_path.parent / f"{db_path.stem}_{timestamp}.db" + self._logger.info(f"Backing up database to {backup_path}") + backup_conn = sqlite3.connect(backup_path) + with backup_conn: + self._conn.backup(backup_conn) + backup_conn.close() + return backup_path + + def _restore_db(self, backup_path: Path) -> None: + """Restores the database from a backup file, unless the database is a memory database.""" + if self._database == sqlite_memory: + return + with self._lock: + self._logger.info(f"Restoring database from {backup_path}") + self._conn.close() + if not Path(backup_path).is_file(): + raise FileNotFoundError(f"Backup file {backup_path} does not exist") + shutil.copy2(backup_path, self._database) diff --git a/invokeai/app/services/workflow_records/workflow_records_sqlite.py b/invokeai/app/services/workflow_records/workflow_records_sqlite.py index ecbe7c0c9b..ef9e60fb9a 100644 --- a/invokeai/app/services/workflow_records/workflow_records_sqlite.py +++ b/invokeai/app/services/workflow_records/workflow_records_sqlite.py @@ -26,7 +26,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): self._lock = db.lock self._conn = db.conn self._cursor = self._conn.cursor() - self._create_tables() def start(self, invoker: Invoker) -> None: self._invoker = invoker @@ -233,87 +232,3 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): raise finally: self._lock.release() - - def _create_tables(self) -> None: - try: - self._lock.acquire() - self._cursor.execute( - """--sql - CREATE TABLE IF NOT EXISTS workflow_library ( - workflow_id TEXT NOT NULL PRIMARY KEY, - workflow TEXT NOT NULL, - created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), - -- updated via trigger - updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), - -- updated manually when retrieving workflow - opened_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), - -- Generated columns, needed for indexing and searching - category TEXT GENERATED ALWAYS as (json_extract(workflow, '$.meta.category')) VIRTUAL NOT NULL, - name TEXT GENERATED ALWAYS as (json_extract(workflow, '$.name')) VIRTUAL NOT NULL, - description TEXT GENERATED ALWAYS as (json_extract(workflow, '$.description')) VIRTUAL NOT NULL - ); - """ - ) - - self._cursor.execute( - """--sql - CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at - AFTER UPDATE - ON workflow_library FOR EACH ROW - BEGIN - UPDATE workflow_library - SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') - WHERE workflow_id = old.workflow_id; - END; - """ - ) - - self._cursor.execute( - """--sql - CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at); - """ - ) - self._cursor.execute( - """--sql - CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at); - """ - ) - self._cursor.execute( - """--sql - CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at); - """ - ) - self._cursor.execute( - """--sql - CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category); - """ - ) - self._cursor.execute( - """--sql - CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name); - """ - ) - self._cursor.execute( - """--sql - CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description); - """ - ) - - # We do not need the original `workflows` table or `workflow_images` junction table. - self._cursor.execute( - """--sql - DROP TABLE IF EXISTS workflow_images; - """ - ) - self._cursor.execute( - """--sql - DROP TABLE IF EXISTS workflows; - """ - ) - - self._conn.commit() - except Exception: - self._conn.rollback() - raise - finally: - self._lock.release() diff --git a/tests/test_sqlite_migrator.py b/tests/test_sqlite_migrator.py new file mode 100644 index 0000000000..d4da61caa5 --- /dev/null +++ b/tests/test_sqlite_migrator.py @@ -0,0 +1,173 @@ +import sqlite3 +import threading +from copy import deepcopy +from logging import Logger +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Callable + +import pytest + +from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory +from invokeai.app.services.shared.sqlite.sqlite_migrator import ( + Migration, + MigrationError, + MigrationVersionError, + SQLiteMigrator, +) + + +@pytest.fixture +def migrator() -> SQLiteMigrator: + conn = sqlite3.connect(sqlite_memory, check_same_thread=False) + return SQLiteMigrator( + conn=conn, database=sqlite_memory, lock=threading.RLock(), logger=Logger("test_sqlite_migrator") + ) + + +@pytest.fixture +def good_migration() -> Migration: + return Migration(db_version=1, app_version="1.0.0", migrate=lambda cursor: None) + + +@pytest.fixture +def failing_migration() -> Migration: + def failing_migration(cursor: sqlite3.Cursor) -> None: + raise Exception("Bad migration") + + return Migration(db_version=1, app_version="1.0.0", migrate=failing_migration) + + +def test_register_migration(migrator: SQLiteMigrator, good_migration: Migration): + migration = good_migration + migrator.register_migration(migration) + assert migration in migrator._migrations + with pytest.raises(MigrationError, match="Invalid migration version"): + migrator.register_migration(Migration(db_version=0, app_version="0.0.0", migrate=lambda cursor: None)) + + +def test_register_invalid_migration_version(migrator: SQLiteMigrator): + with pytest.raises(MigrationError, match="Invalid migration version"): + migrator.register_migration(Migration(db_version=0, app_version="0.0.0", migrate=lambda cursor: None)) + + +def test_create_version_table(migrator: SQLiteMigrator): + migrator._create_version_table() + migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='version';") + assert migrator._cursor.fetchone() is not None + + +def test_get_current_version(migrator: SQLiteMigrator): + migrator._create_version_table() + migrator._conn.commit() + assert migrator._get_current_version() == 0 # initial version + + +def test_set_version(migrator: SQLiteMigrator): + migrator._create_version_table() + migrator._set_version(db_version=1, app_version="1.0.0") + migrator._cursor.execute("SELECT MAX(db_version) FROM version;") + assert migrator._cursor.fetchone()[0] == 1 + migrator._cursor.execute("SELECT app_version from version WHERE db_version = 1;") + assert migrator._cursor.fetchone()[0] == "1.0.0" + + +def test_run_migration(migrator: SQLiteMigrator): + migrator._create_version_table() + + def migration_callback(cursor: sqlite3.Cursor) -> None: + cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);") + + migration = Migration(db_version=1, app_version="1.0.0", migrate=migration_callback) + migrator._run_migration(migration) + assert migrator._get_current_version() == 1 + migrator._cursor.execute("SELECT app_version from version WHERE db_version = 1;") + assert migrator._cursor.fetchone()[0] == "1.0.0" + + +def test_run_migrations(migrator: SQLiteMigrator): + migrator._create_version_table() + + def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]: + def migrate(cursor: sqlite3.Cursor) -> None: + cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);") + + return migrate + + migrations = [Migration(db_version=i, app_version=f"{i}.0.0", migrate=create_migrate(i)) for i in range(1, 4)] + for migration in migrations: + migrator.register_migration(migration) + migrator.run_migrations() + assert migrator._get_current_version() == 3 + + +def test_backup_and_restore_db(migrator: SQLiteMigrator): + with TemporaryDirectory() as tempdir: + # must do this with a file database - we don't backup/restore for memory + database = Path(tempdir) / "test.db" + migrator._database = database + migrator._cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);") + migrator._conn.commit() + backup_path = migrator._backup_db(migrator._database) + migrator._cursor.execute("DROP TABLE test;") + migrator._conn.commit() + migrator._restore_db(backup_path) # this closes the connection + # reconnect to db + restored_conn = sqlite3.connect(database) + restored_cursor = restored_conn.cursor() + restored_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';") + assert restored_cursor.fetchone() is not None + + +def test_no_backup_and_restore_for_memory_db(migrator: SQLiteMigrator): + with pytest.raises(MigrationError, match="Cannot back up memory database"): + migrator._backup_db(sqlite_memory) + + +def test_failed_migration(migrator: SQLiteMigrator, failing_migration: Migration): + migrator._create_version_table() + with pytest.raises(MigrationError, match="Error migrating database from 0 to 1"): + migrator._run_migration(failing_migration) + assert migrator._get_current_version() == 0 + + +def test_duplicate_migration_versions(migrator: SQLiteMigrator, good_migration: Migration): + migrator._create_version_table() + migrator.register_migration(good_migration) + with pytest.raises(MigrationVersionError, match="already registered"): + migrator.register_migration(deepcopy(good_migration)) + + +def test_non_sequential_migration_registration(migrator: SQLiteMigrator): + migrator._create_version_table() + + def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]: + def migrate(cursor: sqlite3.Cursor) -> None: + cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);") + + return migrate + + migrations = [ + Migration(db_version=i, app_version=f"{i}.0.0", migrate=create_migrate(i)) for i in reversed(range(1, 4)) + ] + for migration in migrations: + migrator.register_migration(migration) + migrator.run_migrations() + assert migrator._get_current_version() == 3 + + +def test_db_version_gt_last_migration(migrator: SQLiteMigrator, good_migration: Migration): + migrator._create_version_table() + migrator.register_migration(good_migration) + migrator._set_version(db_version=2, app_version="2.0.0") + with pytest.raises(MigrationError, match="greater than the latest migration version"): + migrator.run_migrations() + assert migrator._get_current_version() == 2 + + +def test_idempotent_migrations(migrator: SQLiteMigrator, good_migration: Migration): + migrator._create_version_table() + migrator.register_migration(good_migration) + migrator.run_migrations() + migrator.run_migrations() + assert migrator._get_current_version() == 1