fix(tests): add sqlite migrator to test fixtures

This commit is contained in:
psychedelicious 2023-12-11 14:21:14 +11:00
parent 4f3c32a2ee
commit 26ab917021
3 changed files with 48 additions and 2 deletions

View File

@ -1,4 +1,5 @@
import logging
from pathlib import Path
import pytest
@ -28,7 +29,10 @@ from invokeai.app.services.shared.graph import (
IterateInvocation,
LibraryGraph,
)
from invokeai.app.services.shared.sqlite.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_migrator import SQLiteMigrator
from invokeai.backend.util.logging import InvokeAILogger
from .test_invoker import create_edge
@ -49,7 +53,18 @@ def simple_graph():
@pytest.fixture
def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
db = SqliteDatabase(configuration, InvokeAILogger.get_logger())
logger = InvokeAILogger.get_logger()
db = SqliteDatabase(configuration, logger)
migrator = SQLiteMigrator(
db_path=db.database if isinstance(db.database, Path) else None,
conn=db.conn,
lock=db.lock,
logger=logger,
log_sql=configuration.log_sql,
)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
return InvocationServices(

View File

@ -1,8 +1,10 @@
import logging
from pathlib import Path
import pytest
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite.sqlite_migrator import SQLiteMigrator
from invokeai.backend.util.logging import InvokeAILogger
# This import must happen before other invoke imports or test in other files(!!) break
@ -24,6 +26,8 @@ from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
from invokeai.app.services.shared.sqlite.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@ -52,8 +56,19 @@ def graph_with_subgraph():
# the test invocations.
@pytest.fixture
def mock_services() -> InvocationServices:
db = SqliteDatabase(InvokeAIAppConfig(use_memory_db=True), InvokeAILogger.get_logger())
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = SqliteDatabase(configuration, logger)
migrator = SQLiteMigrator(
db_path=db.database if isinstance(db.database, Path) else None,
conn=db.conn,
lock=db.lock,
logger=logger,
log_sql=configuration.log_sql,
)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")

View File

@ -3,6 +3,7 @@ Test the refactored model config classes.
"""
from hashlib import sha256
from pathlib import Path
import pytest
@ -13,7 +14,10 @@ from invokeai.app.services.model_records import (
ModelRecordServiceSQL,
UnknownModelException,
)
from invokeai.app.services.shared.sqlite.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_migrator import SQLiteMigrator
from invokeai.backend.model_manager.config import (
BaseModelType,
MainCheckpointConfig,
@ -30,6 +34,18 @@ def store(datadir) -> ModelRecordServiceBase:
config = InvokeAIAppConfig(root=datadir)
logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config, logger)
migrator = SQLiteMigrator(
db_path=db.database if isinstance(db.database, Path) else None,
conn=db.conn,
lock=db.lock,
logger=logger,
log_sql=config.log_sql,
)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
# this test uses a file database, so we need to reinitialize it after migrations
db.reinitialize()
return ModelRecordServiceSQL(db)