feat(db): decouple SqliteDatabase from config object

- Simplify init args to path (None means use memory), logger, and verbose
- Add docstrings to SqliteDatabase (it had almost none)
- Update all usages of the class
This commit is contained in:
psychedelicious 2023-12-12 10:29:46 +11:00
parent afe4e55bf9
commit 417db71471
8 changed files with 80 additions and 47 deletions

View File

@ -2,7 +2,6 @@
from functools import partial from functools import partial
from logging import Logger from logging import Logger
from pathlib import Path
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
@ -75,10 +74,11 @@ class ApiDependencies:
output_folder = config.output_path output_folder = config.output_path
image_files = DiskImageFileStorage(f"{output_folder}/images") image_files = DiskImageFileStorage(f"{output_folder}/images")
db = SqliteDatabase(config, logger) db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
migrator = SQLiteMigrator( migrator = SQLiteMigrator(
db_path=db.database if isinstance(db.database, Path) else None, db_path=db.db_path,
conn=db.conn, conn=db.conn,
lock=db.lock, lock=db.lock,
logger=logger, logger=logger,
@ -89,7 +89,10 @@ class ApiDependencies:
migrator.register_migration(migration_2) migrator.register_migration(migration_2)
did_migrate = migrator.run_migrations() did_migrate = migrator.run_migrations()
if not db.is_memory and did_migrate: # We need to reinitialize the database if we migrated, but only if we are using a file database.
# This closes the SqliteDatabase's connection and re-runs its `__init__` logic.
# If we do this with a memory database, we wipe the db.
if not db.db_path and did_migrate:
db.reinitialize() db.reinitialize()
configuration = config configuration = config

View File

@ -3,58 +3,83 @@ import threading
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
class SqliteDatabase: class SqliteDatabase:
database: Path | str # Must declare this here to satisfy type checker """
Manages a connection to an SQLite database.
def __init__(self, config: InvokeAIAppConfig, logger: Logger) -> None: This is a light wrapper around the `sqlite3` module, providing a few conveniences:
self.initialize(config, logger) - The database file is written to disk if it does not exist.
- Foreign key constraints are enabled by default.
- The connection is configured to use the `sqlite3.Row` row factory.
- A `conn` attribute is provided to access the connection.
- A `lock` attribute is provided to lock the database connection.
- A `clean` method to run the VACUUM command and report on the freed space.
- A `reinitialize` method to close the connection and re-run the init.
- A `close` method to close the connection.
def initialize(self, config: InvokeAIAppConfig, logger: Logger) -> None: :param db_path: Path to the database file. If None, an in-memory database is used.
self._logger = logger :param logger: Logger to use for logging.
self._config = config :param verbose: Whether to log SQL statements. Provides `logger.debug` as the SQLite trace callback.
self.is_memory = False """
if self._config.use_memory_db:
self.database = sqlite_memory def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
self.is_memory = True self.initialize(db_path=db_path, logger=logger, verbose=verbose)
def initialize(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
"""Initializes the database. This is used internally by the class constructor."""
self.logger = logger
self.db_path = db_path
self.verbose = verbose
if not self.db_path:
logger.info("Initializing in-memory database") logger.info("Initializing in-memory database")
else: else:
self.database = self._config.db_path self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.database.parent.mkdir(parents=True, exist_ok=True) self.logger.info(f"Initializing database at {self.db_path}")
self._logger.info(f"Initializing database at {self.database}")
self.conn = sqlite3.connect(database=self.database, check_same_thread=False) self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
self.lock = threading.RLock() self.lock = threading.RLock()
self.conn.row_factory = sqlite3.Row self.conn.row_factory = sqlite3.Row
if self._config.log_sql: if self.verbose:
self.conn.set_trace_callback(self._logger.debug) self.conn.set_trace_callback(self.logger.debug)
self.conn.execute("PRAGMA foreign_keys = ON;") self.conn.execute("PRAGMA foreign_keys = ON;")
def reinitialize(self) -> None: def reinitialize(self) -> None:
"""Reinitializes the database. Needed after migration.""" """
Re-initializes the database by closing the connection and re-running the init.
Warning: This will wipe the database if it is an in-memory database.
"""
self.close() self.close()
self.initialize(self._config, self._logger) self.initialize(db_path=self.db_path, logger=self.logger, verbose=self.verbose)
def close(self) -> None: def close(self) -> None:
"""
Closes the connection to the database.
Warning: This will wipe the database if it is an in-memory database.
"""
self.conn.close() self.conn.close()
def clean(self) -> None: def clean(self) -> None:
"""
Cleans the database by running the VACUUM command, reporting on the freed space.
"""
# No need to clean in-memory database
if not self.db_path:
return
with self.lock: with self.lock:
try: try:
if self.database == sqlite_memory: initial_db_size = Path(self.db_path).stat().st_size
return
initial_db_size = Path(self.database).stat().st_size
self.conn.execute("VACUUM;") self.conn.execute("VACUUM;")
self.conn.commit() self.conn.commit()
final_db_size = Path(self.database).stat().st_size final_db_size = Path(self.db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2) freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0: if freed_space_in_mb > 0:
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)") self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
except Exception as e: except Exception as e:
self._logger.error(f"Error cleaning database: {e}") self.logger.error(f"Error cleaning database: {e}")
raise raise

View File

@ -47,7 +47,8 @@ class MigrateModelYamlToDb:
def get_db(self) -> ModelRecordServiceSQL: def get_db(self) -> ModelRecordServiceSQL:
"""Fetch the sqlite3 database for this installation.""" """Fetch the sqlite3 database for this installation."""
db = SqliteDatabase(self.config, self.logger) db_path = None if self.config.use_memory_db else self.config.db_path
db = SqliteDatabase(db_path=db_path, logger=self.logger, verbose=self.config.log_sql)
return ModelRecordServiceSQL(db) return ModelRecordServiceSQL(db)
def get_yaml(self) -> DictConfig: def get_yaml(self) -> DictConfig:

View File

@ -1,5 +1,4 @@
import logging import logging
from pathlib import Path
import pytest import pytest
@ -54,9 +53,10 @@ def simple_graph():
def mock_services() -> InvocationServices: def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger() logger = InvokeAILogger.get_logger()
db = SqliteDatabase(configuration, logger) db_path = None if configuration.use_memory_db else configuration.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql)
migrator = SQLiteMigrator( migrator = SQLiteMigrator(
db_path=db.database if isinstance(db.database, Path) else None, db_path=db.db_path,
conn=db.conn, conn=db.conn,
lock=db.lock, lock=db.lock,
logger=logger, logger=logger,

View File

@ -1,5 +1,4 @@
import logging import logging
from pathlib import Path
import pytest import pytest
@ -58,9 +57,10 @@ def graph_with_subgraph():
def mock_services() -> InvocationServices: def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger() logger = InvokeAILogger.get_logger()
db = SqliteDatabase(configuration, logger) db_path = None if configuration.use_memory_db else configuration.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql)
migrator = SQLiteMigrator( migrator = SQLiteMigrator(
db_path=db.database if isinstance(db.database, Path) else None, db_path=db.db_path,
conn=db.conn, conn=db.conn,
lock=db.lock, lock=db.lock,
logger=logger, logger=logger,

View File

@ -15,8 +15,11 @@ class TestModel(BaseModel):
@pytest.fixture @pytest.fixture
def db() -> SqliteItemStorage[TestModel]: def db() -> SqliteItemStorage[TestModel]:
sqlite_db = SqliteDatabase(InvokeAIAppConfig(use_memory_db=True), InvokeAILogger.get_logger()) config = InvokeAIAppConfig(use_memory_db=True)
sqlite_item_storage = SqliteItemStorage[TestModel](db=sqlite_db, table_name="test", id_field="id") logger = InvokeAILogger.get_logger()
db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
sqlite_item_storage = SqliteItemStorage[TestModel](db=db, table_name="test", id_field="id")
return sqlite_item_storage return sqlite_item_storage

View File

@ -42,11 +42,12 @@ def app_config(datadir: Path) -> InvokeAIAppConfig:
@pytest.fixture @pytest.fixture
def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=app_config) logger = InvokeAILogger.get_logger(config=app_config)
database = SqliteDatabase(app_config, logger) db_path = None if app_config.use_memory_db else app_config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=app_config.log_sql)
migrator = SQLiteMigrator( migrator = SQLiteMigrator(
db_path=database.database if isinstance(database.database, Path) else None, db_path=db.db_path,
conn=database.conn, conn=db.conn,
lock=database.lock, lock=db.lock,
logger=logger, logger=logger,
log_sql=app_config.log_sql, log_sql=app_config.log_sql,
) )
@ -54,8 +55,8 @@ def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
migrator.register_migration(migration_2) migrator.register_migration(migration_2)
migrator.run_migrations() migrator.run_migrations()
# this test uses a file database, so we need to reinitialize it after migrations # this test uses a file database, so we need to reinitialize it after migrations
database.reinitialize() db.reinitialize()
store: ModelRecordServiceBase = ModelRecordServiceSQL(database) store: ModelRecordServiceBase = ModelRecordServiceSQL(db)
return store return store

View File

@ -3,7 +3,6 @@ Test the refactored model config classes.
""" """
from hashlib import sha256 from hashlib import sha256
from pathlib import Path
from typing import Any from typing import Any
import pytest import pytest
@ -34,9 +33,10 @@ from invokeai.backend.util.logging import InvokeAILogger
def store(datadir: Any) -> ModelRecordServiceBase: def store(datadir: Any) -> ModelRecordServiceBase:
config = InvokeAIAppConfig(root=datadir) config = InvokeAIAppConfig(root=datadir)
logger = InvokeAILogger.get_logger(config=config) logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config, logger) db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
migrator = SQLiteMigrator( migrator = SQLiteMigrator(
db_path=db.database if isinstance(db.database, Path) else None, db_path=db.db_path,
conn=db.conn, conn=db.conn,
lock=db.lock, lock=db.lock,
logger=logger, logger=logger,