mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): add SQLiteMigrator to perform db migrations
This commit is contained in:
parent
ef807cf63a
commit
f2c6819d68
@ -20,63 +20,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `board_images` junction table."""
|
||||
|
||||
# Create the `board_images` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS board_images (
|
||||
board_id TEXT NOT NULL,
|
||||
image_name TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between boards and images using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (image_name),
|
||||
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for board id
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for board id, sorted by created_at
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON board_images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE board_id = old.board_id AND image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
|
@ -28,52 +28,6 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `boards` table and `board_images` junction table."""
|
||||
|
||||
# Create the `boards` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS boards (
|
||||
board_id TEXT NOT NULL PRIMARY KEY,
|
||||
board_name TEXT NOT NULL,
|
||||
cover_image_name TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
|
||||
AFTER UPDATE
|
||||
ON boards FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE boards SET updated_at = current_timestamp
|
||||
WHERE board_id = old.board_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
@ -32,101 +32,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `images` table."""
|
||||
|
||||
# Create the `images` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
image_name TEXT NOT NULL PRIMARY KEY,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_origin TEXT NOT NULL,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_category TEXT NOT NULL,
|
||||
width INTEGER NOT NULL,
|
||||
height INTEGER NOT NULL,
|
||||
session_id TEXT,
|
||||
node_id TEXT,
|
||||
metadata TEXT,
|
||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute("PRAGMA table_info(images)")
|
||||
columns = [column[1] for column in self._cursor.fetchall()]
|
||||
|
||||
if "starred" not in columns:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `images` table indices.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute("PRAGMA table_info(images)")
|
||||
columns = [column[1] for column in self._cursor.fetchall()]
|
||||
if "has_workflow" not in columns:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
ALTER TABLE images
|
||||
ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;
|
||||
"""
|
||||
)
|
||||
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
@ -9,9 +9,6 @@ from typing import List, Optional, Union
|
||||
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
||||
|
||||
# should match the InvokeAI version when this is first released.
|
||||
CONFIG_FILE_VERSION = "3.2.0"
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
"""Raised on an attempt to add a model with the same key twice."""
|
||||
@ -32,12 +29,6 @@ class ConfigFileVersionMismatchException(Exception):
|
||||
class ModelRecordServiceBase(ABC):
|
||||
"""Abstract base class for storage and retrieval of model configs."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def version(self) -> str:
|
||||
"""Return the config file/database schema version."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
|
@ -54,7 +54,6 @@ from invokeai.backend.model_manager.config import (
|
||||
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_records_base import (
|
||||
CONFIG_FILE_VERSION,
|
||||
DuplicateModelException,
|
||||
ModelRecordServiceBase,
|
||||
UnknownModelException,
|
||||
@ -78,85 +77,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
with self._db.lock:
|
||||
# Enable foreign keys
|
||||
self._db.conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._db.conn.commit()
|
||||
assert (
|
||||
str(self.version) == CONFIG_FILE_VERSION
|
||||
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Create sqlite3 tables."""
|
||||
# model_config table breaks out the fields that are common to all config objects
|
||||
# and puts class-specific ones in a serialized json object
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_config (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
-- The next 3 fields are enums in python, unrestricted string here
|
||||
base TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
original_hash TEXT, -- could be null
|
||||
-- Serialized JSON representation of the whole config object,
|
||||
-- which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# metadata table
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_manager_metadata (
|
||||
metadata_key TEXT NOT NULL PRIMARY KEY,
|
||||
metadata_value TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
|
||||
AFTER UPDATE
|
||||
ON model_config FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
# Add indexes for searchable fields
|
||||
for stmt in [
|
||||
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
|
||||
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
|
||||
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
|
||||
]:
|
||||
self._cursor.execute(stmt)
|
||||
|
||||
# Add our version to the metadata table
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE into model_manager_metadata (
|
||||
metadata_key,
|
||||
metadata_value
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
("version", CONFIG_FILE_VERSION),
|
||||
)
|
||||
|
||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
@ -214,22 +134,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
return self.get_model(key)
|
||||
|
||||
@property
|
||||
def version(self) -> str:
|
||||
"""Return the version of the database schema."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata_value FROM model_manager_metadata
|
||||
WHERE metadata_key=?;
|
||||
""",
|
||||
("version",),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise KeyError("Models database does not have metadata key 'version'")
|
||||
return rows[0]
|
||||
|
||||
def del_model(self, key: str) -> None:
|
||||
"""
|
||||
Delete a model.
|
||||
|
@ -50,7 +50,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__lock = db.lock
|
||||
self.__conn = db.conn
|
||||
self.__cursor = self.__conn.cursor()
|
||||
self._create_tables()
|
||||
|
||||
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
|
||||
return event[1]["event"] in match_in
|
||||
@ -98,123 +97,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the session queue tables, indicies, and triggers"""
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS session_queue (
|
||||
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
|
||||
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
|
||||
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
|
||||
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
|
||||
field_values TEXT, -- NULL if no values are associated with this queue item
|
||||
session TEXT NOT NULL, -- the session to be executed
|
||||
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
|
||||
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
|
||||
error TEXT, -- any errors associated with this queue item
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
|
||||
started_at DATETIME, -- updated via trigger
|
||||
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
|
||||
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
|
||||
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'completed'
|
||||
OR NEW.status = 'failed'
|
||||
OR NEW.status = 'canceled'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'in_progress'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
|
||||
AFTER UPDATE
|
||||
ON session_queue FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = old.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute("PRAGMA table_info(session_queue)")
|
||||
columns = [column[1] for column in self.__cursor.fetchall()]
|
||||
if "workflow" not in columns:
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
ALTER TABLE session_queue ADD COLUMN workflow TEXT;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""
|
||||
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
|
||||
|
370
invokeai/app/services/shared/sqlite/migrations/migration_1.py
Normal file
370
invokeai/app/services/shared/sqlite/migrations/migration_1.py
Normal file
@ -0,0 +1,370 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
||||
|
||||
|
||||
def _migrate(cursor: sqlite3.Cursor) -> None:
|
||||
"""Migration callback for database version 1."""
|
||||
|
||||
_create_board_images(cursor)
|
||||
_create_boards(cursor)
|
||||
_create_images(cursor)
|
||||
_create_model_config(cursor)
|
||||
_create_session_queue(cursor)
|
||||
_create_workflow_images(cursor)
|
||||
_create_workflows(cursor)
|
||||
|
||||
|
||||
def _create_board_images(cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the `board_images` table, indices and triggers."""
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS board_images (
|
||||
board_id TEXT NOT NULL,
|
||||
image_name TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between boards and images using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (image_name),
|
||||
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);",
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON board_images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE board_id = old.board_id AND image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def _create_boards(cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the `boards` table, indices and triggers."""
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS boards (
|
||||
board_id TEXT NOT NULL PRIMARY KEY,
|
||||
board_name TEXT NOT NULL,
|
||||
cover_image_name TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
indices = ["CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);"]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
|
||||
AFTER UPDATE
|
||||
ON boards FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE boards SET updated_at = current_timestamp
|
||||
WHERE board_id = old.board_id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def _create_images(cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the `images` table, indices and triggers. Adds the `starred` column."""
|
||||
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
image_name TEXT NOT NULL PRIMARY KEY,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_origin TEXT NOT NULL,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_category TEXT NOT NULL,
|
||||
width INTEGER NOT NULL,
|
||||
height INTEGER NOT NULL,
|
||||
session_id TEXT,
|
||||
node_id TEXT,
|
||||
metadata TEXT,
|
||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
indices = [
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);",
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
# Add the 'starred' column to `images` if it doesn't exist
|
||||
cursor.execute("PRAGMA table_info(images)")
|
||||
columns = [column[1] for column in cursor.fetchall()]
|
||||
|
||||
if "starred" not in columns:
|
||||
tables.append("ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;")
|
||||
indices.append("CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);")
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def _create_model_config(cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the `model_config` table, `model_manager_metadata` table, indices and triggers."""
|
||||
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_config (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
-- The next 3 fields are enums in python, unrestricted string here
|
||||
base TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
original_hash TEXT, -- could be null
|
||||
-- Serialized JSON representation of the whole config object,
|
||||
-- which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
""",
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_manager_metadata (
|
||||
metadata_key TEXT NOT NULL PRIMARY KEY,
|
||||
metadata_value TEXT NOT NULL
|
||||
);
|
||||
""",
|
||||
]
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
|
||||
AFTER UPDATE
|
||||
ON model_config FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
# Add indexes for searchable fields
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
|
||||
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
|
||||
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def _create_session_queue(cursor: sqlite3.Cursor) -> None:
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS session_queue (
|
||||
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
|
||||
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
|
||||
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
|
||||
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
|
||||
field_values TEXT, -- NULL if no values are associated with this queue item
|
||||
session TEXT NOT NULL, -- the session to be executed
|
||||
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
|
||||
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
|
||||
error TEXT, -- any errors associated with this queue item
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
|
||||
started_at DATETIME, -- updated via trigger
|
||||
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
|
||||
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
|
||||
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
indices = [
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);",
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'completed'
|
||||
OR NEW.status = 'failed'
|
||||
OR NEW.status = 'canceled'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
""",
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'in_progress'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
""",
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
|
||||
AFTER UPDATE
|
||||
ON session_queue FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = old.item_id;
|
||||
END;
|
||||
""",
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def _create_workflow_images(cursor: sqlite3.Cursor) -> None:
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS workflow_images (
|
||||
workflow_id TEXT NOT NULL,
|
||||
image_name TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between workflows and images using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (image_name),
|
||||
FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id ON workflow_images (workflow_id);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id_created_at ON workflow_images (workflow_id, created_at);",
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflow_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflow_images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE workflow_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = old.workflow_id AND image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def _create_workflows(cursor: sqlite3.Cursor) -> None:
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS workflows (
|
||||
workflow TEXT NOT NULL,
|
||||
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflows FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE workflows
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = old.workflow_id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
for stmt in tables + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
migration_1 = Migration(db_version=1, app_version="3.4.0", migrate=_migrate)
|
||||
"""
|
||||
Database version 1 (initial state).
|
||||
|
||||
This migration represents the state of the database circa InvokeAI v3.4.0, which was the last
|
||||
version to not use migrations to manage the database.
|
||||
|
||||
As such, this migration does include some ALTER statements, and the SQL statements are written
|
||||
to be idempotent.
|
||||
|
||||
- Create `board_images` junction table
|
||||
- Create `boards` table
|
||||
- Create `images` table, add `starred` column
|
||||
- Create `model_config` table
|
||||
- Create `session_queue` table
|
||||
- Create `workflow_images` junction table
|
||||
- Create `workflows` table
|
||||
"""
|
@ -0,0 +1,97 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
||||
|
||||
|
||||
def _migrate(cursor: sqlite3.Cursor) -> None:
|
||||
"""Migration callback for database version 2."""
|
||||
|
||||
_add_images_has_workflow(cursor)
|
||||
_add_session_queue_workflow(cursor)
|
||||
_drop_old_workflow_tables(cursor)
|
||||
_add_workflow_library(cursor)
|
||||
_drop_model_manager_metadata(cursor)
|
||||
|
||||
|
||||
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
|
||||
"""Add the `has_workflow` column to `images` table."""
|
||||
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
|
||||
|
||||
|
||||
def _add_session_queue_workflow(cursor: sqlite3.Cursor) -> None:
|
||||
"""Add the `workflow` column to `session_queue` table."""
|
||||
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
|
||||
|
||||
|
||||
def _drop_old_workflow_tables(cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the `workflows` and `workflow_images` tables."""
|
||||
cursor.execute("DROP TABLE workflow_images;")
|
||||
cursor.execute("DROP TABLE workflows;")
|
||||
|
||||
|
||||
def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
|
||||
"""Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables."""
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE workflow_library (
|
||||
workflow_id TEXT NOT NULL PRIMARY KEY,
|
||||
workflow TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated manually when retrieving workflow
|
||||
opened_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Generated columns, needed for indexing and searching
|
||||
category TEXT GENERATED ALWAYS as (json_extract(workflow, '$.meta.category')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(workflow, '$.name')) VIRTUAL NOT NULL,
|
||||
description TEXT GENERATED ALWAYS as (json_extract(workflow, '$.description')) VIRTUAL NOT NULL
|
||||
);
|
||||
""",
|
||||
]
|
||||
|
||||
indices = [
|
||||
"CREATE INDEX idx_workflow_library_created_at ON workflow_library(created_at);",
|
||||
"CREATE INDEX idx_workflow_library_updated_at ON workflow_library(updated_at);",
|
||||
"CREATE INDEX idx_workflow_library_opened_at ON workflow_library(opened_at);",
|
||||
"CREATE INDEX idx_workflow_library_category ON workflow_library(category);",
|
||||
"CREATE INDEX idx_workflow_library_name ON workflow_library(name);",
|
||||
"CREATE INDEX idx_workflow_library_description ON workflow_library(description);",
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER tg_workflow_library_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflow_library FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE workflow_library
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = old.workflow_id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the `model_manager_metadata` table."""
|
||||
cursor.execute("DROP TABLE model_manager_metadata;")
|
||||
|
||||
|
||||
migration_2 = Migration(
|
||||
db_version=2,
|
||||
app_version="3.5.0",
|
||||
migrate=_migrate,
|
||||
)
|
||||
"""
|
||||
Database version 2.
|
||||
|
||||
Introduced in v3.5.0 for the new workflow library.
|
||||
|
||||
- Add `has_workflow` column to `images` table
|
||||
- Add `workflow` column to `session_queue` table
|
||||
- Drop `workflows` and `workflow_images` tables
|
||||
- Add `workflow_library` table
|
||||
"""
|
@ -4,24 +4,27 @@ from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite.migrations.migration_1 import migration_1
|
||||
from invokeai.app.services.shared.sqlite.migrations.migration_2 import migration_2
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
||||
from invokeai.app.services.shared.sqlite.sqlite_migrator import SQLiteMigrator
|
||||
|
||||
|
||||
class SqliteDatabase:
|
||||
database: Path | str
|
||||
|
||||
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
|
||||
self._logger = logger
|
||||
self._config = config
|
||||
|
||||
if self._config.use_memory_db:
|
||||
self.db_path = sqlite_memory
|
||||
self.database = sqlite_memory
|
||||
logger.info("Using in-memory database")
|
||||
else:
|
||||
db_path = self._config.db_path
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.db_path = str(db_path)
|
||||
self._logger.info(f"Using database at {self.db_path}")
|
||||
self.database = self._config.db_path
|
||||
self.database.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._logger.info(f"Using database at {self.database}")
|
||||
|
||||
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
self.conn = sqlite3.connect(database=self.database, check_same_thread=False)
|
||||
self.lock = threading.RLock()
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
|
||||
@ -30,15 +33,20 @@ class SqliteDatabase:
|
||||
|
||||
self.conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
migrator = SQLiteMigrator(conn=self.conn, database=self.database, lock=self.lock, logger=self._logger)
|
||||
migrator.register_migration(migration_1)
|
||||
migrator.register_migration(migration_2)
|
||||
migrator.run_migrations()
|
||||
|
||||
def clean(self) -> None:
|
||||
with self.lock:
|
||||
try:
|
||||
if self.db_path == sqlite_memory:
|
||||
if self.database == sqlite_memory:
|
||||
return
|
||||
initial_db_size = Path(self.db_path).stat().st_size
|
||||
initial_db_size = Path(self.database).stat().st_size
|
||||
self.conn.execute("VACUUM;")
|
||||
self.conn.commit()
|
||||
final_db_size = Path(self.db_path).stat().st_size
|
||||
final_db_size = Path(self.database).stat().st_size
|
||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||
if freed_space_in_mb > 0:
|
||||
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||
|
210
invokeai/app/services/shared/sqlite/sqlite_migrator.py
Normal file
210
invokeai/app/services/shared/sqlite/sqlite_migrator.py
Normal file
@ -0,0 +1,210 @@
|
||||
import shutil
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, TypeAlias
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
||||
|
||||
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
|
||||
|
||||
|
||||
class MigrationError(Exception):
|
||||
"""Raised when a migration fails."""
|
||||
|
||||
|
||||
class MigrationVersionError(ValueError, MigrationError):
|
||||
"""Raised when a migration version is invalid."""
|
||||
|
||||
|
||||
class Migration:
|
||||
"""Represents a migration for a SQLite database.
|
||||
|
||||
:param db_version: The database schema version this migration results in.
|
||||
:param app_version: The app version this migration is introduced in.
|
||||
:param migrate: The callback to run to perform the migration. The callback will be passed a
|
||||
cursor to the database. The migrator will manage locking database access and committing the
|
||||
transaction; the callback should not do either of these things.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_version: int,
|
||||
app_version: str,
|
||||
migrate: MigrateCallback,
|
||||
) -> None:
|
||||
self.db_version = db_version
|
||||
self.app_version = app_version
|
||||
self.migrate = migrate
|
||||
|
||||
|
||||
class SQLiteMigrator:
|
||||
"""
|
||||
Manages migrations for a SQLite database.
|
||||
|
||||
:param conn: The database connection.
|
||||
:param database: The path to the database file, or ":memory:" for an in-memory database.
|
||||
:param lock: A lock to use when accessing the database.
|
||||
:param logger: The logger to use.
|
||||
|
||||
Migrations should be registered with :meth:`register_migration`. Migrations will be run in
|
||||
order of their version number. If the database is already at the latest version, no migrations
|
||||
will be run.
|
||||
"""
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection, database: Path | str, lock: threading.RLock, logger: Logger) -> None:
|
||||
self._logger = logger
|
||||
self._conn = conn
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = lock
|
||||
self._database = database
|
||||
self._migrations: set[Migration] = set()
|
||||
|
||||
def register_migration(self, migration: Migration) -> None:
|
||||
"""Registers a migration."""
|
||||
if not isinstance(migration.db_version, int) or migration.db_version < 1:
|
||||
raise MigrationVersionError(f"Invalid migration version {migration.db_version}")
|
||||
if any(m.db_version == migration.db_version for m in self._migrations):
|
||||
raise MigrationVersionError(f"Migration version {migration.db_version} already registered")
|
||||
self._migrations.add(migration)
|
||||
self._logger.debug(f"Registered migration {migration.db_version}")
|
||||
|
||||
def run_migrations(self) -> None:
|
||||
"""Migrates the database to the latest version."""
|
||||
with self._lock:
|
||||
self._create_version_table()
|
||||
sorted_migrations = sorted(self._migrations, key=lambda m: m.db_version)
|
||||
current_version = self._get_current_version()
|
||||
|
||||
if len(sorted_migrations) == 0:
|
||||
self._logger.debug("No migrations registered")
|
||||
return
|
||||
|
||||
latest_version = sorted_migrations[-1].db_version
|
||||
if current_version == latest_version:
|
||||
self._logger.debug("Database is up to date, no migrations to run")
|
||||
return
|
||||
|
||||
if current_version > latest_version:
|
||||
raise MigrationError(
|
||||
f"Database version {current_version} is greater than the latest migration version {latest_version}"
|
||||
)
|
||||
|
||||
self._logger.info("Database update needed")
|
||||
|
||||
# Only make a backup if using a file database (not memory)
|
||||
backup_path: Optional[Path] = None
|
||||
if isinstance(self._database, Path):
|
||||
backup_path = self._backup_db(self._database)
|
||||
else:
|
||||
self._logger.info("Using in-memory database, skipping backup")
|
||||
|
||||
for migration in sorted_migrations:
|
||||
try:
|
||||
self._run_migration(migration)
|
||||
except MigrationError:
|
||||
if backup_path is not None:
|
||||
self._logger.error(f" Restoring from {backup_path}")
|
||||
self._restore_db(backup_path)
|
||||
raise
|
||||
self._logger.info("Database updated successfully")
|
||||
|
||||
def _run_migration(self, migration: Migration) -> None:
|
||||
"""Runs a single migration."""
|
||||
with self._lock:
|
||||
current_version = self._get_current_version()
|
||||
try:
|
||||
if current_version >= migration.db_version:
|
||||
return
|
||||
migration.migrate(self._cursor)
|
||||
# Migration callbacks only get a cursor; they cannot commit the transaction.
|
||||
self._conn.commit()
|
||||
self._set_version(db_version=migration.db_version, app_version=migration.app_version)
|
||||
self._logger.debug(f"Successfully migrated database from {current_version} to {migration.db_version}")
|
||||
except Exception as e:
|
||||
msg = f"Error migrating database from {current_version} to {migration.db_version}: {e}"
|
||||
self._conn.rollback()
|
||||
self._logger.error(msg)
|
||||
raise MigrationError(msg) from e
|
||||
|
||||
def _create_version_table(self) -> None:
|
||||
"""Creates a version table for the database, if one does not already exist."""
|
||||
with self._lock:
|
||||
try:
|
||||
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='version';")
|
||||
if self._cursor.fetchone() is not None:
|
||||
return
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS version (
|
||||
db_version INTEGER PRIMARY KEY,
|
||||
app_version TEXT NOT NULL,
|
||||
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||
);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute("INSERT INTO version (db_version, app_version) VALUES (?,?);", (0, "0.0.0"))
|
||||
self._conn.commit()
|
||||
self._logger.debug("Created version table")
|
||||
except sqlite3.Error as e:
|
||||
msg = f"Problem creation version table: {e}"
|
||||
self._logger.error(msg)
|
||||
self._conn.rollback()
|
||||
raise MigrationError(msg) from e
|
||||
|
||||
def _get_current_version(self) -> int:
|
||||
"""Gets the current version of the database, or 0 if the version table does not exist."""
|
||||
with self._lock:
|
||||
try:
|
||||
self._cursor.execute("SELECT MAX(db_version) FROM version;")
|
||||
version = self._cursor.fetchone()[0]
|
||||
if version is None:
|
||||
return 0
|
||||
return version
|
||||
except sqlite3.OperationalError as e:
|
||||
if "no such table" in str(e):
|
||||
return 0
|
||||
raise
|
||||
|
||||
def _set_version(self, db_version: int, app_version: str) -> None:
|
||||
"""Adds a version entry to the table's version table."""
|
||||
with self._lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"INSERT INTO version (db_version, app_version) VALUES (?,?);", (db_version, app_version)
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
msg = f"Problem setting database version: {e}"
|
||||
self._logger.error(msg)
|
||||
self._conn.rollback()
|
||||
raise MigrationError(msg) from e
|
||||
|
||||
def _backup_db(self, db_path: Path | str) -> Path:
|
||||
"""Backs up the databse, returning the path to the backup file."""
|
||||
if db_path == sqlite_memory:
|
||||
raise MigrationError("Cannot back up memory database")
|
||||
if not isinstance(db_path, Path):
|
||||
raise MigrationError(f'Database path must be "{sqlite_memory}" or a Path')
|
||||
with self._lock:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_path = db_path.parent / f"{db_path.stem}_{timestamp}.db"
|
||||
self._logger.info(f"Backing up database to {backup_path}")
|
||||
backup_conn = sqlite3.connect(backup_path)
|
||||
with backup_conn:
|
||||
self._conn.backup(backup_conn)
|
||||
backup_conn.close()
|
||||
return backup_path
|
||||
|
||||
def _restore_db(self, backup_path: Path) -> None:
|
||||
"""Restores the database from a backup file, unless the database is a memory database."""
|
||||
if self._database == sqlite_memory:
|
||||
return
|
||||
with self._lock:
|
||||
self._logger.info(f"Restoring database from {backup_path}")
|
||||
self._conn.close()
|
||||
if not Path(backup_path).is_file():
|
||||
raise FileNotFoundError(f"Backup file {backup_path} does not exist")
|
||||
shutil.copy2(backup_path, self._database)
|
@ -26,7 +26,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
self._create_tables()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
@ -233,87 +232,3 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS workflow_library (
|
||||
workflow_id TEXT NOT NULL PRIMARY KEY,
|
||||
workflow TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated manually when retrieving workflow
|
||||
opened_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Generated columns, needed for indexing and searching
|
||||
category TEXT GENERATED ALWAYS as (json_extract(workflow, '$.meta.category')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(workflow, '$.name')) VIRTUAL NOT NULL,
|
||||
description TEXT GENERATED ALWAYS as (json_extract(workflow, '$.description')) VIRTUAL NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflow_library FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE workflow_library
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = old.workflow_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description);
|
||||
"""
|
||||
)
|
||||
|
||||
# We do not need the original `workflows` table or `workflow_images` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DROP TABLE IF EXISTS workflow_images;
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DROP TABLE IF EXISTS workflows;
|
||||
"""
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
173
tests/test_sqlite_migrator.py
Normal file
173
tests/test_sqlite_migrator.py
Normal file
@ -0,0 +1,173 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from copy import deepcopy
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
||||
from invokeai.app.services.shared.sqlite.sqlite_migrator import (
|
||||
Migration,
|
||||
MigrationError,
|
||||
MigrationVersionError,
|
||||
SQLiteMigrator,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migrator() -> SQLiteMigrator:
|
||||
conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||
return SQLiteMigrator(
|
||||
conn=conn, database=sqlite_memory, lock=threading.RLock(), logger=Logger("test_sqlite_migrator")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def good_migration() -> Migration:
|
||||
return Migration(db_version=1, app_version="1.0.0", migrate=lambda cursor: None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def failing_migration() -> Migration:
|
||||
def failing_migration(cursor: sqlite3.Cursor) -> None:
|
||||
raise Exception("Bad migration")
|
||||
|
||||
return Migration(db_version=1, app_version="1.0.0", migrate=failing_migration)
|
||||
|
||||
|
||||
def test_register_migration(migrator: SQLiteMigrator, good_migration: Migration):
|
||||
migration = good_migration
|
||||
migrator.register_migration(migration)
|
||||
assert migration in migrator._migrations
|
||||
with pytest.raises(MigrationError, match="Invalid migration version"):
|
||||
migrator.register_migration(Migration(db_version=0, app_version="0.0.0", migrate=lambda cursor: None))
|
||||
|
||||
|
||||
def test_register_invalid_migration_version(migrator: SQLiteMigrator):
|
||||
with pytest.raises(MigrationError, match="Invalid migration version"):
|
||||
migrator.register_migration(Migration(db_version=0, app_version="0.0.0", migrate=lambda cursor: None))
|
||||
|
||||
|
||||
def test_create_version_table(migrator: SQLiteMigrator):
|
||||
migrator._create_version_table()
|
||||
migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='version';")
|
||||
assert migrator._cursor.fetchone() is not None
|
||||
|
||||
|
||||
def test_get_current_version(migrator: SQLiteMigrator):
|
||||
migrator._create_version_table()
|
||||
migrator._conn.commit()
|
||||
assert migrator._get_current_version() == 0 # initial version
|
||||
|
||||
|
||||
def test_set_version(migrator: SQLiteMigrator):
|
||||
migrator._create_version_table()
|
||||
migrator._set_version(db_version=1, app_version="1.0.0")
|
||||
migrator._cursor.execute("SELECT MAX(db_version) FROM version;")
|
||||
assert migrator._cursor.fetchone()[0] == 1
|
||||
migrator._cursor.execute("SELECT app_version from version WHERE db_version = 1;")
|
||||
assert migrator._cursor.fetchone()[0] == "1.0.0"
|
||||
|
||||
|
||||
def test_run_migration(migrator: SQLiteMigrator):
|
||||
migrator._create_version_table()
|
||||
|
||||
def migration_callback(cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
|
||||
migration = Migration(db_version=1, app_version="1.0.0", migrate=migration_callback)
|
||||
migrator._run_migration(migration)
|
||||
assert migrator._get_current_version() == 1
|
||||
migrator._cursor.execute("SELECT app_version from version WHERE db_version = 1;")
|
||||
assert migrator._cursor.fetchone()[0] == "1.0.0"
|
||||
|
||||
|
||||
def test_run_migrations(migrator: SQLiteMigrator):
|
||||
migrator._create_version_table()
|
||||
|
||||
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
|
||||
def migrate(cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);")
|
||||
|
||||
return migrate
|
||||
|
||||
migrations = [Migration(db_version=i, app_version=f"{i}.0.0", migrate=create_migrate(i)) for i in range(1, 4)]
|
||||
for migration in migrations:
|
||||
migrator.register_migration(migration)
|
||||
migrator.run_migrations()
|
||||
assert migrator._get_current_version() == 3
|
||||
|
||||
|
||||
def test_backup_and_restore_db(migrator: SQLiteMigrator):
|
||||
with TemporaryDirectory() as tempdir:
|
||||
# must do this with a file database - we don't backup/restore for memory
|
||||
database = Path(tempdir) / "test.db"
|
||||
migrator._database = database
|
||||
migrator._cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
migrator._conn.commit()
|
||||
backup_path = migrator._backup_db(migrator._database)
|
||||
migrator._cursor.execute("DROP TABLE test;")
|
||||
migrator._conn.commit()
|
||||
migrator._restore_db(backup_path) # this closes the connection
|
||||
# reconnect to db
|
||||
restored_conn = sqlite3.connect(database)
|
||||
restored_cursor = restored_conn.cursor()
|
||||
restored_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
|
||||
assert restored_cursor.fetchone() is not None
|
||||
|
||||
|
||||
def test_no_backup_and_restore_for_memory_db(migrator: SQLiteMigrator):
|
||||
with pytest.raises(MigrationError, match="Cannot back up memory database"):
|
||||
migrator._backup_db(sqlite_memory)
|
||||
|
||||
|
||||
def test_failed_migration(migrator: SQLiteMigrator, failing_migration: Migration):
|
||||
migrator._create_version_table()
|
||||
with pytest.raises(MigrationError, match="Error migrating database from 0 to 1"):
|
||||
migrator._run_migration(failing_migration)
|
||||
assert migrator._get_current_version() == 0
|
||||
|
||||
|
||||
def test_duplicate_migration_versions(migrator: SQLiteMigrator, good_migration: Migration):
|
||||
migrator._create_version_table()
|
||||
migrator.register_migration(good_migration)
|
||||
with pytest.raises(MigrationVersionError, match="already registered"):
|
||||
migrator.register_migration(deepcopy(good_migration))
|
||||
|
||||
|
||||
def test_non_sequential_migration_registration(migrator: SQLiteMigrator):
|
||||
migrator._create_version_table()
|
||||
|
||||
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
|
||||
def migrate(cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);")
|
||||
|
||||
return migrate
|
||||
|
||||
migrations = [
|
||||
Migration(db_version=i, app_version=f"{i}.0.0", migrate=create_migrate(i)) for i in reversed(range(1, 4))
|
||||
]
|
||||
for migration in migrations:
|
||||
migrator.register_migration(migration)
|
||||
migrator.run_migrations()
|
||||
assert migrator._get_current_version() == 3
|
||||
|
||||
|
||||
def test_db_version_gt_last_migration(migrator: SQLiteMigrator, good_migration: Migration):
|
||||
migrator._create_version_table()
|
||||
migrator.register_migration(good_migration)
|
||||
migrator._set_version(db_version=2, app_version="2.0.0")
|
||||
with pytest.raises(MigrationError, match="greater than the latest migration version"):
|
||||
migrator.run_migrations()
|
||||
assert migrator._get_current_version() == 2
|
||||
|
||||
|
||||
def test_idempotent_migrations(migrator: SQLiteMigrator, good_migration: Migration):
|
||||
migrator._create_version_table()
|
||||
migrator.register_migration(good_migration)
|
||||
migrator.run_migrations()
|
||||
migrator.run_migrations()
|
||||
assert migrator._get_current_version() == 1
|
Loading…
Reference in New Issue
Block a user