From d6a5c2dbe33431bafe3e2aa8fb6c43e139604fb2 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Wed, 16 Aug 2023 14:35:49 -0400 Subject: [PATCH] Fix tests --- .../app/services/batch_manager_storage.py | 2 +- tests/nodes/test_graph_execution_state.py | 7 ++-- tests/nodes/test_invoker.py | 8 ++-- tests/nodes/test_sqlite.py | 37 +++++++++---------- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/invokeai/app/services/batch_manager_storage.py b/invokeai/app/services/batch_manager_storage.py index 80411c1e5c..b307688d1d 100644 --- a/invokeai/app/services/batch_manager_storage.py +++ b/invokeai/app/services/batch_manager_storage.py @@ -16,7 +16,7 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, ) from invokeai.app.services.graph import Graph -from invokeai.app.models.image import ImageField +from invokeai.app.invocations.primitives import ImageField from pydantic import BaseModel, Field, Extra, parse_raw_as, StrictStr, StrictInt, StrictFloat diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index 276dee7c98..8c9a5bd19d 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -25,7 +25,7 @@ from invokeai.app.services.graph import ( LibraryGraph, ) import pytest - +import sqlite3 @pytest.fixture def simple_graph(): @@ -42,8 +42,9 @@ def simple_graph(): @pytest.fixture def mock_services() -> InvocationServices: # NOTE: none of these are actually called by the test invocations + db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False) graph_execution_manager = SqliteItemStorage[GraphExecutionState]( - filename=sqlite_memory, table_name="graph_executions" + conn=db_conn, table_name="graph_executions" ) return InvocationServices( model_manager=None, # type: ignore @@ -55,7 +56,7 @@ def mock_services() -> InvocationServices: batch_manager=None, # type: ignore board_images=None, # type: ignore queue=MemoryInvocationQueue(), - graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"), + graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"), graph_execution_manager=graph_execution_manager, performance_statistics=InvocationStatsService(graph_execution_manager), processor=DefaultInvocationProcessor(), diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 7e7a226023..275abc820c 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -23,6 +23,7 @@ from invokeai.app.services.graph import ( LibraryGraph, ) import pytest +import sqlite3 @pytest.fixture @@ -87,10 +88,11 @@ def simple_batches(): @pytest.fixture def mock_services() -> InvocationServices: # NOTE: none of these are actually called by the test invocations + db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False) graph_execution_manager = SqliteItemStorage[GraphExecutionState]( - filename=sqlite_memory, table_name="graph_executions" + conn=db_conn, table_name="graph_executions" ) - batch_manager_storage = SqliteBatchProcessStorage(sqlite_memory) + batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn) return InvocationServices( model_manager=None, # type: ignore events=TestEventService(), @@ -101,7 +103,7 @@ def mock_services() -> InvocationServices: boards=None, # type: ignore board_images=None, # type: ignore queue=MemoryInvocationQueue(), - graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"), + graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"), graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), performance_statistics=InvocationStatsService(graph_execution_manager), diff --git a/tests/nodes/test_sqlite.py b/tests/nodes/test_sqlite.py index a9eb542e44..5ea33674df 100644 --- a/tests/nodes/test_sqlite.py +++ b/tests/nodes/test_sqlite.py @@ -1,20 +1,25 @@ from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from pydantic import BaseModel, Field +import pytest +import sqlite3 class TestModel(BaseModel): id: str = Field(description="ID") name: str = Field(description="Name") -def test_sqlite_service_can_create_and_get(): - db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") +@pytest.fixture +def db() -> SqliteItemStorage[TestModel]: + db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False) + return SqliteItemStorage[TestModel](db_conn, "test", "id") + +def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]): db.set(TestModel(id="1", name="Test")) assert db.get("1") == TestModel(id="1", name="Test") -def test_sqlite_service_can_list(): - db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") +def test_sqlite_service_can_list(db: SqliteItemStorage[TestModel]): db.set(TestModel(id="1", name="Test")) db.set(TestModel(id="2", name="Test")) db.set(TestModel(id="3", name="Test")) @@ -30,15 +35,13 @@ def test_sqlite_service_can_list(): ] -def test_sqlite_service_can_delete(): - db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") +def test_sqlite_service_can_delete(db: SqliteItemStorage[TestModel]): db.set(TestModel(id="1", name="Test")) db.delete("1") assert db.get("1") is None -def test_sqlite_service_calls_set_callback(): - db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") +def test_sqlite_service_calls_set_callback(db: SqliteItemStorage[TestModel]): called = False def on_changed(item: TestModel): @@ -50,8 +53,7 @@ def test_sqlite_service_calls_set_callback(): assert called -def test_sqlite_service_calls_delete_callback(): - db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") +def test_sqlite_service_calls_delete_callback(db: SqliteItemStorage[TestModel]): called = False def on_deleted(item_id: str): @@ -64,8 +66,7 @@ def test_sqlite_service_calls_delete_callback(): assert called -def test_sqlite_service_can_list_with_pagination(): - db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") +def test_sqlite_service_can_list_with_pagination(db: SqliteItemStorage[TestModel]): db.set(TestModel(id="1", name="Test")) db.set(TestModel(id="2", name="Test")) db.set(TestModel(id="3", name="Test")) @@ -77,8 +78,7 @@ def test_sqlite_service_can_list_with_pagination(): assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")] -def test_sqlite_service_can_list_with_pagination_and_offset(): - db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") +def test_sqlite_service_can_list_with_pagination_and_offset(db: SqliteItemStorage[TestModel]): db.set(TestModel(id="1", name="Test")) db.set(TestModel(id="2", name="Test")) db.set(TestModel(id="3", name="Test")) @@ -90,8 +90,7 @@ def test_sqlite_service_can_list_with_pagination_and_offset(): assert results.items == [TestModel(id="3", name="Test")] -def test_sqlite_service_can_search(): - db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") +def test_sqlite_service_can_search(db: SqliteItemStorage[TestModel]): db.set(TestModel(id="1", name="Test")) db.set(TestModel(id="2", name="Test")) db.set(TestModel(id="3", name="Test")) @@ -107,8 +106,7 @@ def test_sqlite_service_can_search(): ] -def test_sqlite_service_can_search_with_pagination(): - db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") +def test_sqlite_service_can_search_with_pagination(db: SqliteItemStorage[TestModel]): db.set(TestModel(id="1", name="Test")) db.set(TestModel(id="2", name="Test")) db.set(TestModel(id="3", name="Test")) @@ -120,8 +118,7 @@ def test_sqlite_service_can_search_with_pagination(): assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")] -def test_sqlite_service_can_search_with_pagination_and_offset(): - db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") +def test_sqlite_service_can_search_with_pagination_and_offset(db: SqliteItemStorage[TestModel]): db.set(TestModel(id="1", name="Test")) db.set(TestModel(id="2", name="Test")) db.set(TestModel(id="3", name="Test"))