mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): add SQLiteMigrator to perform db migrations (#5227)
## What type of PR is this? (check all applicable) - [x] Refactor - [x] Feature - [ ] Bug Fix - [x] Optimization - [ ] Documentation Update - [ ] Community Node Submission ## Have you discussed this change with the InvokeAI team? - [x] Yes - [ ] No, because: ## Description This PR enhances our SQLite database with migration logic. ### `SQLiteMigrator` class The new `SQLiteMigrator` class handles safely running database migrations. It is initialized in the `SqliteDatabase` class's init, and immediately runs all database migrations. ### `Migration` class Migrations are reprsented by a `Migration` class, which has 3 attributes: - `db_version: int`: The database version this migration results in. - `app_version: str`: The semver app version this migration is run for. - `migrate: Callable[[sqlite3.Cursor], None]`: A function that performs the migration. It receives a cursor _only_, but can do anything it wants to do. A convention is established for these functions. All schema-creating SQL now lives in a `migrate` function. We haven't needed to make any data migrations yet, but when we do, this will also be handled within one of these callbacks. ### Migration Flow First, migrations are registered with `SQLiteMigrator` with it's `register_migration` method. This performs some basic checks of the migration version. After registering all migrations, they are run with the `run_migrations` method. This does a few things: - Creates a `version` table in the DB, if it doesn't already exist. This table has `db_version INTEGER`, `app_version TEXT` and `migrated_at DATETIME` columns. - Sort the migrations by their `db_version`. - Do some checks to see if we need a migration. - Backs up the database (if it's a file database). The migration bails out if this fails. - Runs each migration. If there is a problem, restore from backup. ### Included Migrations Migrations are in `invokeai/app/services/shared/sqlite/migrations`. #### `migrate_1.py` All\* schema SQL up to 3.4.0post2 is in `migration_1.py`. Running only this migration should result in a database that is identical to the one you get from starting up 3.4.0post2. SQL in this migration is **idempotent** (same as it was when the SQL was spread across the various services). #### `migrate_2.py` Schema changes through 3.5.0 (the upcoming release) are in `migration_2.py`. SQL in this migration is **not idempotent**. Future migrations need not be idempotent, as the migration logic ensures each will only be run once. ### \*Caveat - ItemStorage This class provides a generic document-db-like interface for storing objects. Our `graph_executions` and `graphs` tables are created and managed by this service. This PR does not touch this class and therefore does not touch either of those two tables. We can decide how to handle those tables in the future as the need arises. ### Change to Model Manager Metadata table I noticed that there is a `model_manager_metadata` table which included the app version, and whose `version` property wasn't accessed outside the service. I believe the new `version` table fulfills the purpose of this table, and have removed it. @lstein Please let me know if this is not right. ## QA Instructions, Screenshots, Recordings 1. Case 1 - Upgrade - Back up your 3.4.0post2 database - Run this PR - It should upgrade your database and everything should work exactly like it did before 2. Case 2 - New Install - Move your database out of the invoke root so that when the app starts, it creates a new one - Run this PR - It should work just like a new install 3. Case 3 - With an In-Memory Database - Enable the in-memory memory database (set `use_memory_db` under `Paths` in `invokeai.yaml` to `true`) - Run this PR - It should work just like a new install ## Added/updated tests? - [x] Yes: Fairly comprehensive tests are added for the `SQLiteMigrator`. - [ ] No : _please replace this line with details on why tests have not been included_
This commit is contained in:
commit
6c6c45c3da
@ -2,6 +2,7 @@
|
||||
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
@ -30,7 +31,6 @@ from ..services.session_processor.session_processor_default import DefaultSessio
|
||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from ..services.shared.default_graphs import create_system_graphs
|
||||
from ..services.shared.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from ..services.urls.urls_default import LocalUrlService
|
||||
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from .events import FastAPIEventService
|
||||
@ -67,8 +67,9 @@ class ApiDependencies:
|
||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||
|
||||
output_folder = config.output_path
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
db = SqliteDatabase(config, logger)
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
|
||||
configuration = config
|
||||
logger = logger
|
||||
@ -80,7 +81,6 @@ class ApiDependencies:
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||
graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
|
@ -20,63 +20,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `board_images` junction table."""
|
||||
|
||||
# Create the `board_images` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS board_images (
|
||||
board_id TEXT NOT NULL,
|
||||
image_name TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between boards and images using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (image_name),
|
||||
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for board id
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for board id, sorted by created_at
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON board_images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE board_id = old.board_id AND image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
|
@ -28,52 +28,6 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `boards` table and `board_images` junction table."""
|
||||
|
||||
# Create the `boards` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS boards (
|
||||
board_id TEXT NOT NULL PRIMARY KEY,
|
||||
board_name TEXT NOT NULL,
|
||||
cover_image_name TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
|
||||
AFTER UPDATE
|
||||
ON boards FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE boards SET updated_at = current_timestamp
|
||||
WHERE board_id = old.board_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
@ -32,101 +32,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `images` table."""
|
||||
|
||||
# Create the `images` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
image_name TEXT NOT NULL PRIMARY KEY,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_origin TEXT NOT NULL,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_category TEXT NOT NULL,
|
||||
width INTEGER NOT NULL,
|
||||
height INTEGER NOT NULL,
|
||||
session_id TEXT,
|
||||
node_id TEXT,
|
||||
metadata TEXT,
|
||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute("PRAGMA table_info(images)")
|
||||
columns = [column[1] for column in self._cursor.fetchall()]
|
||||
|
||||
if "starred" not in columns:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `images` table indices.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute("PRAGMA table_info(images)")
|
||||
columns = [column[1] for column in self._cursor.fetchall()]
|
||||
if "has_workflow" not in columns:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
ALTER TABLE images
|
||||
ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;
|
||||
"""
|
||||
)
|
||||
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
@ -9,9 +9,6 @@ from typing import List, Optional, Union
|
||||
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
||||
|
||||
# should match the InvokeAI version when this is first released.
|
||||
CONFIG_FILE_VERSION = "3.2.0"
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
"""Raised on an attempt to add a model with the same key twice."""
|
||||
@ -32,12 +29,6 @@ class ConfigFileVersionMismatchException(Exception):
|
||||
class ModelRecordServiceBase(ABC):
|
||||
"""Abstract base class for storage and retrieval of model configs."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def version(self) -> str:
|
||||
"""Return the config file/database schema version."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
|
@ -54,7 +54,6 @@ from invokeai.backend.model_manager.config import (
|
||||
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_records_base import (
|
||||
CONFIG_FILE_VERSION,
|
||||
DuplicateModelException,
|
||||
ModelRecordServiceBase,
|
||||
UnknownModelException,
|
||||
@ -78,86 +77,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
with self._db.lock:
|
||||
# Enable foreign keys
|
||||
self._db.conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._db.conn.commit()
|
||||
assert (
|
||||
str(self.version) == CONFIG_FILE_VERSION
|
||||
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Create sqlite3 tables."""
|
||||
# model_config table breaks out the fields that are common to all config objects
|
||||
# and puts class-specific ones in a serialized json object
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_config (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
-- The next 3 fields are enums in python, unrestricted string here
|
||||
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||
original_hash TEXT, -- could be null
|
||||
-- Serialized JSON representation of the whole config object,
|
||||
-- which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# metadata table
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_manager_metadata (
|
||||
metadata_key TEXT NOT NULL PRIMARY KEY,
|
||||
metadata_value TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
|
||||
AFTER UPDATE
|
||||
ON model_config FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
# Add indexes for searchable fields
|
||||
for stmt in [
|
||||
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
|
||||
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
|
||||
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
|
||||
]:
|
||||
self._cursor.execute(stmt)
|
||||
|
||||
# Add our version to the metadata table
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE into model_manager_metadata (
|
||||
metadata_key,
|
||||
metadata_value
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
("version", CONFIG_FILE_VERSION),
|
||||
)
|
||||
|
||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
@ -207,22 +126,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
return self.get_model(key)
|
||||
|
||||
@property
|
||||
def version(self) -> str:
|
||||
"""Return the version of the database schema."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata_value FROM model_manager_metadata
|
||||
WHERE metadata_key=?;
|
||||
""",
|
||||
("version",),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise KeyError("Models database does not have metadata key 'version'")
|
||||
return rows[0]
|
||||
|
||||
def del_model(self, key: str) -> None:
|
||||
"""
|
||||
Delete a model.
|
||||
|
@ -50,7 +50,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__lock = db.lock
|
||||
self.__conn = db.conn
|
||||
self.__cursor = self.__conn.cursor()
|
||||
self._create_tables()
|
||||
|
||||
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
|
||||
return event[1]["event"] in match_in
|
||||
@ -98,123 +97,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the session queue tables, indicies, and triggers"""
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS session_queue (
|
||||
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
|
||||
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
|
||||
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
|
||||
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
|
||||
field_values TEXT, -- NULL if no values are associated with this queue item
|
||||
session TEXT NOT NULL, -- the session to be executed
|
||||
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
|
||||
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
|
||||
error TEXT, -- any errors associated with this queue item
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
|
||||
started_at DATETIME, -- updated via trigger
|
||||
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
|
||||
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
|
||||
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'completed'
|
||||
OR NEW.status = 'failed'
|
||||
OR NEW.status = 'canceled'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'in_progress'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
|
||||
AFTER UPDATE
|
||||
ON session_queue FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = old.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute("PRAGMA table_info(session_queue)")
|
||||
columns = [column[1] for column in self.__cursor.fetchall()]
|
||||
if "workflow" not in columns:
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
ALTER TABLE session_queue ADD COLUMN workflow TEXT;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""
|
||||
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
|
||||
|
@ -3,45 +3,65 @@ import threading
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
||||
|
||||
|
||||
class SqliteDatabase:
|
||||
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
|
||||
self._logger = logger
|
||||
self._config = config
|
||||
"""
|
||||
Manages a connection to an SQLite database.
|
||||
|
||||
if self._config.use_memory_db:
|
||||
self.db_path = sqlite_memory
|
||||
logger.info("Using in-memory database")
|
||||
:param db_path: Path to the database file. If None, an in-memory database is used.
|
||||
:param logger: Logger to use for logging.
|
||||
:param verbose: Whether to log SQL statements. Provides `logger.debug` as the SQLite trace callback.
|
||||
|
||||
This is a light wrapper around the `sqlite3` module, providing a few conveniences:
|
||||
- The database file is written to disk if it does not exist.
|
||||
- Foreign key constraints are enabled by default.
|
||||
- The connection is configured to use the `sqlite3.Row` row factory.
|
||||
|
||||
In addition to the constructor args, the instance provides the following attributes and methods:
|
||||
- `conn`: A `sqlite3.Connection` object. Note that the connection must never be closed if the database is in-memory.
|
||||
- `lock`: A shared re-entrant lock, used to approximate thread safety.
|
||||
- `clean()`: Runs the SQL `VACUUM;` command and reports on the freed space.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
|
||||
"""Initializes the database. This is used internally by the class constructor."""
|
||||
self.logger = logger
|
||||
self.db_path = db_path
|
||||
self.verbose = verbose
|
||||
|
||||
if not self.db_path:
|
||||
logger.info("Initializing in-memory database")
|
||||
else:
|
||||
db_path = self._config.db_path
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.db_path = str(db_path)
|
||||
self._logger.info(f"Using database at {self.db_path}")
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.logger.info(f"Initializing database at {self.db_path}")
|
||||
|
||||
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
|
||||
self.lock = threading.RLock()
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
|
||||
if self._config.log_sql:
|
||||
self.conn.set_trace_callback(self._logger.debug)
|
||||
if self.verbose:
|
||||
self.conn.set_trace_callback(self.logger.debug)
|
||||
|
||||
self.conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
def clean(self) -> None:
|
||||
"""
|
||||
Cleans the database by running the VACUUM command, reporting on the freed space.
|
||||
"""
|
||||
# No need to clean in-memory database
|
||||
if not self.db_path:
|
||||
return
|
||||
with self.lock:
|
||||
try:
|
||||
if self.db_path == sqlite_memory:
|
||||
return
|
||||
initial_db_size = Path(self.db_path).stat().st_size
|
||||
self.conn.execute("VACUUM;")
|
||||
self.conn.commit()
|
||||
final_db_size = Path(self.db_path).stat().st_size
|
||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||
if freed_space_in_mb > 0:
|
||||
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error cleaning database: {e}")
|
||||
self.logger.error(f"Error cleaning database: {e}")
|
||||
raise
|
||||
|
32
invokeai/app/services/shared/sqlite/sqlite_util.py
Normal file
32
invokeai/app/services/shared/sqlite/sqlite_util.py
Normal file
@ -0,0 +1,32 @@
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import build_migration_2
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileStorageBase) -> SqliteDatabase:
|
||||
"""
|
||||
Initializes the SQLite database.
|
||||
|
||||
:param config: The app config
|
||||
:param logger: The logger
|
||||
:param image_files: The image files service (used by migration 2)
|
||||
|
||||
This function:
|
||||
- Instantiates a :class:`SqliteDatabase`
|
||||
- Instantiates a :class:`SqliteMigrator` and registers all migrations
|
||||
- Runs all migrations
|
||||
"""
|
||||
db_path = None if config.use_memory_db else config.db_path
|
||||
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
|
||||
|
||||
migrator = SqliteMigrator(db=db)
|
||||
migrator.register_migration(build_migration_1())
|
||||
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
@ -0,0 +1,372 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration1Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Migration callback for database version 1."""
|
||||
|
||||
self._create_board_images(cursor)
|
||||
self._create_boards(cursor)
|
||||
self._create_images(cursor)
|
||||
self._create_model_config(cursor)
|
||||
self._create_session_queue(cursor)
|
||||
self._create_workflow_images(cursor)
|
||||
self._create_workflows(cursor)
|
||||
|
||||
def _create_board_images(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the `board_images` table, indices and triggers."""
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS board_images (
|
||||
board_id TEXT NOT NULL,
|
||||
image_name TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between boards and images using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (image_name),
|
||||
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);",
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON board_images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE board_id = old.board_id AND image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
def _create_boards(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the `boards` table, indices and triggers."""
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS boards (
|
||||
board_id TEXT NOT NULL PRIMARY KEY,
|
||||
board_name TEXT NOT NULL,
|
||||
cover_image_name TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
indices = ["CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);"]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
|
||||
AFTER UPDATE
|
||||
ON boards FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE boards SET updated_at = current_timestamp
|
||||
WHERE board_id = old.board_id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
def _create_images(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the `images` table, indices and triggers. Adds the `starred` column."""
|
||||
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
image_name TEXT NOT NULL PRIMARY KEY,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_origin TEXT NOT NULL,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_category TEXT NOT NULL,
|
||||
width INTEGER NOT NULL,
|
||||
height INTEGER NOT NULL,
|
||||
session_id TEXT,
|
||||
node_id TEXT,
|
||||
metadata TEXT,
|
||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
indices = [
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);",
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
# Add the 'starred' column to `images` if it doesn't exist
|
||||
cursor.execute("PRAGMA table_info(images)")
|
||||
columns = [column[1] for column in cursor.fetchall()]
|
||||
|
||||
if "starred" not in columns:
|
||||
tables.append("ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;")
|
||||
indices.append("CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);")
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
def _create_model_config(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the `model_config` table, `model_manager_metadata` table, indices and triggers."""
|
||||
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_config (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
-- The next 3 fields are enums in python, unrestricted string here
|
||||
base TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
original_hash TEXT, -- could be null
|
||||
-- Serialized JSON representation of the whole config object,
|
||||
-- which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
""",
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_manager_metadata (
|
||||
metadata_key TEXT NOT NULL PRIMARY KEY,
|
||||
metadata_value TEXT NOT NULL
|
||||
);
|
||||
""",
|
||||
]
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
|
||||
AFTER UPDATE
|
||||
ON model_config FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
# Add indexes for searchable fields
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
|
||||
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
|
||||
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
def _create_session_queue(self, cursor: sqlite3.Cursor) -> None:
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS session_queue (
|
||||
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
|
||||
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
|
||||
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
|
||||
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
|
||||
field_values TEXT, -- NULL if no values are associated with this queue item
|
||||
session TEXT NOT NULL, -- the session to be executed
|
||||
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
|
||||
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
|
||||
error TEXT, -- any errors associated with this queue item
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
|
||||
started_at DATETIME, -- updated via trigger
|
||||
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
|
||||
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
|
||||
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
indices = [
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);",
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'completed'
|
||||
OR NEW.status = 'failed'
|
||||
OR NEW.status = 'canceled'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
""",
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'in_progress'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
""",
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
|
||||
AFTER UPDATE
|
||||
ON session_queue FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = old.item_id;
|
||||
END;
|
||||
""",
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
def _create_workflow_images(self, cursor: sqlite3.Cursor) -> None:
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS workflow_images (
|
||||
workflow_id TEXT NOT NULL,
|
||||
image_name TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between workflows and images using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (image_name),
|
||||
FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id ON workflow_images (workflow_id);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id_created_at ON workflow_images (workflow_id, created_at);",
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflow_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflow_images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE workflow_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = old.workflow_id AND image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
def _create_workflows(self, cursor: sqlite3.Cursor) -> None:
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS workflows (
|
||||
workflow TEXT NOT NULL,
|
||||
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflows FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE workflows
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = old.workflow_id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
for stmt in tables + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def build_migration_1() -> Migration:
|
||||
"""
|
||||
Builds the migration from database version 0 (init) to 1.
|
||||
|
||||
This migration represents the state of the database circa InvokeAI v3.4.0, which was the last
|
||||
version to not use migrations to manage the database.
|
||||
|
||||
As such, this migration does include some ALTER statements, and the SQL statements are written
|
||||
to be idempotent.
|
||||
|
||||
- Create `board_images` junction table
|
||||
- Create `boards` table
|
||||
- Create `images` table, add `starred` column
|
||||
- Create `model_config` table
|
||||
- Create `session_queue` table
|
||||
- Create `workflow_images` junction table
|
||||
- Create `workflows` table
|
||||
"""
|
||||
|
||||
migration_1 = Migration(
|
||||
from_version=0,
|
||||
to_version=1,
|
||||
callback=Migration1Callback(),
|
||||
)
|
||||
|
||||
return migration_1
|
@ -0,0 +1,184 @@
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration2Callback:
|
||||
def __init__(self, image_files: ImageFileStorageBase, logger: Logger):
|
||||
self._image_files = image_files
|
||||
self._logger = logger
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor):
|
||||
self._add_images_has_workflow(cursor)
|
||||
self._add_session_queue_workflow(cursor)
|
||||
self._drop_old_workflow_tables(cursor)
|
||||
self._add_workflow_library(cursor)
|
||||
self._drop_model_manager_metadata(cursor)
|
||||
self._recreate_model_config(cursor)
|
||||
self._migrate_embedded_workflows(cursor)
|
||||
|
||||
def _add_images_has_workflow(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Add the `has_workflow` column to `images` table."""
|
||||
cursor.execute("PRAGMA table_info(images)")
|
||||
columns = [column[1] for column in cursor.fetchall()]
|
||||
|
||||
if "has_workflow" not in columns:
|
||||
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
|
||||
|
||||
def _add_session_queue_workflow(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Add the `workflow` column to `session_queue` table."""
|
||||
|
||||
cursor.execute("PRAGMA table_info(session_queue)")
|
||||
columns = [column[1] for column in cursor.fetchall()]
|
||||
|
||||
if "workflow" not in columns:
|
||||
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
|
||||
|
||||
def _drop_old_workflow_tables(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the `workflows` and `workflow_images` tables."""
|
||||
cursor.execute("DROP TABLE IF EXISTS workflow_images;")
|
||||
cursor.execute("DROP TABLE IF EXISTS workflows;")
|
||||
|
||||
def _add_workflow_library(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables."""
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS workflow_library (
|
||||
workflow_id TEXT NOT NULL PRIMARY KEY,
|
||||
workflow TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated manually when retrieving workflow
|
||||
opened_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Generated columns, needed for indexing and searching
|
||||
category TEXT GENERATED ALWAYS as (json_extract(workflow, '$.meta.category')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(workflow, '$.name')) VIRTUAL NOT NULL,
|
||||
description TEXT GENERATED ALWAYS as (json_extract(workflow, '$.description')) VIRTUAL NOT NULL
|
||||
);
|
||||
""",
|
||||
]
|
||||
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description);",
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflow_library FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE workflow_library
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = old.workflow_id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
def _drop_model_manager_metadata(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the `model_manager_metadata` table."""
|
||||
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
|
||||
|
||||
def _recreate_model_config(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
Drops the `model_config` table, recreating it.
|
||||
|
||||
In 3.4.0, this table used explicit columns but was changed to use json_extract 3.5.0.
|
||||
|
||||
Because this table is not used in production, we are able to simply drop it and recreate it.
|
||||
"""
|
||||
|
||||
cursor.execute("DROP TABLE IF EXISTS model_config;")
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_config (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
-- The next 3 fields are enums in python, unrestricted string here
|
||||
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||
original_hash TEXT, -- could be null
|
||||
-- Serialized JSON representation of the whole config object,
|
||||
-- which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _migrate_embedded_workflows(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
|
||||
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
|
||||
|
||||
This migrate callback checks each image for the presence of an embedded workflow, then updates its entry
|
||||
in the database accordingly.
|
||||
"""
|
||||
# Get the total number of images and chunk it into pages
|
||||
cursor.execute("SELECT image_name FROM images")
|
||||
image_names: list[str] = [image[0] for image in cursor.fetchall()]
|
||||
total_image_names = len(image_names)
|
||||
|
||||
if not total_image_names:
|
||||
return
|
||||
|
||||
self._logger.info(f"Migrating workflows for {total_image_names} images")
|
||||
|
||||
# Migrate the images
|
||||
to_migrate: list[tuple[bool, str]] = []
|
||||
pbar = tqdm(image_names)
|
||||
for idx, image_name in enumerate(pbar):
|
||||
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
|
||||
pil_image = self._image_files.get(image_name)
|
||||
if "invokeai_workflow" in pil_image.info:
|
||||
to_migrate.append((True, image_name))
|
||||
|
||||
self._logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
|
||||
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)
|
||||
|
||||
|
||||
def build_migration_2(image_files: ImageFileStorageBase, logger: Logger) -> Migration:
|
||||
"""
|
||||
Builds the migration from database version 1 to 2.
|
||||
|
||||
Introduced in v3.5.0 for the new workflow library.
|
||||
|
||||
:param image_files: The image files service, used to check for embedded workflows
|
||||
:param logger: The logger, used to log progress during embedded workflows handling
|
||||
|
||||
This migration does the following:
|
||||
- Add `has_workflow` column to `images` table
|
||||
- Add `workflow` column to `session_queue` table
|
||||
- Drop `workflows` and `workflow_images` tables
|
||||
- Add `workflow_library` table
|
||||
- Drops the `model_manager_metadata` table
|
||||
- Drops the `model_config` table, recreating it (at this point, there is no user data in this table)
|
||||
- Populates the `has_workflow` column in the `images` table (requires `image_files` & `logger` dependencies)
|
||||
"""
|
||||
migration_2 = Migration(
|
||||
from_version=1,
|
||||
to_version=2,
|
||||
callback=Migration2Callback(image_files=image_files, logger=logger),
|
||||
)
|
||||
|
||||
return migration_2
|
@ -0,0 +1,164 @@
|
||||
import sqlite3
|
||||
from typing import Optional, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MigrateCallback(Protocol):
|
||||
"""
|
||||
A callback that performs a migration.
|
||||
|
||||
Migrate callbacks are provided an open cursor to the database. They should not commit their
|
||||
transaction; this is handled by the migrator.
|
||||
|
||||
If the callback needs to access additional dependencies, will be provided to the callback at runtime.
|
||||
|
||||
See :class:`Migration` for an example.
|
||||
"""
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
...
|
||||
|
||||
|
||||
class MigrationError(RuntimeError):
|
||||
"""Raised when a migration fails."""
|
||||
|
||||
|
||||
class MigrationVersionError(ValueError):
|
||||
"""Raised when a migration version is invalid."""
|
||||
|
||||
|
||||
class Migration(BaseModel):
|
||||
"""
|
||||
Represents a migration for a SQLite database.
|
||||
|
||||
:param from_version: The database version on which this migration may be run
|
||||
:param to_version: The database version that results from this migration
|
||||
:param migrate_callback: The callback to run to perform the migration
|
||||
|
||||
Migration callbacks will be provided an open cursor to the database. They should not commit their
|
||||
transaction; this is handled by the migrator.
|
||||
|
||||
It is suggested to use a class to define the migration callback and a builder function to create
|
||||
the :class:`Migration`. This allows the callback to be provided with additional dependencies and
|
||||
keeps things tidy, as all migration logic is self-contained.
|
||||
|
||||
Example:
|
||||
```py
|
||||
# Define the migration callback class
|
||||
class Migration1Callback:
|
||||
# This migration needs a logger, so we define a class that accepts a logger in its constructor.
|
||||
def __init__(self, image_files: ImageFileStorageBase) -> None:
|
||||
self._image_files = ImageFileStorageBase
|
||||
|
||||
# This dunder method allows the instance of the class to be called like a function.
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._add_with_banana_column(cursor)
|
||||
self._do_something_with_images(cursor)
|
||||
|
||||
def _add_with_banana_column(self, cursor: sqlite3.Cursor) -> None:
|
||||
\"""Adds the with_banana column to the sushi table.\"""
|
||||
# Execute SQL using the cursor, taking care to *not commit* a transaction
|
||||
cursor.execute('ALTER TABLE sushi ADD COLUMN with_banana BOOLEAN DEFAULT TRUE;')
|
||||
|
||||
def _do_something_with_images(self, cursor: sqlite3.Cursor) -> None:
|
||||
\"""Does something with the image files service.\"""
|
||||
self._image_files.get(...)
|
||||
|
||||
# Define the migration builder function. This function creates an instance of the migration callback
|
||||
# class and returns a Migration.
|
||||
def build_migration_1(image_files: ImageFileStorageBase) -> Migration:
|
||||
\"""Builds the migration from database version 0 to 1.
|
||||
Requires the image files service to...
|
||||
\"""
|
||||
|
||||
migration_1 = Migration(
|
||||
from_version=0,
|
||||
to_version=1,
|
||||
migrate_callback=Migration1Callback(image_files=image_files),
|
||||
)
|
||||
|
||||
return migration_1
|
||||
|
||||
# Register the migration after all dependencies have been initialized
|
||||
db = SqliteDatabase(db_path, logger)
|
||||
migrator = SqliteMigrator(db)
|
||||
migrator.register_migration(build_migration_1(image_files))
|
||||
migrator.run_migrations()
|
||||
```
|
||||
"""
|
||||
|
||||
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
|
||||
to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration")
|
||||
callback: MigrateCallback = Field(description="The callback to run to perform the migration")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_to_version(self) -> "Migration":
|
||||
"""Validates that to_version is one greater than from_version."""
|
||||
if self.to_version != self.from_version + 1:
|
||||
raise MigrationVersionError("to_version must be one greater than from_version")
|
||||
return self
|
||||
|
||||
def __hash__(self) -> int:
|
||||
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
|
||||
return hash((self.from_version, self.to_version))
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class MigrationSet:
|
||||
"""
|
||||
A set of Migrations. Performs validation during migration registration and provides utility methods.
|
||||
|
||||
Migrations should be registered with `register()`. Once all are registered, `validate_migration_chain()`
|
||||
should be called to ensure that the migrations form a single chain of migrations from version 0 to the latest version.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._migrations: set[Migration] = set()
|
||||
|
||||
def register(self, migration: Migration) -> None:
|
||||
"""Registers a migration."""
|
||||
migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations)
|
||||
migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations)
|
||||
if migration_from_already_registered or migration_to_already_registered:
|
||||
raise MigrationVersionError("Migration with from_version or to_version already registered")
|
||||
self._migrations.add(migration)
|
||||
|
||||
def get(self, from_version: int) -> Optional[Migration]:
|
||||
"""Gets the migration that may be run on the given database version."""
|
||||
# register() ensures that there is only one migration with a given from_version, so this is safe.
|
||||
return next((m for m in self._migrations if m.from_version == from_version), None)
|
||||
|
||||
def validate_migration_chain(self) -> None:
|
||||
"""
|
||||
Validates that the migrations form a single chain of migrations from version 0 to the latest version,
|
||||
Raises a MigrationError if there is a problem.
|
||||
"""
|
||||
if self.count == 0:
|
||||
return
|
||||
if self.latest_version == 0:
|
||||
return
|
||||
next_migration = self.get(from_version=0)
|
||||
if next_migration is None:
|
||||
raise MigrationError("Migration chain is fragmented")
|
||||
touched_count = 1
|
||||
while next_migration is not None:
|
||||
next_migration = self.get(next_migration.to_version)
|
||||
if next_migration is not None:
|
||||
touched_count += 1
|
||||
if touched_count != self.count:
|
||||
raise MigrationError("Migration chain is fragmented")
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
"""The count of registered migrations."""
|
||||
return len(self._migrations)
|
||||
|
||||
@property
|
||||
def latest_version(self) -> int:
|
||||
"""Gets latest to_version among registered migrations. Returns 0 if there are no migrations registered."""
|
||||
if self.count == 0:
|
||||
return 0
|
||||
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version
|
@ -0,0 +1,130 @@
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationError, MigrationSet
|
||||
|
||||
|
||||
class SqliteMigrator:
|
||||
"""
|
||||
Manages migrations for a SQLite database.
|
||||
|
||||
:param db: The instance of :class:`SqliteDatabase` to migrate.
|
||||
|
||||
Migrations should be registered with :meth:`register_migration`.
|
||||
|
||||
Each migration is run in a transaction. If a migration fails, the transaction is rolled back.
|
||||
|
||||
Example Usage:
|
||||
```py
|
||||
db = SqliteDatabase(db_path="my_db.db", logger=logger)
|
||||
migrator = SqliteMigrator(db=db)
|
||||
migrator.register_migration(build_migration_1())
|
||||
migrator.register_migration(build_migration_2())
|
||||
migrator.run_migrations()
|
||||
```
|
||||
"""
|
||||
|
||||
backup_path: Optional[Path] = None
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
self._db = db
|
||||
self._logger = db.logger
|
||||
self._migration_set = MigrationSet()
|
||||
|
||||
def register_migration(self, migration: Migration) -> None:
|
||||
"""Registers a migration."""
|
||||
self._migration_set.register(migration)
|
||||
self._logger.debug(f"Registered migration {migration.from_version} -> {migration.to_version}")
|
||||
|
||||
def run_migrations(self) -> bool:
|
||||
"""Migrates the database to the latest version."""
|
||||
with self._db.lock:
|
||||
# This throws if there is a problem.
|
||||
self._migration_set.validate_migration_chain()
|
||||
cursor = self._db.conn.cursor()
|
||||
self._create_migrations_table(cursor=cursor)
|
||||
|
||||
if self._migration_set.count == 0:
|
||||
self._logger.debug("No migrations registered")
|
||||
return False
|
||||
|
||||
if self._get_current_version(cursor=cursor) == self._migration_set.latest_version:
|
||||
self._logger.debug("Database is up to date, no migrations to run")
|
||||
return False
|
||||
|
||||
self._logger.info("Database update needed")
|
||||
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
|
||||
while next_migration is not None:
|
||||
self._run_migration(next_migration)
|
||||
next_migration = self._migration_set.get(self._get_current_version(cursor))
|
||||
self._logger.info("Database updated successfully")
|
||||
return True
|
||||
|
||||
def _run_migration(self, migration: Migration) -> None:
|
||||
"""Runs a single migration."""
|
||||
try:
|
||||
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
|
||||
# exception is raised.
|
||||
with self._db.lock, self._db.conn as conn:
|
||||
cursor = conn.cursor()
|
||||
if self._get_current_version(cursor) != migration.from_version:
|
||||
raise MigrationError(
|
||||
f"Database is at version {self._get_current_version(cursor)}, expected {migration.from_version}"
|
||||
)
|
||||
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
|
||||
|
||||
# Run the actual migration
|
||||
migration.callback(cursor)
|
||||
|
||||
# Update the version
|
||||
cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
|
||||
|
||||
self._logger.debug(
|
||||
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
|
||||
)
|
||||
# We want to catch *any* error, mirroring the behaviour of the sqlite3 module.
|
||||
except Exception as e:
|
||||
# The connection context manager has already rolled back the migration, so we don't need to do anything.
|
||||
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
|
||||
self._logger.error(msg)
|
||||
raise MigrationError(msg) from e
|
||||
|
||||
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the migrations table for the database, if one does not already exist."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
|
||||
if cursor.fetchone() is not None:
|
||||
return
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
|
||||
cursor.connection.commit()
|
||||
self._logger.debug("Created migrations table")
|
||||
except sqlite3.Error as e:
|
||||
msg = f"Problem creating migrations table: {e}"
|
||||
self._logger.error(msg)
|
||||
cursor.connection.rollback()
|
||||
raise MigrationError(msg) from e
|
||||
|
||||
@classmethod
|
||||
def _get_current_version(cls, cursor: sqlite3.Cursor) -> int:
|
||||
"""Gets the current version of the database, or 0 if the migrations table does not exist."""
|
||||
try:
|
||||
cursor.execute("SELECT MAX(version) FROM migrations;")
|
||||
version: int = cursor.fetchone()[0]
|
||||
if version is None:
|
||||
return 0
|
||||
return version
|
||||
except sqlite3.OperationalError as e:
|
||||
if "no such table" in str(e):
|
||||
return 0
|
||||
raise
|
@ -26,7 +26,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
self._create_tables()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
@ -233,87 +232,3 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS workflow_library (
|
||||
workflow_id TEXT NOT NULL PRIMARY KEY,
|
||||
workflow TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated manually when retrieving workflow
|
||||
opened_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Generated columns, needed for indexing and searching
|
||||
category TEXT GENERATED ALWAYS as (json_extract(workflow, '$.meta.category')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(workflow, '$.name')) VIRTUAL NOT NULL,
|
||||
description TEXT GENERATED ALWAYS as (json_extract(workflow, '$.description')) VIRTUAL NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflow_library FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE workflow_library
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = old.workflow_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description);
|
||||
"""
|
||||
)
|
||||
|
||||
# We do not need the original `workflows` table or `workflow_images` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DROP TABLE IF EXISTS workflow_images;
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DROP TABLE IF EXISTS workflows;
|
||||
"""
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
@ -49,7 +49,8 @@ class MigrateModelYamlToDb:
|
||||
|
||||
def get_db(self) -> ModelRecordServiceSQL:
|
||||
"""Fetch the sqlite3 database for this installation."""
|
||||
db = SqliteDatabase(self.config, self.logger)
|
||||
db_path = None if self.config.use_memory_db else self.config.db_path
|
||||
db = SqliteDatabase(db_path=db_path, logger=self.logger, verbose=self.config.log_sql)
|
||||
return ModelRecordServiceSQL(db)
|
||||
|
||||
def get_yaml(self) -> DictConfig:
|
||||
|
@ -28,8 +28,8 @@ from invokeai.app.services.shared.graph import (
|
||||
IterateInvocation,
|
||||
LibraryGraph,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
|
||||
from .test_invoker import create_edge
|
||||
|
||||
@ -49,7 +49,8 @@ def simple_graph():
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||
db = SqliteDatabase(configuration, InvokeAILogger.get_logger())
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(configuration, logger)
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||
return InvocationServices(
|
||||
|
@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
|
||||
# This import must happen before other invoke imports or test in other files(!!) break
|
||||
from .test_nodes import ( # isort: split
|
||||
@ -24,7 +25,6 @@ from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -52,8 +52,9 @@ def graph_with_subgraph():
|
||||
# the test invocations.
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
db = SqliteDatabase(InvokeAIAppConfig(use_memory_db=True), InvokeAILogger.get_logger())
|
||||
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(configuration, logger)
|
||||
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||
|
@ -15,8 +15,11 @@ class TestModel(BaseModel):
|
||||
|
||||
@pytest.fixture
|
||||
def db() -> SqliteItemStorage[TestModel]:
|
||||
sqlite_db = SqliteDatabase(InvokeAIAppConfig(use_memory_db=True), InvokeAILogger.get_logger())
|
||||
sqlite_item_storage = SqliteItemStorage[TestModel](db=sqlite_db, table_name="test", id_field="id")
|
||||
config = InvokeAIAppConfig(use_memory_db=True)
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db_path = None if config.use_memory_db else config.db_path
|
||||
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
|
||||
sqlite_item_storage = SqliteItemStorage[TestModel](db=db, table_name="test", id_field="id")
|
||||
return sqlite_item_storage
|
||||
|
||||
|
||||
|
@ -18,9 +18,9 @@ from invokeai.app.services.model_install import (
|
||||
ModelInstallServiceBase,
|
||||
)
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -37,9 +37,12 @@ def app_config(datadir: Path) -> InvokeAIAppConfig:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||
database = SqliteDatabase(app_config, InvokeAILogger.get_logger(config=app_config))
|
||||
store: ModelRecordServiceBase = ModelRecordServiceSQL(database)
|
||||
def store(
|
||||
app_config: InvokeAIAppConfig,
|
||||
) -> ModelRecordServiceBase:
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
db = create_mock_sqlite_database(app_config, logger)
|
||||
store: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
||||
return store
|
||||
|
||||
|
||||
|
@ -3,6 +3,7 @@ Test the refactored model config classes.
|
||||
"""
|
||||
|
||||
from hashlib import sha256
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
@ -13,7 +14,6 @@ from invokeai.app.services.model_records import (
|
||||
ModelRecordServiceSQL,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
MainCheckpointConfig,
|
||||
@ -23,13 +23,16 @@ from invokeai.backend.model_manager.config import (
|
||||
VaeDiffusersConfig,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(datadir) -> ModelRecordServiceBase:
|
||||
def store(
|
||||
datadir: Any,
|
||||
) -> ModelRecordServiceBase:
|
||||
config = InvokeAIAppConfig(root=datadir)
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
db = SqliteDatabase(config, logger)
|
||||
db = create_mock_sqlite_database(config, logger)
|
||||
return ModelRecordServiceSQL(db)
|
||||
|
||||
|
||||
|
0
tests/fixtures/__init__.py
vendored
Normal file
0
tests/fixtures/__init__.py
vendored
Normal file
13
tests/fixtures/sqlite_database.py
vendored
Normal file
13
tests/fixtures/sqlite_database.py
vendored
Normal file
@ -0,0 +1,13 @@
|
||||
from logging import Logger
|
||||
from unittest import mock
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
|
||||
|
||||
def create_mock_sqlite_database(config: InvokeAIAppConfig, logger: Logger) -> SqliteDatabase:
|
||||
image_files = mock.Mock(spec=ImageFileStorageBase)
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
return db
|
272
tests/test_sqlite_migrator.py
Normal file
272
tests/test_sqlite_migrator.py
Normal file
@ -0,0 +1,272 @@
|
||||
import sqlite3
|
||||
from contextlib import closing
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import (
|
||||
MigrateCallback,
|
||||
Migration,
|
||||
MigrationError,
|
||||
MigrationSet,
|
||||
MigrationVersionError,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import (
|
||||
SqliteMigrator,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logger() -> Logger:
|
||||
return Logger("test_sqlite_migrator")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_db_conn() -> sqlite3.Connection:
|
||||
return sqlite3.connect(":memory:")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_db_cursor(memory_db_conn: sqlite3.Connection) -> sqlite3.Cursor:
|
||||
return memory_db_conn.cursor()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migrator(logger: Logger) -> SqliteMigrator:
|
||||
db = SqliteDatabase(db_path=None, logger=logger, verbose=False)
|
||||
return SqliteMigrator(db=db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def no_op_migrate_callback() -> MigrateCallback:
|
||||
def no_op_migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
return no_op_migrate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migration_no_op(no_op_migrate_callback: MigrateCallback) -> Migration:
|
||||
return Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migrate_callback_create_table_of_name() -> MigrateCallback:
|
||||
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
|
||||
table_name = kwargs["table_name"]
|
||||
cursor.execute(f"CREATE TABLE {table_name} (id INTEGER PRIMARY KEY);")
|
||||
|
||||
return migrate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migrate_callback_create_test_table() -> MigrateCallback:
|
||||
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
|
||||
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
|
||||
return migrate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migration_create_test_table(migrate_callback_create_test_table: MigrateCallback) -> Migration:
|
||||
return Migration(from_version=0, to_version=1, callback=migrate_callback_create_test_table)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def failing_migration() -> Migration:
|
||||
def failing_migration(cursor: sqlite3.Cursor, **kwargs) -> None:
|
||||
raise Exception("Bad migration")
|
||||
|
||||
return Migration(from_version=0, to_version=1, callback=failing_migration)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def failing_migrate_callback() -> MigrateCallback:
|
||||
def failing_migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
|
||||
raise Exception("Bad migration")
|
||||
|
||||
return failing_migrate
|
||||
|
||||
|
||||
def create_migrate(i: int) -> MigrateCallback:
|
||||
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
|
||||
cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);")
|
||||
|
||||
return migrate
|
||||
|
||||
|
||||
def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None:
|
||||
with pytest.raises(ValidationError, match="to_version must be one greater than from_version"):
|
||||
Migration(from_version=0, to_version=2, callback=no_op_migrate_callback)
|
||||
# not raising is sufficient
|
||||
Migration(from_version=1, to_version=2, callback=no_op_migrate_callback)
|
||||
|
||||
|
||||
def test_migration_hash(no_op_migrate_callback: MigrateCallback) -> None:
|
||||
migration = Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
|
||||
assert hash(migration) == hash((0, 1))
|
||||
|
||||
|
||||
def test_migration_set_add_migration(migrator: SqliteMigrator, migration_no_op: Migration) -> None:
|
||||
migration = migration_no_op
|
||||
migrator._migration_set.register(migration)
|
||||
assert migration in migrator._migration_set._migrations
|
||||
|
||||
|
||||
def test_migration_set_may_not_register_dupes(
|
||||
migrator: SqliteMigrator, no_op_migrate_callback: MigrateCallback
|
||||
) -> None:
|
||||
migrate_0_to_1_a = Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
|
||||
migrate_0_to_1_b = Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
|
||||
migrator._migration_set.register(migrate_0_to_1_a)
|
||||
with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"):
|
||||
migrator._migration_set.register(migrate_0_to_1_b)
|
||||
migrate_1_to_2_a = Migration(from_version=1, to_version=2, callback=no_op_migrate_callback)
|
||||
migrate_1_to_2_b = Migration(from_version=1, to_version=2, callback=no_op_migrate_callback)
|
||||
migrator._migration_set.register(migrate_1_to_2_a)
|
||||
with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"):
|
||||
migrator._migration_set.register(migrate_1_to_2_b)
|
||||
|
||||
|
||||
def test_migration_set_gets_migration(migration_no_op: Migration) -> None:
|
||||
migration_set = MigrationSet()
|
||||
migration_set.register(migration_no_op)
|
||||
assert migration_set.get(0) == migration_no_op
|
||||
assert migration_set.get(1) is None
|
||||
|
||||
|
||||
def test_migration_set_validates_migration_chain(no_op_migrate_callback: MigrateCallback) -> None:
|
||||
migration_set = MigrationSet()
|
||||
migration_set.register(Migration(from_version=1, to_version=2, callback=no_op_migrate_callback))
|
||||
with pytest.raises(MigrationError, match="Migration chain is fragmented"):
|
||||
# no migration from 0 to 1
|
||||
migration_set.validate_migration_chain()
|
||||
migration_set.register(Migration(from_version=0, to_version=1, callback=no_op_migrate_callback))
|
||||
migration_set.validate_migration_chain()
|
||||
migration_set.register(Migration(from_version=2, to_version=3, callback=no_op_migrate_callback))
|
||||
migration_set.validate_migration_chain()
|
||||
migration_set.register(Migration(from_version=4, to_version=5, callback=no_op_migrate_callback))
|
||||
with pytest.raises(MigrationError, match="Migration chain is fragmented"):
|
||||
# no migration from 3 to 4
|
||||
migration_set.validate_migration_chain()
|
||||
|
||||
|
||||
def test_migration_set_counts_migrations(no_op_migrate_callback: MigrateCallback) -> None:
|
||||
migration_set = MigrationSet()
|
||||
assert migration_set.count == 0
|
||||
migration_set.register(Migration(from_version=0, to_version=1, callback=no_op_migrate_callback))
|
||||
assert migration_set.count == 1
|
||||
migration_set.register(Migration(from_version=1, to_version=2, callback=no_op_migrate_callback))
|
||||
assert migration_set.count == 2
|
||||
|
||||
|
||||
def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallback) -> None:
|
||||
migration_set = MigrationSet()
|
||||
assert migration_set.latest_version == 0
|
||||
migration_set.register(Migration(from_version=1, to_version=2, callback=no_op_migrate_callback))
|
||||
assert migration_set.latest_version == 2
|
||||
migration_set.register(Migration(from_version=0, to_version=1, callback=no_op_migrate_callback))
|
||||
assert migration_set.latest_version == 2
|
||||
|
||||
|
||||
def test_migration_runs(memory_db_cursor: sqlite3.Cursor, migrate_callback_create_test_table: MigrateCallback) -> None:
|
||||
migration = Migration(
|
||||
from_version=0,
|
||||
to_version=1,
|
||||
callback=migrate_callback_create_test_table,
|
||||
)
|
||||
migration.callback(memory_db_cursor)
|
||||
memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
|
||||
assert memory_db_cursor.fetchone() is not None
|
||||
|
||||
|
||||
def test_migrator_registers_migration(migrator: SqliteMigrator, migration_no_op: Migration) -> None:
|
||||
migration = migration_no_op
|
||||
migrator.register_migration(migration)
|
||||
assert migration in migrator._migration_set._migrations
|
||||
|
||||
|
||||
def test_migrator_creates_migrations_table(migrator: SqliteMigrator) -> None:
|
||||
cursor = migrator._db.conn.cursor()
|
||||
migrator._create_migrations_table(cursor)
|
||||
cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='migrations';")
|
||||
assert cursor.fetchone() is not None
|
||||
|
||||
|
||||
def test_migrator_migration_sets_version(migrator: SqliteMigrator, migration_no_op: Migration) -> None:
|
||||
cursor = migrator._db.conn.cursor()
|
||||
migrator._create_migrations_table(cursor)
|
||||
migrator.register_migration(migration_no_op)
|
||||
migrator.run_migrations()
|
||||
cursor.execute("SELECT MAX(version) FROM migrations;")
|
||||
assert cursor.fetchone()[0] == 1
|
||||
|
||||
|
||||
def test_migrator_gets_current_version(migrator: SqliteMigrator, migration_no_op: Migration) -> None:
|
||||
cursor = migrator._db.conn.cursor()
|
||||
assert migrator._get_current_version(cursor) == 0
|
||||
migrator._create_migrations_table(cursor)
|
||||
assert migrator._get_current_version(cursor) == 0
|
||||
migrator.register_migration(migration_no_op)
|
||||
migrator.run_migrations()
|
||||
assert migrator._get_current_version(cursor) == 1
|
||||
|
||||
|
||||
def test_migrator_runs_single_migration(migrator: SqliteMigrator, migration_create_test_table: Migration) -> None:
|
||||
cursor = migrator._db.conn.cursor()
|
||||
migrator._create_migrations_table(cursor)
|
||||
migrator._run_migration(migration_create_test_table)
|
||||
assert migrator._get_current_version(cursor) == 1
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
|
||||
assert cursor.fetchone() is not None
|
||||
|
||||
|
||||
def test_migrator_runs_all_migrations_in_memory(migrator: SqliteMigrator) -> None:
|
||||
cursor = migrator._db.conn.cursor()
|
||||
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
|
||||
for migration in migrations:
|
||||
migrator.register_migration(migration)
|
||||
migrator.run_migrations()
|
||||
assert migrator._get_current_version(cursor) == 3
|
||||
|
||||
|
||||
def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
|
||||
with TemporaryDirectory() as tempdir:
|
||||
original_db_path = Path(tempdir) / "invokeai.db"
|
||||
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
|
||||
migrator = SqliteMigrator(db=db)
|
||||
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
|
||||
for migration in migrations:
|
||||
migrator.register_migration(migration)
|
||||
migrator.run_migrations()
|
||||
with closing(sqlite3.connect(original_db_path)) as original_db_conn:
|
||||
original_db_cursor = original_db_conn.cursor()
|
||||
assert SqliteMigrator._get_current_version(original_db_cursor) == 3
|
||||
# Must manually close else we get an error on Windows
|
||||
db.conn.close()
|
||||
|
||||
|
||||
def test_migrator_makes_no_changes_on_failed_migration(
|
||||
migrator: SqliteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
|
||||
) -> None:
|
||||
cursor = migrator._db.conn.cursor()
|
||||
migrator.register_migration(migration_no_op)
|
||||
migrator.run_migrations()
|
||||
assert migrator._get_current_version(cursor) == 1
|
||||
migrator.register_migration(Migration(from_version=1, to_version=2, callback=failing_migrate_callback))
|
||||
with pytest.raises(MigrationError, match="Bad migration"):
|
||||
migrator.run_migrations()
|
||||
assert migrator._get_current_version(cursor) == 1
|
||||
|
||||
|
||||
def test_idempotent_migrations(migrator: SqliteMigrator, migration_create_test_table: Migration) -> None:
|
||||
cursor = migrator._db.conn.cursor()
|
||||
migrator.register_migration(migration_create_test_table)
|
||||
migrator.run_migrations()
|
||||
# not throwing is sufficient
|
||||
migrator.run_migrations()
|
||||
assert migrator._get_current_version(cursor) == 1
|
Loading…
Reference in New Issue
Block a user