mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): add SQLiteMigrator to perform db migrations
This commit is contained in:
parent
ef807cf63a
commit
f2c6819d68
@ -20,63 +20,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
|||||||
self._conn = db.conn
|
self._conn = db.conn
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._create_tables()
|
|
||||||
self._conn.commit()
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def _create_tables(self) -> None:
|
|
||||||
"""Creates the `board_images` junction table."""
|
|
||||||
|
|
||||||
# Create the `board_images` junction table.
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TABLE IF NOT EXISTS board_images (
|
|
||||||
board_id TEXT NOT NULL,
|
|
||||||
image_name TEXT NOT NULL,
|
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- updated via trigger
|
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- Soft delete, currently unused
|
|
||||||
deleted_at DATETIME,
|
|
||||||
-- enforce one-to-many relationship between boards and images using PK
|
|
||||||
-- (we can extend this to many-to-many later)
|
|
||||||
PRIMARY KEY (image_name),
|
|
||||||
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
|
|
||||||
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add index for board id
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add index for board id, sorted by created_at
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add trigger for `updated_at`.
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
|
|
||||||
AFTER UPDATE
|
|
||||||
ON board_images FOR EACH ROW
|
|
||||||
BEGIN
|
|
||||||
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
|
||||||
WHERE board_id = old.board_id AND image_name = old.image_name;
|
|
||||||
END;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_image_to_board(
|
def add_image_to_board(
|
||||||
self,
|
self,
|
||||||
board_id: str,
|
board_id: str,
|
||||||
|
@ -28,52 +28,6 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
|||||||
self._conn = db.conn
|
self._conn = db.conn
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._create_tables()
|
|
||||||
self._conn.commit()
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def _create_tables(self) -> None:
|
|
||||||
"""Creates the `boards` table and `board_images` junction table."""
|
|
||||||
|
|
||||||
# Create the `boards` table.
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TABLE IF NOT EXISTS boards (
|
|
||||||
board_id TEXT NOT NULL PRIMARY KEY,
|
|
||||||
board_name TEXT NOT NULL,
|
|
||||||
cover_image_name TEXT,
|
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- Updated via trigger
|
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- Soft delete, currently unused
|
|
||||||
deleted_at DATETIME,
|
|
||||||
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add trigger for `updated_at`.
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
|
|
||||||
AFTER UPDATE
|
|
||||||
ON boards FOR EACH ROW
|
|
||||||
BEGIN
|
|
||||||
UPDATE boards SET updated_at = current_timestamp
|
|
||||||
WHERE board_id = old.board_id;
|
|
||||||
END;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete(self, board_id: str) -> None:
|
def delete(self, board_id: str) -> None:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
@ -32,101 +32,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
self._conn = db.conn
|
self._conn = db.conn
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._create_tables()
|
|
||||||
self._conn.commit()
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def _create_tables(self) -> None:
|
|
||||||
"""Creates the `images` table."""
|
|
||||||
|
|
||||||
# Create the `images` table.
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TABLE IF NOT EXISTS images (
|
|
||||||
image_name TEXT NOT NULL PRIMARY KEY,
|
|
||||||
-- This is an enum in python, unrestricted string here for flexibility
|
|
||||||
image_origin TEXT NOT NULL,
|
|
||||||
-- This is an enum in python, unrestricted string here for flexibility
|
|
||||||
image_category TEXT NOT NULL,
|
|
||||||
width INTEGER NOT NULL,
|
|
||||||
height INTEGER NOT NULL,
|
|
||||||
session_id TEXT,
|
|
||||||
node_id TEXT,
|
|
||||||
metadata TEXT,
|
|
||||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- Updated via trigger
|
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- Soft delete, currently unused
|
|
||||||
deleted_at DATETIME
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self._cursor.execute("PRAGMA table_info(images)")
|
|
||||||
columns = [column[1] for column in self._cursor.fetchall()]
|
|
||||||
|
|
||||||
if "starred" not in columns:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create the `images` table indices.
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add trigger for `updated_at`.
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
|
|
||||||
AFTER UPDATE
|
|
||||||
ON images FOR EACH ROW
|
|
||||||
BEGIN
|
|
||||||
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
|
||||||
WHERE image_name = old.image_name;
|
|
||||||
END;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self._cursor.execute("PRAGMA table_info(images)")
|
|
||||||
columns = [column[1] for column in self._cursor.fetchall()]
|
|
||||||
if "has_workflow" not in columns:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
ALTER TABLE images
|
|
||||||
ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
def get(self, image_name: str) -> ImageRecord:
|
def get(self, image_name: str) -> ImageRecord:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
@ -9,9 +9,6 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
||||||
|
|
||||||
# should match the InvokeAI version when this is first released.
|
|
||||||
CONFIG_FILE_VERSION = "3.2.0"
|
|
||||||
|
|
||||||
|
|
||||||
class DuplicateModelException(Exception):
|
class DuplicateModelException(Exception):
|
||||||
"""Raised on an attempt to add a model with the same key twice."""
|
"""Raised on an attempt to add a model with the same key twice."""
|
||||||
@ -32,12 +29,6 @@ class ConfigFileVersionMismatchException(Exception):
|
|||||||
class ModelRecordServiceBase(ABC):
|
class ModelRecordServiceBase(ABC):
|
||||||
"""Abstract base class for storage and retrieval of model configs."""
|
"""Abstract base class for storage and retrieval of model configs."""
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def version(self) -> str:
|
|
||||||
"""Return the config file/database schema version."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
|
@ -54,7 +54,6 @@ from invokeai.backend.model_manager.config import (
|
|||||||
|
|
||||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from .model_records_base import (
|
from .model_records_base import (
|
||||||
CONFIG_FILE_VERSION,
|
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
ModelRecordServiceBase,
|
ModelRecordServiceBase,
|
||||||
UnknownModelException,
|
UnknownModelException,
|
||||||
@ -78,85 +77,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
self._db = db
|
self._db = db
|
||||||
self._cursor = self._db.conn.cursor()
|
self._cursor = self._db.conn.cursor()
|
||||||
|
|
||||||
with self._db.lock:
|
|
||||||
# Enable foreign keys
|
|
||||||
self._db.conn.execute("PRAGMA foreign_keys = ON;")
|
|
||||||
self._create_tables()
|
|
||||||
self._db.conn.commit()
|
|
||||||
assert (
|
|
||||||
str(self.version) == CONFIG_FILE_VERSION
|
|
||||||
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
|
||||||
|
|
||||||
def _create_tables(self) -> None:
|
|
||||||
"""Create sqlite3 tables."""
|
|
||||||
# model_config table breaks out the fields that are common to all config objects
|
|
||||||
# and puts class-specific ones in a serialized json object
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TABLE IF NOT EXISTS model_config (
|
|
||||||
id TEXT NOT NULL PRIMARY KEY,
|
|
||||||
-- The next 3 fields are enums in python, unrestricted string here
|
|
||||||
base TEXT NOT NULL,
|
|
||||||
type TEXT NOT NULL,
|
|
||||||
name TEXT NOT NULL,
|
|
||||||
path TEXT NOT NULL,
|
|
||||||
original_hash TEXT, -- could be null
|
|
||||||
-- Serialized JSON representation of the whole config object,
|
|
||||||
-- which will contain additional fields from subclasses
|
|
||||||
config TEXT NOT NULL,
|
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- Updated via trigger
|
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- unique constraint on combo of name, base and type
|
|
||||||
UNIQUE(name, base, type)
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# metadata table
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TABLE IF NOT EXISTS model_manager_metadata (
|
|
||||||
metadata_key TEXT NOT NULL PRIMARY KEY,
|
|
||||||
metadata_value TEXT NOT NULL
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add trigger for `updated_at`.
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
|
|
||||||
AFTER UPDATE
|
|
||||||
ON model_config FOR EACH ROW
|
|
||||||
BEGIN
|
|
||||||
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
|
||||||
WHERE id = old.id;
|
|
||||||
END;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add indexes for searchable fields
|
|
||||||
for stmt in [
|
|
||||||
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
|
|
||||||
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
|
|
||||||
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
|
|
||||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
|
|
||||||
]:
|
|
||||||
self._cursor.execute(stmt)
|
|
||||||
|
|
||||||
# Add our version to the metadata table
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
INSERT OR IGNORE into model_manager_metadata (
|
|
||||||
metadata_key,
|
|
||||||
metadata_value
|
|
||||||
)
|
|
||||||
VALUES (?,?);
|
|
||||||
""",
|
|
||||||
("version", CONFIG_FILE_VERSION),
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Add a model to the database.
|
Add a model to the database.
|
||||||
@ -214,22 +134,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
|
|
||||||
return self.get_model(key)
|
return self.get_model(key)
|
||||||
|
|
||||||
@property
|
|
||||||
def version(self) -> str:
|
|
||||||
"""Return the version of the database schema."""
|
|
||||||
with self._db.lock:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT metadata_value FROM model_manager_metadata
|
|
||||||
WHERE metadata_key=?;
|
|
||||||
""",
|
|
||||||
("version",),
|
|
||||||
)
|
|
||||||
rows = self._cursor.fetchone()
|
|
||||||
if not rows:
|
|
||||||
raise KeyError("Models database does not have metadata key 'version'")
|
|
||||||
return rows[0]
|
|
||||||
|
|
||||||
def del_model(self, key: str) -> None:
|
def del_model(self, key: str) -> None:
|
||||||
"""
|
"""
|
||||||
Delete a model.
|
Delete a model.
|
||||||
|
@ -50,7 +50,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
self.__lock = db.lock
|
self.__lock = db.lock
|
||||||
self.__conn = db.conn
|
self.__conn = db.conn
|
||||||
self.__cursor = self.__conn.cursor()
|
self.__cursor = self.__conn.cursor()
|
||||||
self._create_tables()
|
|
||||||
|
|
||||||
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
|
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
|
||||||
return event[1]["event"] in match_in
|
return event[1]["event"] in match_in
|
||||||
@ -98,123 +97,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
except SessionQueueItemNotFoundError:
|
except SessionQueueItemNotFoundError:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _create_tables(self) -> None:
|
|
||||||
"""Creates the session queue tables, indicies, and triggers"""
|
|
||||||
try:
|
|
||||||
self.__lock.acquire()
|
|
||||||
self.__cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TABLE IF NOT EXISTS session_queue (
|
|
||||||
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
|
|
||||||
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
|
|
||||||
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
|
|
||||||
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
|
|
||||||
field_values TEXT, -- NULL if no values are associated with this queue item
|
|
||||||
session TEXT NOT NULL, -- the session to be executed
|
|
||||||
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
|
|
||||||
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
|
|
||||||
error TEXT, -- any errors associated with this queue item
|
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
|
|
||||||
started_at DATETIME, -- updated via trigger
|
|
||||||
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
|
|
||||||
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
|
|
||||||
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self.__cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self.__cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self.__cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self.__cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self.__cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self.__cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
|
|
||||||
AFTER UPDATE OF status ON session_queue
|
|
||||||
FOR EACH ROW
|
|
||||||
WHEN
|
|
||||||
NEW.status = 'completed'
|
|
||||||
OR NEW.status = 'failed'
|
|
||||||
OR NEW.status = 'canceled'
|
|
||||||
BEGIN
|
|
||||||
UPDATE session_queue
|
|
||||||
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
|
||||||
WHERE item_id = NEW.item_id;
|
|
||||||
END;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self.__cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
|
|
||||||
AFTER UPDATE OF status ON session_queue
|
|
||||||
FOR EACH ROW
|
|
||||||
WHEN
|
|
||||||
NEW.status = 'in_progress'
|
|
||||||
BEGIN
|
|
||||||
UPDATE session_queue
|
|
||||||
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
|
||||||
WHERE item_id = NEW.item_id;
|
|
||||||
END;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self.__cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
|
|
||||||
AFTER UPDATE
|
|
||||||
ON session_queue FOR EACH ROW
|
|
||||||
BEGIN
|
|
||||||
UPDATE session_queue
|
|
||||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
|
||||||
WHERE item_id = old.item_id;
|
|
||||||
END;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self.__cursor.execute("PRAGMA table_info(session_queue)")
|
|
||||||
columns = [column[1] for column in self.__cursor.fetchall()]
|
|
||||||
if "workflow" not in columns:
|
|
||||||
self.__cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
ALTER TABLE session_queue ADD COLUMN workflow TEXT;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self.__conn.commit()
|
|
||||||
except Exception:
|
|
||||||
self.__conn.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self.__lock.release()
|
|
||||||
|
|
||||||
def _set_in_progress_to_canceled(self) -> None:
|
def _set_in_progress_to_canceled(self) -> None:
|
||||||
"""
|
"""
|
||||||
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
|
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
|
||||||
|
370
invokeai/app/services/shared/sqlite/migrations/migration_1.py
Normal file
370
invokeai/app/services/shared/sqlite/migrations/migration_1.py
Normal file
@ -0,0 +1,370 @@
|
|||||||
|
import sqlite3
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate(cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""Migration callback for database version 1."""
|
||||||
|
|
||||||
|
_create_board_images(cursor)
|
||||||
|
_create_boards(cursor)
|
||||||
|
_create_images(cursor)
|
||||||
|
_create_model_config(cursor)
|
||||||
|
_create_session_queue(cursor)
|
||||||
|
_create_workflow_images(cursor)
|
||||||
|
_create_workflows(cursor)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_board_images(cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""Creates the `board_images` table, indices and triggers."""
|
||||||
|
tables = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS board_images (
|
||||||
|
board_id TEXT NOT NULL,
|
||||||
|
image_name TEXT NOT NULL,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- updated via trigger
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Soft delete, currently unused
|
||||||
|
deleted_at DATETIME,
|
||||||
|
-- enforce one-to-many relationship between boards and images using PK
|
||||||
|
-- (we can extend this to many-to-many later)
|
||||||
|
PRIMARY KEY (image_name),
|
||||||
|
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
indices = [
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);",
|
||||||
|
]
|
||||||
|
|
||||||
|
triggers = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON board_images FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE board_id = old.board_id AND image_name = old.image_name;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
for stmt in tables + indices + triggers:
|
||||||
|
cursor.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_boards(cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""Creates the `boards` table, indices and triggers."""
|
||||||
|
tables = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS boards (
|
||||||
|
board_id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
board_name TEXT NOT NULL,
|
||||||
|
cover_image_name TEXT,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Updated via trigger
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Soft delete, currently unused
|
||||||
|
deleted_at DATETIME,
|
||||||
|
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
indices = ["CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);"]
|
||||||
|
|
||||||
|
triggers = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON boards FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE boards SET updated_at = current_timestamp
|
||||||
|
WHERE board_id = old.board_id;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
for stmt in tables + indices + triggers:
|
||||||
|
cursor.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_images(cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""Creates the `images` table, indices and triggers. Adds the `starred` column."""
|
||||||
|
|
||||||
|
tables = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS images (
|
||||||
|
image_name TEXT NOT NULL PRIMARY KEY,
|
||||||
|
-- This is an enum in python, unrestricted string here for flexibility
|
||||||
|
image_origin TEXT NOT NULL,
|
||||||
|
-- This is an enum in python, unrestricted string here for flexibility
|
||||||
|
image_category TEXT NOT NULL,
|
||||||
|
width INTEGER NOT NULL,
|
||||||
|
height INTEGER NOT NULL,
|
||||||
|
session_id TEXT,
|
||||||
|
node_id TEXT,
|
||||||
|
metadata TEXT,
|
||||||
|
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Updated via trigger
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Soft delete, currently unused
|
||||||
|
deleted_at DATETIME
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
indices = [
|
||||||
|
"CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);",
|
||||||
|
]
|
||||||
|
|
||||||
|
triggers = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON images FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE image_name = old.image_name;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add the 'starred' column to `images` if it doesn't exist
|
||||||
|
cursor.execute("PRAGMA table_info(images)")
|
||||||
|
columns = [column[1] for column in cursor.fetchall()]
|
||||||
|
|
||||||
|
if "starred" not in columns:
|
||||||
|
tables.append("ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;")
|
||||||
|
indices.append("CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);")
|
||||||
|
|
||||||
|
for stmt in tables + indices + triggers:
|
||||||
|
cursor.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_model_config(cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""Creates the `model_config` table, `model_manager_metadata` table, indices and triggers."""
|
||||||
|
|
||||||
|
tables = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS model_config (
|
||||||
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
-- The next 3 fields are enums in python, unrestricted string here
|
||||||
|
base TEXT NOT NULL,
|
||||||
|
type TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
path TEXT NOT NULL,
|
||||||
|
original_hash TEXT, -- could be null
|
||||||
|
-- Serialized JSON representation of the whole config object,
|
||||||
|
-- which will contain additional fields from subclasses
|
||||||
|
config TEXT NOT NULL,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Updated via trigger
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- unique constraint on combo of name, base and type
|
||||||
|
UNIQUE(name, base, type)
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS model_manager_metadata (
|
||||||
|
metadata_key TEXT NOT NULL PRIMARY KEY,
|
||||||
|
metadata_value TEXT NOT NULL
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add trigger for `updated_at`.
|
||||||
|
triggers = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON model_config FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE id = old.id;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add indexes for searchable fields
|
||||||
|
indices = [
|
||||||
|
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
|
||||||
|
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
|
||||||
|
]
|
||||||
|
|
||||||
|
for stmt in tables + indices + triggers:
|
||||||
|
cursor.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_session_queue(cursor: sqlite3.Cursor) -> None:
|
||||||
|
tables = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS session_queue (
|
||||||
|
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
|
||||||
|
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
|
||||||
|
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
|
||||||
|
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
|
||||||
|
field_values TEXT, -- NULL if no values are associated with this queue item
|
||||||
|
session TEXT NOT NULL, -- the session to be executed
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
|
||||||
|
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
|
||||||
|
error TEXT, -- any errors associated with this queue item
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
|
||||||
|
started_at DATETIME, -- updated via trigger
|
||||||
|
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
|
||||||
|
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
|
||||||
|
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
indices = [
|
||||||
|
"CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);",
|
||||||
|
"CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);",
|
||||||
|
]
|
||||||
|
|
||||||
|
triggers = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
|
||||||
|
AFTER UPDATE OF status ON session_queue
|
||||||
|
FOR EACH ROW
|
||||||
|
WHEN
|
||||||
|
NEW.status = 'completed'
|
||||||
|
OR NEW.status = 'failed'
|
||||||
|
OR NEW.status = 'canceled'
|
||||||
|
BEGIN
|
||||||
|
UPDATE session_queue
|
||||||
|
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE item_id = NEW.item_id;
|
||||||
|
END;
|
||||||
|
""",
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
|
||||||
|
AFTER UPDATE OF status ON session_queue
|
||||||
|
FOR EACH ROW
|
||||||
|
WHEN
|
||||||
|
NEW.status = 'in_progress'
|
||||||
|
BEGIN
|
||||||
|
UPDATE session_queue
|
||||||
|
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE item_id = NEW.item_id;
|
||||||
|
END;
|
||||||
|
""",
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON session_queue FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE session_queue
|
||||||
|
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE item_id = old.item_id;
|
||||||
|
END;
|
||||||
|
""",
|
||||||
|
]
|
||||||
|
|
||||||
|
for stmt in tables + indices + triggers:
|
||||||
|
cursor.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_workflow_images(cursor: sqlite3.Cursor) -> None:
|
||||||
|
tables = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS workflow_images (
|
||||||
|
workflow_id TEXT NOT NULL,
|
||||||
|
image_name TEXT NOT NULL,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- updated via trigger
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Soft delete, currently unused
|
||||||
|
deleted_at DATETIME,
|
||||||
|
-- enforce one-to-many relationship between workflows and images using PK
|
||||||
|
-- (we can extend this to many-to-many later)
|
||||||
|
PRIMARY KEY (image_name),
|
||||||
|
FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
indices = [
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id ON workflow_images (workflow_id);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id_created_at ON workflow_images (workflow_id, created_at);",
|
||||||
|
]
|
||||||
|
|
||||||
|
triggers = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_workflow_images_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON workflow_images FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE workflow_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE workflow_id = old.workflow_id AND image_name = old.image_name;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
for stmt in tables + indices + triggers:
|
||||||
|
cursor.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_workflows(cursor: sqlite3.Cursor) -> None:
|
||||||
|
tables = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS workflows (
|
||||||
|
workflow TEXT NOT NULL,
|
||||||
|
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
triggers = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON workflows FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE workflows
|
||||||
|
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE workflow_id = old.workflow_id;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
for stmt in tables + triggers:
|
||||||
|
cursor.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
migration_1 = Migration(db_version=1, app_version="3.4.0", migrate=_migrate)
|
||||||
|
"""
|
||||||
|
Database version 1 (initial state).
|
||||||
|
|
||||||
|
This migration represents the state of the database circa InvokeAI v3.4.0, which was the last
|
||||||
|
version to not use migrations to manage the database.
|
||||||
|
|
||||||
|
As such, this migration does include some ALTER statements, and the SQL statements are written
|
||||||
|
to be idempotent.
|
||||||
|
|
||||||
|
- Create `board_images` junction table
|
||||||
|
- Create `boards` table
|
||||||
|
- Create `images` table, add `starred` column
|
||||||
|
- Create `model_config` table
|
||||||
|
- Create `session_queue` table
|
||||||
|
- Create `workflow_images` junction table
|
||||||
|
- Create `workflows` table
|
||||||
|
"""
|
@ -0,0 +1,97 @@
|
|||||||
|
import sqlite3
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate(cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""Migration callback for database version 2."""
|
||||||
|
|
||||||
|
_add_images_has_workflow(cursor)
|
||||||
|
_add_session_queue_workflow(cursor)
|
||||||
|
_drop_old_workflow_tables(cursor)
|
||||||
|
_add_workflow_library(cursor)
|
||||||
|
_drop_model_manager_metadata(cursor)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""Add the `has_workflow` column to `images` table."""
|
||||||
|
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
|
||||||
|
|
||||||
|
|
||||||
|
def _add_session_queue_workflow(cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""Add the `workflow` column to `session_queue` table."""
|
||||||
|
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
|
||||||
|
|
||||||
|
|
||||||
|
def _drop_old_workflow_tables(cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""Drops the `workflows` and `workflow_images` tables."""
|
||||||
|
cursor.execute("DROP TABLE workflow_images;")
|
||||||
|
cursor.execute("DROP TABLE workflows;")
|
||||||
|
|
||||||
|
|
||||||
|
def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables."""
|
||||||
|
tables = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE workflow_library (
|
||||||
|
workflow_id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
workflow TEXT NOT NULL,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- updated via trigger
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- updated manually when retrieving workflow
|
||||||
|
opened_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Generated columns, needed for indexing and searching
|
||||||
|
category TEXT GENERATED ALWAYS as (json_extract(workflow, '$.meta.category')) VIRTUAL NOT NULL,
|
||||||
|
name TEXT GENERATED ALWAYS as (json_extract(workflow, '$.name')) VIRTUAL NOT NULL,
|
||||||
|
description TEXT GENERATED ALWAYS as (json_extract(workflow, '$.description')) VIRTUAL NOT NULL
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
]
|
||||||
|
|
||||||
|
indices = [
|
||||||
|
"CREATE INDEX idx_workflow_library_created_at ON workflow_library(created_at);",
|
||||||
|
"CREATE INDEX idx_workflow_library_updated_at ON workflow_library(updated_at);",
|
||||||
|
"CREATE INDEX idx_workflow_library_opened_at ON workflow_library(opened_at);",
|
||||||
|
"CREATE INDEX idx_workflow_library_category ON workflow_library(category);",
|
||||||
|
"CREATE INDEX idx_workflow_library_name ON workflow_library(name);",
|
||||||
|
"CREATE INDEX idx_workflow_library_description ON workflow_library(description);",
|
||||||
|
]
|
||||||
|
|
||||||
|
triggers = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER tg_workflow_library_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON workflow_library FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE workflow_library
|
||||||
|
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE workflow_id = old.workflow_id;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
for stmt in tables + indices + triggers:
|
||||||
|
cursor.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""Drops the `model_manager_metadata` table."""
|
||||||
|
cursor.execute("DROP TABLE model_manager_metadata;")
|
||||||
|
|
||||||
|
|
||||||
|
migration_2 = Migration(
|
||||||
|
db_version=2,
|
||||||
|
app_version="3.5.0",
|
||||||
|
migrate=_migrate,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
Database version 2.
|
||||||
|
|
||||||
|
Introduced in v3.5.0 for the new workflow library.
|
||||||
|
|
||||||
|
- Add `has_workflow` column to `images` table
|
||||||
|
- Add `workflow` column to `session_queue` table
|
||||||
|
- Drop `workflows` and `workflow_images` tables
|
||||||
|
- Add `workflow_library` table
|
||||||
|
"""
|
@ -4,24 +4,27 @@ from logging import Logger
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.shared.sqlite.migrations.migration_1 import migration_1
|
||||||
|
from invokeai.app.services.shared.sqlite.migrations.migration_2 import migration_2
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_migrator import SQLiteMigrator
|
||||||
|
|
||||||
|
|
||||||
class SqliteDatabase:
|
class SqliteDatabase:
|
||||||
|
database: Path | str
|
||||||
|
|
||||||
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
|
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
if self._config.use_memory_db:
|
if self._config.use_memory_db:
|
||||||
self.db_path = sqlite_memory
|
self.database = sqlite_memory
|
||||||
logger.info("Using in-memory database")
|
logger.info("Using in-memory database")
|
||||||
else:
|
else:
|
||||||
db_path = self._config.db_path
|
self.database = self._config.db_path
|
||||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
self.database.parent.mkdir(parents=True, exist_ok=True)
|
||||||
self.db_path = str(db_path)
|
self._logger.info(f"Using database at {self.database}")
|
||||||
self._logger.info(f"Using database at {self.db_path}")
|
|
||||||
|
|
||||||
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
self.conn = sqlite3.connect(database=self.database, check_same_thread=False)
|
||||||
self.lock = threading.RLock()
|
self.lock = threading.RLock()
|
||||||
self.conn.row_factory = sqlite3.Row
|
self.conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
@ -30,15 +33,20 @@ class SqliteDatabase:
|
|||||||
|
|
||||||
self.conn.execute("PRAGMA foreign_keys = ON;")
|
self.conn.execute("PRAGMA foreign_keys = ON;")
|
||||||
|
|
||||||
|
migrator = SQLiteMigrator(conn=self.conn, database=self.database, lock=self.lock, logger=self._logger)
|
||||||
|
migrator.register_migration(migration_1)
|
||||||
|
migrator.register_migration(migration_2)
|
||||||
|
migrator.run_migrations()
|
||||||
|
|
||||||
def clean(self) -> None:
|
def clean(self) -> None:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
try:
|
try:
|
||||||
if self.db_path == sqlite_memory:
|
if self.database == sqlite_memory:
|
||||||
return
|
return
|
||||||
initial_db_size = Path(self.db_path).stat().st_size
|
initial_db_size = Path(self.database).stat().st_size
|
||||||
self.conn.execute("VACUUM;")
|
self.conn.execute("VACUUM;")
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
final_db_size = Path(self.db_path).stat().st_size
|
final_db_size = Path(self.database).stat().st_size
|
||||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||||
if freed_space_in_mb > 0:
|
if freed_space_in_mb > 0:
|
||||||
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||||
|
210
invokeai/app/services/shared/sqlite/sqlite_migrator.py
Normal file
210
invokeai/app/services/shared/sqlite/sqlite_migrator.py
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
import shutil
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
from datetime import datetime
|
||||||
|
from logging import Logger
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Optional, TypeAlias
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
||||||
|
|
||||||
|
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
|
||||||
|
|
||||||
|
|
||||||
|
class MigrationError(Exception):
|
||||||
|
"""Raised when a migration fails."""
|
||||||
|
|
||||||
|
|
||||||
|
class MigrationVersionError(ValueError, MigrationError):
|
||||||
|
"""Raised when a migration version is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class Migration:
|
||||||
|
"""Represents a migration for a SQLite database.
|
||||||
|
|
||||||
|
:param db_version: The database schema version this migration results in.
|
||||||
|
:param app_version: The app version this migration is introduced in.
|
||||||
|
:param migrate: The callback to run to perform the migration. The callback will be passed a
|
||||||
|
cursor to the database. The migrator will manage locking database access and committing the
|
||||||
|
transaction; the callback should not do either of these things.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_version: int,
|
||||||
|
app_version: str,
|
||||||
|
migrate: MigrateCallback,
|
||||||
|
) -> None:
|
||||||
|
self.db_version = db_version
|
||||||
|
self.app_version = app_version
|
||||||
|
self.migrate = migrate
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteMigrator:
|
||||||
|
"""
|
||||||
|
Manages migrations for a SQLite database.
|
||||||
|
|
||||||
|
:param conn: The database connection.
|
||||||
|
:param database: The path to the database file, or ":memory:" for an in-memory database.
|
||||||
|
:param lock: A lock to use when accessing the database.
|
||||||
|
:param logger: The logger to use.
|
||||||
|
|
||||||
|
Migrations should be registered with :meth:`register_migration`. Migrations will be run in
|
||||||
|
order of their version number. If the database is already at the latest version, no migrations
|
||||||
|
will be run.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, conn: sqlite3.Connection, database: Path | str, lock: threading.RLock, logger: Logger) -> None:
|
||||||
|
self._logger = logger
|
||||||
|
self._conn = conn
|
||||||
|
self._cursor = self._conn.cursor()
|
||||||
|
self._lock = lock
|
||||||
|
self._database = database
|
||||||
|
self._migrations: set[Migration] = set()
|
||||||
|
|
||||||
|
def register_migration(self, migration: Migration) -> None:
|
||||||
|
"""Registers a migration."""
|
||||||
|
if not isinstance(migration.db_version, int) or migration.db_version < 1:
|
||||||
|
raise MigrationVersionError(f"Invalid migration version {migration.db_version}")
|
||||||
|
if any(m.db_version == migration.db_version for m in self._migrations):
|
||||||
|
raise MigrationVersionError(f"Migration version {migration.db_version} already registered")
|
||||||
|
self._migrations.add(migration)
|
||||||
|
self._logger.debug(f"Registered migration {migration.db_version}")
|
||||||
|
|
||||||
|
def run_migrations(self) -> None:
|
||||||
|
"""Migrates the database to the latest version."""
|
||||||
|
with self._lock:
|
||||||
|
self._create_version_table()
|
||||||
|
sorted_migrations = sorted(self._migrations, key=lambda m: m.db_version)
|
||||||
|
current_version = self._get_current_version()
|
||||||
|
|
||||||
|
if len(sorted_migrations) == 0:
|
||||||
|
self._logger.debug("No migrations registered")
|
||||||
|
return
|
||||||
|
|
||||||
|
latest_version = sorted_migrations[-1].db_version
|
||||||
|
if current_version == latest_version:
|
||||||
|
self._logger.debug("Database is up to date, no migrations to run")
|
||||||
|
return
|
||||||
|
|
||||||
|
if current_version > latest_version:
|
||||||
|
raise MigrationError(
|
||||||
|
f"Database version {current_version} is greater than the latest migration version {latest_version}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._logger.info("Database update needed")
|
||||||
|
|
||||||
|
# Only make a backup if using a file database (not memory)
|
||||||
|
backup_path: Optional[Path] = None
|
||||||
|
if isinstance(self._database, Path):
|
||||||
|
backup_path = self._backup_db(self._database)
|
||||||
|
else:
|
||||||
|
self._logger.info("Using in-memory database, skipping backup")
|
||||||
|
|
||||||
|
for migration in sorted_migrations:
|
||||||
|
try:
|
||||||
|
self._run_migration(migration)
|
||||||
|
except MigrationError:
|
||||||
|
if backup_path is not None:
|
||||||
|
self._logger.error(f" Restoring from {backup_path}")
|
||||||
|
self._restore_db(backup_path)
|
||||||
|
raise
|
||||||
|
self._logger.info("Database updated successfully")
|
||||||
|
|
||||||
|
def _run_migration(self, migration: Migration) -> None:
|
||||||
|
"""Runs a single migration."""
|
||||||
|
with self._lock:
|
||||||
|
current_version = self._get_current_version()
|
||||||
|
try:
|
||||||
|
if current_version >= migration.db_version:
|
||||||
|
return
|
||||||
|
migration.migrate(self._cursor)
|
||||||
|
# Migration callbacks only get a cursor; they cannot commit the transaction.
|
||||||
|
self._conn.commit()
|
||||||
|
self._set_version(db_version=migration.db_version, app_version=migration.app_version)
|
||||||
|
self._logger.debug(f"Successfully migrated database from {current_version} to {migration.db_version}")
|
||||||
|
except Exception as e:
|
||||||
|
msg = f"Error migrating database from {current_version} to {migration.db_version}: {e}"
|
||||||
|
self._conn.rollback()
|
||||||
|
self._logger.error(msg)
|
||||||
|
raise MigrationError(msg) from e
|
||||||
|
|
||||||
|
def _create_version_table(self) -> None:
|
||||||
|
"""Creates a version table for the database, if one does not already exist."""
|
||||||
|
with self._lock:
|
||||||
|
try:
|
||||||
|
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='version';")
|
||||||
|
if self._cursor.fetchone() is not None:
|
||||||
|
return
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS version (
|
||||||
|
db_version INTEGER PRIMARY KEY,
|
||||||
|
app_version TEXT NOT NULL,
|
||||||
|
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self._cursor.execute("INSERT INTO version (db_version, app_version) VALUES (?,?);", (0, "0.0.0"))
|
||||||
|
self._conn.commit()
|
||||||
|
self._logger.debug("Created version table")
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
msg = f"Problem creation version table: {e}"
|
||||||
|
self._logger.error(msg)
|
||||||
|
self._conn.rollback()
|
||||||
|
raise MigrationError(msg) from e
|
||||||
|
|
||||||
|
def _get_current_version(self) -> int:
|
||||||
|
"""Gets the current version of the database, or 0 if the version table does not exist."""
|
||||||
|
with self._lock:
|
||||||
|
try:
|
||||||
|
self._cursor.execute("SELECT MAX(db_version) FROM version;")
|
||||||
|
version = self._cursor.fetchone()[0]
|
||||||
|
if version is None:
|
||||||
|
return 0
|
||||||
|
return version
|
||||||
|
except sqlite3.OperationalError as e:
|
||||||
|
if "no such table" in str(e):
|
||||||
|
return 0
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _set_version(self, db_version: int, app_version: str) -> None:
|
||||||
|
"""Adds a version entry to the table's version table."""
|
||||||
|
with self._lock:
|
||||||
|
try:
|
||||||
|
self._cursor.execute(
|
||||||
|
"INSERT INTO version (db_version, app_version) VALUES (?,?);", (db_version, app_version)
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
msg = f"Problem setting database version: {e}"
|
||||||
|
self._logger.error(msg)
|
||||||
|
self._conn.rollback()
|
||||||
|
raise MigrationError(msg) from e
|
||||||
|
|
||||||
|
def _backup_db(self, db_path: Path | str) -> Path:
|
||||||
|
"""Backs up the databse, returning the path to the backup file."""
|
||||||
|
if db_path == sqlite_memory:
|
||||||
|
raise MigrationError("Cannot back up memory database")
|
||||||
|
if not isinstance(db_path, Path):
|
||||||
|
raise MigrationError(f'Database path must be "{sqlite_memory}" or a Path')
|
||||||
|
with self._lock:
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
backup_path = db_path.parent / f"{db_path.stem}_{timestamp}.db"
|
||||||
|
self._logger.info(f"Backing up database to {backup_path}")
|
||||||
|
backup_conn = sqlite3.connect(backup_path)
|
||||||
|
with backup_conn:
|
||||||
|
self._conn.backup(backup_conn)
|
||||||
|
backup_conn.close()
|
||||||
|
return backup_path
|
||||||
|
|
||||||
|
def _restore_db(self, backup_path: Path) -> None:
|
||||||
|
"""Restores the database from a backup file, unless the database is a memory database."""
|
||||||
|
if self._database == sqlite_memory:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
self._logger.info(f"Restoring database from {backup_path}")
|
||||||
|
self._conn.close()
|
||||||
|
if not Path(backup_path).is_file():
|
||||||
|
raise FileNotFoundError(f"Backup file {backup_path} does not exist")
|
||||||
|
shutil.copy2(backup_path, self._database)
|
@ -26,7 +26,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
|||||||
self._lock = db.lock
|
self._lock = db.lock
|
||||||
self._conn = db.conn
|
self._conn = db.conn
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
self._create_tables()
|
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
def start(self, invoker: Invoker) -> None:
|
||||||
self._invoker = invoker
|
self._invoker = invoker
|
||||||
@ -233,87 +232,3 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
|||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def _create_tables(self) -> None:
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TABLE IF NOT EXISTS workflow_library (
|
|
||||||
workflow_id TEXT NOT NULL PRIMARY KEY,
|
|
||||||
workflow TEXT NOT NULL,
|
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- updated via trigger
|
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- updated manually when retrieving workflow
|
|
||||||
opened_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- Generated columns, needed for indexing and searching
|
|
||||||
category TEXT GENERATED ALWAYS as (json_extract(workflow, '$.meta.category')) VIRTUAL NOT NULL,
|
|
||||||
name TEXT GENERATED ALWAYS as (json_extract(workflow, '$.name')) VIRTUAL NOT NULL,
|
|
||||||
description TEXT GENERATED ALWAYS as (json_extract(workflow, '$.description')) VIRTUAL NOT NULL
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at
|
|
||||||
AFTER UPDATE
|
|
||||||
ON workflow_library FOR EACH ROW
|
|
||||||
BEGIN
|
|
||||||
UPDATE workflow_library
|
|
||||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
|
||||||
WHERE workflow_id = old.workflow_id;
|
|
||||||
END;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# We do not need the original `workflows` table or `workflow_images` junction table.
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
DROP TABLE IF EXISTS workflow_images;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
DROP TABLE IF EXISTS workflows;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self._conn.commit()
|
|
||||||
except Exception:
|
|
||||||
self._conn.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
173
tests/test_sqlite_migrator.py
Normal file
173
tests/test_sqlite_migrator.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
from copy import deepcopy
|
||||||
|
from logging import Logger
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_migrator import (
|
||||||
|
Migration,
|
||||||
|
MigrationError,
|
||||||
|
MigrationVersionError,
|
||||||
|
SQLiteMigrator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def migrator() -> SQLiteMigrator:
|
||||||
|
conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||||
|
return SQLiteMigrator(
|
||||||
|
conn=conn, database=sqlite_memory, lock=threading.RLock(), logger=Logger("test_sqlite_migrator")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def good_migration() -> Migration:
|
||||||
|
return Migration(db_version=1, app_version="1.0.0", migrate=lambda cursor: None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def failing_migration() -> Migration:
|
||||||
|
def failing_migration(cursor: sqlite3.Cursor) -> None:
|
||||||
|
raise Exception("Bad migration")
|
||||||
|
|
||||||
|
return Migration(db_version=1, app_version="1.0.0", migrate=failing_migration)
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_migration(migrator: SQLiteMigrator, good_migration: Migration):
|
||||||
|
migration = good_migration
|
||||||
|
migrator.register_migration(migration)
|
||||||
|
assert migration in migrator._migrations
|
||||||
|
with pytest.raises(MigrationError, match="Invalid migration version"):
|
||||||
|
migrator.register_migration(Migration(db_version=0, app_version="0.0.0", migrate=lambda cursor: None))
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_invalid_migration_version(migrator: SQLiteMigrator):
|
||||||
|
with pytest.raises(MigrationError, match="Invalid migration version"):
|
||||||
|
migrator.register_migration(Migration(db_version=0, app_version="0.0.0", migrate=lambda cursor: None))
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_version_table(migrator: SQLiteMigrator):
|
||||||
|
migrator._create_version_table()
|
||||||
|
migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='version';")
|
||||||
|
assert migrator._cursor.fetchone() is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_current_version(migrator: SQLiteMigrator):
|
||||||
|
migrator._create_version_table()
|
||||||
|
migrator._conn.commit()
|
||||||
|
assert migrator._get_current_version() == 0 # initial version
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_version(migrator: SQLiteMigrator):
|
||||||
|
migrator._create_version_table()
|
||||||
|
migrator._set_version(db_version=1, app_version="1.0.0")
|
||||||
|
migrator._cursor.execute("SELECT MAX(db_version) FROM version;")
|
||||||
|
assert migrator._cursor.fetchone()[0] == 1
|
||||||
|
migrator._cursor.execute("SELECT app_version from version WHERE db_version = 1;")
|
||||||
|
assert migrator._cursor.fetchone()[0] == "1.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_migration(migrator: SQLiteMigrator):
|
||||||
|
migrator._create_version_table()
|
||||||
|
|
||||||
|
def migration_callback(cursor: sqlite3.Cursor) -> None:
|
||||||
|
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||||
|
|
||||||
|
migration = Migration(db_version=1, app_version="1.0.0", migrate=migration_callback)
|
||||||
|
migrator._run_migration(migration)
|
||||||
|
assert migrator._get_current_version() == 1
|
||||||
|
migrator._cursor.execute("SELECT app_version from version WHERE db_version = 1;")
|
||||||
|
assert migrator._cursor.fetchone()[0] == "1.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_migrations(migrator: SQLiteMigrator):
|
||||||
|
migrator._create_version_table()
|
||||||
|
|
||||||
|
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
|
||||||
|
def migrate(cursor: sqlite3.Cursor) -> None:
|
||||||
|
cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);")
|
||||||
|
|
||||||
|
return migrate
|
||||||
|
|
||||||
|
migrations = [Migration(db_version=i, app_version=f"{i}.0.0", migrate=create_migrate(i)) for i in range(1, 4)]
|
||||||
|
for migration in migrations:
|
||||||
|
migrator.register_migration(migration)
|
||||||
|
migrator.run_migrations()
|
||||||
|
assert migrator._get_current_version() == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_backup_and_restore_db(migrator: SQLiteMigrator):
|
||||||
|
with TemporaryDirectory() as tempdir:
|
||||||
|
# must do this with a file database - we don't backup/restore for memory
|
||||||
|
database = Path(tempdir) / "test.db"
|
||||||
|
migrator._database = database
|
||||||
|
migrator._cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||||
|
migrator._conn.commit()
|
||||||
|
backup_path = migrator._backup_db(migrator._database)
|
||||||
|
migrator._cursor.execute("DROP TABLE test;")
|
||||||
|
migrator._conn.commit()
|
||||||
|
migrator._restore_db(backup_path) # this closes the connection
|
||||||
|
# reconnect to db
|
||||||
|
restored_conn = sqlite3.connect(database)
|
||||||
|
restored_cursor = restored_conn.cursor()
|
||||||
|
restored_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
|
||||||
|
assert restored_cursor.fetchone() is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_backup_and_restore_for_memory_db(migrator: SQLiteMigrator):
|
||||||
|
with pytest.raises(MigrationError, match="Cannot back up memory database"):
|
||||||
|
migrator._backup_db(sqlite_memory)
|
||||||
|
|
||||||
|
|
||||||
|
def test_failed_migration(migrator: SQLiteMigrator, failing_migration: Migration):
|
||||||
|
migrator._create_version_table()
|
||||||
|
with pytest.raises(MigrationError, match="Error migrating database from 0 to 1"):
|
||||||
|
migrator._run_migration(failing_migration)
|
||||||
|
assert migrator._get_current_version() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_duplicate_migration_versions(migrator: SQLiteMigrator, good_migration: Migration):
|
||||||
|
migrator._create_version_table()
|
||||||
|
migrator.register_migration(good_migration)
|
||||||
|
with pytest.raises(MigrationVersionError, match="already registered"):
|
||||||
|
migrator.register_migration(deepcopy(good_migration))
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_sequential_migration_registration(migrator: SQLiteMigrator):
|
||||||
|
migrator._create_version_table()
|
||||||
|
|
||||||
|
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
|
||||||
|
def migrate(cursor: sqlite3.Cursor) -> None:
|
||||||
|
cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);")
|
||||||
|
|
||||||
|
return migrate
|
||||||
|
|
||||||
|
migrations = [
|
||||||
|
Migration(db_version=i, app_version=f"{i}.0.0", migrate=create_migrate(i)) for i in reversed(range(1, 4))
|
||||||
|
]
|
||||||
|
for migration in migrations:
|
||||||
|
migrator.register_migration(migration)
|
||||||
|
migrator.run_migrations()
|
||||||
|
assert migrator._get_current_version() == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_db_version_gt_last_migration(migrator: SQLiteMigrator, good_migration: Migration):
|
||||||
|
migrator._create_version_table()
|
||||||
|
migrator.register_migration(good_migration)
|
||||||
|
migrator._set_version(db_version=2, app_version="2.0.0")
|
||||||
|
with pytest.raises(MigrationError, match="greater than the latest migration version"):
|
||||||
|
migrator.run_migrations()
|
||||||
|
assert migrator._get_current_version() == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_idempotent_migrations(migrator: SQLiteMigrator, good_migration: Migration):
|
||||||
|
migrator._create_version_table()
|
||||||
|
migrator.register_migration(good_migration)
|
||||||
|
migrator.run_migrations()
|
||||||
|
migrator.run_migrations()
|
||||||
|
assert migrator._get_current_version() == 1
|
Loading…
Reference in New Issue
Block a user