mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): decouple SqliteDatabase from config object
- Simplify init args to path (None means use memory), logger, and verbose - Add docstrings to SqliteDatabase (it had almost none) - Update all usages of the class
This commit is contained in:
parent
afe4e55bf9
commit
417db71471
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
|
||||||
@ -75,10 +74,11 @@ class ApiDependencies:
|
|||||||
output_folder = config.output_path
|
output_folder = config.output_path
|
||||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
|
|
||||||
db = SqliteDatabase(config, logger)
|
db_path = None if config.use_memory_db else config.db_path
|
||||||
|
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
|
||||||
|
|
||||||
migrator = SQLiteMigrator(
|
migrator = SQLiteMigrator(
|
||||||
db_path=db.database if isinstance(db.database, Path) else None,
|
db_path=db.db_path,
|
||||||
conn=db.conn,
|
conn=db.conn,
|
||||||
lock=db.lock,
|
lock=db.lock,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
@ -89,7 +89,10 @@ class ApiDependencies:
|
|||||||
migrator.register_migration(migration_2)
|
migrator.register_migration(migration_2)
|
||||||
did_migrate = migrator.run_migrations()
|
did_migrate = migrator.run_migrations()
|
||||||
|
|
||||||
if not db.is_memory and did_migrate:
|
# We need to reinitialize the database if we migrated, but only if we are using a file database.
|
||||||
|
# This closes the SqliteDatabase's connection and re-runs its `__init__` logic.
|
||||||
|
# If we do this with a memory database, we wipe the db.
|
||||||
|
if not db.db_path and did_migrate:
|
||||||
db.reinitialize()
|
db.reinitialize()
|
||||||
|
|
||||||
configuration = config
|
configuration = config
|
||||||
|
@ -3,58 +3,83 @@ import threading
|
|||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
||||||
|
|
||||||
|
|
||||||
class SqliteDatabase:
|
class SqliteDatabase:
|
||||||
database: Path | str # Must declare this here to satisfy type checker
|
"""
|
||||||
|
Manages a connection to an SQLite database.
|
||||||
|
|
||||||
def __init__(self, config: InvokeAIAppConfig, logger: Logger) -> None:
|
This is a light wrapper around the `sqlite3` module, providing a few conveniences:
|
||||||
self.initialize(config, logger)
|
- The database file is written to disk if it does not exist.
|
||||||
|
- Foreign key constraints are enabled by default.
|
||||||
|
- The connection is configured to use the `sqlite3.Row` row factory.
|
||||||
|
- A `conn` attribute is provided to access the connection.
|
||||||
|
- A `lock` attribute is provided to lock the database connection.
|
||||||
|
- A `clean` method to run the VACUUM command and report on the freed space.
|
||||||
|
- A `reinitialize` method to close the connection and re-run the init.
|
||||||
|
- A `close` method to close the connection.
|
||||||
|
|
||||||
def initialize(self, config: InvokeAIAppConfig, logger: Logger) -> None:
|
:param db_path: Path to the database file. If None, an in-memory database is used.
|
||||||
self._logger = logger
|
:param logger: Logger to use for logging.
|
||||||
self._config = config
|
:param verbose: Whether to log SQL statements. Provides `logger.debug` as the SQLite trace callback.
|
||||||
self.is_memory = False
|
"""
|
||||||
if self._config.use_memory_db:
|
|
||||||
self.database = sqlite_memory
|
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
|
||||||
self.is_memory = True
|
self.initialize(db_path=db_path, logger=logger, verbose=verbose)
|
||||||
|
|
||||||
|
def initialize(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
|
||||||
|
"""Initializes the database. This is used internally by the class constructor."""
|
||||||
|
self.logger = logger
|
||||||
|
self.db_path = db_path
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
if not self.db_path:
|
||||||
logger.info("Initializing in-memory database")
|
logger.info("Initializing in-memory database")
|
||||||
else:
|
else:
|
||||||
self.database = self._config.db_path
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
self.database.parent.mkdir(parents=True, exist_ok=True)
|
self.logger.info(f"Initializing database at {self.db_path}")
|
||||||
self._logger.info(f"Initializing database at {self.database}")
|
|
||||||
|
|
||||||
self.conn = sqlite3.connect(database=self.database, check_same_thread=False)
|
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
|
||||||
self.lock = threading.RLock()
|
self.lock = threading.RLock()
|
||||||
self.conn.row_factory = sqlite3.Row
|
self.conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
if self._config.log_sql:
|
if self.verbose:
|
||||||
self.conn.set_trace_callback(self._logger.debug)
|
self.conn.set_trace_callback(self.logger.debug)
|
||||||
|
|
||||||
self.conn.execute("PRAGMA foreign_keys = ON;")
|
self.conn.execute("PRAGMA foreign_keys = ON;")
|
||||||
|
|
||||||
def reinitialize(self) -> None:
|
def reinitialize(self) -> None:
|
||||||
"""Reinitializes the database. Needed after migration."""
|
"""
|
||||||
|
Re-initializes the database by closing the connection and re-running the init.
|
||||||
|
Warning: This will wipe the database if it is an in-memory database.
|
||||||
|
"""
|
||||||
self.close()
|
self.close()
|
||||||
self.initialize(self._config, self._logger)
|
self.initialize(db_path=self.db_path, logger=self.logger, verbose=self.verbose)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
|
"""
|
||||||
|
Closes the connection to the database.
|
||||||
|
Warning: This will wipe the database if it is an in-memory database.
|
||||||
|
"""
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
|
|
||||||
def clean(self) -> None:
|
def clean(self) -> None:
|
||||||
|
"""
|
||||||
|
Cleans the database by running the VACUUM command, reporting on the freed space.
|
||||||
|
"""
|
||||||
|
# No need to clean in-memory database
|
||||||
|
if not self.db_path:
|
||||||
|
return
|
||||||
with self.lock:
|
with self.lock:
|
||||||
try:
|
try:
|
||||||
if self.database == sqlite_memory:
|
initial_db_size = Path(self.db_path).stat().st_size
|
||||||
return
|
|
||||||
initial_db_size = Path(self.database).stat().st_size
|
|
||||||
self.conn.execute("VACUUM;")
|
self.conn.execute("VACUUM;")
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
final_db_size = Path(self.database).stat().st_size
|
final_db_size = Path(self.db_path).stat().st_size
|
||||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||||
if freed_space_in_mb > 0:
|
if freed_space_in_mb > 0:
|
||||||
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._logger.error(f"Error cleaning database: {e}")
|
self.logger.error(f"Error cleaning database: {e}")
|
||||||
raise
|
raise
|
||||||
|
@ -47,7 +47,8 @@ class MigrateModelYamlToDb:
|
|||||||
|
|
||||||
def get_db(self) -> ModelRecordServiceSQL:
|
def get_db(self) -> ModelRecordServiceSQL:
|
||||||
"""Fetch the sqlite3 database for this installation."""
|
"""Fetch the sqlite3 database for this installation."""
|
||||||
db = SqliteDatabase(self.config, self.logger)
|
db_path = None if self.config.use_memory_db else self.config.db_path
|
||||||
|
db = SqliteDatabase(db_path=db_path, logger=self.logger, verbose=self.config.log_sql)
|
||||||
return ModelRecordServiceSQL(db)
|
return ModelRecordServiceSQL(db)
|
||||||
|
|
||||||
def get_yaml(self) -> DictConfig:
|
def get_yaml(self) -> DictConfig:
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -54,9 +53,10 @@ def simple_graph():
|
|||||||
def mock_services() -> InvocationServices:
|
def mock_services() -> InvocationServices:
|
||||||
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||||
logger = InvokeAILogger.get_logger()
|
logger = InvokeAILogger.get_logger()
|
||||||
db = SqliteDatabase(configuration, logger)
|
db_path = None if configuration.use_memory_db else configuration.db_path
|
||||||
|
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql)
|
||||||
migrator = SQLiteMigrator(
|
migrator = SQLiteMigrator(
|
||||||
db_path=db.database if isinstance(db.database, Path) else None,
|
db_path=db.db_path,
|
||||||
conn=db.conn,
|
conn=db.conn,
|
||||||
lock=db.lock,
|
lock=db.lock,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -58,9 +57,10 @@ def graph_with_subgraph():
|
|||||||
def mock_services() -> InvocationServices:
|
def mock_services() -> InvocationServices:
|
||||||
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||||
logger = InvokeAILogger.get_logger()
|
logger = InvokeAILogger.get_logger()
|
||||||
db = SqliteDatabase(configuration, logger)
|
db_path = None if configuration.use_memory_db else configuration.db_path
|
||||||
|
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql)
|
||||||
migrator = SQLiteMigrator(
|
migrator = SQLiteMigrator(
|
||||||
db_path=db.database if isinstance(db.database, Path) else None,
|
db_path=db.db_path,
|
||||||
conn=db.conn,
|
conn=db.conn,
|
||||||
lock=db.lock,
|
lock=db.lock,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
@ -15,8 +15,11 @@ class TestModel(BaseModel):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def db() -> SqliteItemStorage[TestModel]:
|
def db() -> SqliteItemStorage[TestModel]:
|
||||||
sqlite_db = SqliteDatabase(InvokeAIAppConfig(use_memory_db=True), InvokeAILogger.get_logger())
|
config = InvokeAIAppConfig(use_memory_db=True)
|
||||||
sqlite_item_storage = SqliteItemStorage[TestModel](db=sqlite_db, table_name="test", id_field="id")
|
logger = InvokeAILogger.get_logger()
|
||||||
|
db_path = None if config.use_memory_db else config.db_path
|
||||||
|
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
|
||||||
|
sqlite_item_storage = SqliteItemStorage[TestModel](db=db, table_name="test", id_field="id")
|
||||||
return sqlite_item_storage
|
return sqlite_item_storage
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,11 +42,12 @@ def app_config(datadir: Path) -> InvokeAIAppConfig:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||||
logger = InvokeAILogger.get_logger(config=app_config)
|
logger = InvokeAILogger.get_logger(config=app_config)
|
||||||
database = SqliteDatabase(app_config, logger)
|
db_path = None if app_config.use_memory_db else app_config.db_path
|
||||||
|
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=app_config.log_sql)
|
||||||
migrator = SQLiteMigrator(
|
migrator = SQLiteMigrator(
|
||||||
db_path=database.database if isinstance(database.database, Path) else None,
|
db_path=db.db_path,
|
||||||
conn=database.conn,
|
conn=db.conn,
|
||||||
lock=database.lock,
|
lock=db.lock,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
log_sql=app_config.log_sql,
|
log_sql=app_config.log_sql,
|
||||||
)
|
)
|
||||||
@ -54,8 +55,8 @@ def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
|||||||
migrator.register_migration(migration_2)
|
migrator.register_migration(migration_2)
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
# this test uses a file database, so we need to reinitialize it after migrations
|
# this test uses a file database, so we need to reinitialize it after migrations
|
||||||
database.reinitialize()
|
db.reinitialize()
|
||||||
store: ModelRecordServiceBase = ModelRecordServiceSQL(database)
|
store: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
||||||
return store
|
return store
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@ Test the refactored model config classes.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -34,9 +33,10 @@ from invokeai.backend.util.logging import InvokeAILogger
|
|||||||
def store(datadir: Any) -> ModelRecordServiceBase:
|
def store(datadir: Any) -> ModelRecordServiceBase:
|
||||||
config = InvokeAIAppConfig(root=datadir)
|
config = InvokeAIAppConfig(root=datadir)
|
||||||
logger = InvokeAILogger.get_logger(config=config)
|
logger = InvokeAILogger.get_logger(config=config)
|
||||||
db = SqliteDatabase(config, logger)
|
db_path = None if config.use_memory_db else config.db_path
|
||||||
|
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
|
||||||
migrator = SQLiteMigrator(
|
migrator = SQLiteMigrator(
|
||||||
db_path=db.database if isinstance(db.database, Path) else None,
|
db_path=db.db_path,
|
||||||
conn=db.conn,
|
conn=db.conn,
|
||||||
lock=db.lock,
|
lock=db.lock,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
Loading…
Reference in New Issue
Block a user