Fix tests

This commit is contained in:
Brandon Rising 2023-08-16 14:35:49 -04:00
parent ef8dc2e8c5
commit d6a5c2dbe3
4 changed files with 27 additions and 27 deletions

View File

@ -16,7 +16,7 @@ from invokeai.app.invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
) )
from invokeai.app.services.graph import Graph from invokeai.app.services.graph import Graph
from invokeai.app.models.image import ImageField from invokeai.app.invocations.primitives import ImageField
from pydantic import BaseModel, Field, Extra, parse_raw_as, StrictStr, StrictInt, StrictFloat from pydantic import BaseModel, Field, Extra, parse_raw_as, StrictStr, StrictInt, StrictFloat

View File

@ -25,7 +25,7 @@ from invokeai.app.services.graph import (
LibraryGraph, LibraryGraph,
) )
import pytest import pytest
import sqlite3
@pytest.fixture @pytest.fixture
def simple_graph(): def simple_graph():
@ -42,8 +42,9 @@ def simple_graph():
@pytest.fixture @pytest.fixture
def mock_services() -> InvocationServices: def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
graph_execution_manager = SqliteItemStorage[GraphExecutionState]( graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=sqlite_memory, table_name="graph_executions" conn=db_conn, table_name="graph_executions"
) )
return InvocationServices( return InvocationServices(
model_manager=None, # type: ignore model_manager=None, # type: ignore
@ -55,7 +56,7 @@ def mock_services() -> InvocationServices:
batch_manager=None, # type: ignore batch_manager=None, # type: ignore
board_images=None, # type: ignore board_images=None, # type: ignore
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"), graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
performance_statistics=InvocationStatsService(graph_execution_manager), performance_statistics=InvocationStatsService(graph_execution_manager),
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),

View File

@ -23,6 +23,7 @@ from invokeai.app.services.graph import (
LibraryGraph, LibraryGraph,
) )
import pytest import pytest
import sqlite3
@pytest.fixture @pytest.fixture
@ -87,10 +88,11 @@ def simple_batches():
@pytest.fixture @pytest.fixture
def mock_services() -> InvocationServices: def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
graph_execution_manager = SqliteItemStorage[GraphExecutionState]( graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=sqlite_memory, table_name="graph_executions" conn=db_conn, table_name="graph_executions"
) )
batch_manager_storage = SqliteBatchProcessStorage(sqlite_memory) batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
return InvocationServices( return InvocationServices(
model_manager=None, # type: ignore model_manager=None, # type: ignore
events=TestEventService(), events=TestEventService(),
@ -101,7 +103,7 @@ def mock_services() -> InvocationServices:
boards=None, # type: ignore boards=None, # type: ignore
board_images=None, # type: ignore board_images=None, # type: ignore
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"), graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager), performance_statistics=InvocationStatsService(graph_execution_manager),

View File

@ -1,20 +1,25 @@
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import pytest
import sqlite3
class TestModel(BaseModel): class TestModel(BaseModel):
id: str = Field(description="ID") id: str = Field(description="ID")
name: str = Field(description="Name") name: str = Field(description="Name")
def test_sqlite_service_can_create_and_get(): @pytest.fixture
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") def db() -> SqliteItemStorage[TestModel]:
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
return SqliteItemStorage[TestModel](db_conn, "test", "id")
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test")) db.set(TestModel(id="1", name="Test"))
assert db.get("1") == TestModel(id="1", name="Test") assert db.get("1") == TestModel(id="1", name="Test")
def test_sqlite_service_can_list(): def test_sqlite_service_can_list(db: SqliteItemStorage[TestModel]):
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
db.set(TestModel(id="1", name="Test")) db.set(TestModel(id="1", name="Test"))
db.set(TestModel(id="2", name="Test")) db.set(TestModel(id="2", name="Test"))
db.set(TestModel(id="3", name="Test")) db.set(TestModel(id="3", name="Test"))
@ -30,15 +35,13 @@ def test_sqlite_service_can_list():
] ]
def test_sqlite_service_can_delete(): def test_sqlite_service_can_delete(db: SqliteItemStorage[TestModel]):
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
db.set(TestModel(id="1", name="Test")) db.set(TestModel(id="1", name="Test"))
db.delete("1") db.delete("1")
assert db.get("1") is None assert db.get("1") is None
def test_sqlite_service_calls_set_callback(): def test_sqlite_service_calls_set_callback(db: SqliteItemStorage[TestModel]):
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
called = False called = False
def on_changed(item: TestModel): def on_changed(item: TestModel):
@ -50,8 +53,7 @@ def test_sqlite_service_calls_set_callback():
assert called assert called
def test_sqlite_service_calls_delete_callback(): def test_sqlite_service_calls_delete_callback(db: SqliteItemStorage[TestModel]):
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
called = False called = False
def on_deleted(item_id: str): def on_deleted(item_id: str):
@ -64,8 +66,7 @@ def test_sqlite_service_calls_delete_callback():
assert called assert called
def test_sqlite_service_can_list_with_pagination(): def test_sqlite_service_can_list_with_pagination(db: SqliteItemStorage[TestModel]):
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
db.set(TestModel(id="1", name="Test")) db.set(TestModel(id="1", name="Test"))
db.set(TestModel(id="2", name="Test")) db.set(TestModel(id="2", name="Test"))
db.set(TestModel(id="3", name="Test")) db.set(TestModel(id="3", name="Test"))
@ -77,8 +78,7 @@ def test_sqlite_service_can_list_with_pagination():
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")] assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
def test_sqlite_service_can_list_with_pagination_and_offset(): def test_sqlite_service_can_list_with_pagination_and_offset(db: SqliteItemStorage[TestModel]):
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
db.set(TestModel(id="1", name="Test")) db.set(TestModel(id="1", name="Test"))
db.set(TestModel(id="2", name="Test")) db.set(TestModel(id="2", name="Test"))
db.set(TestModel(id="3", name="Test")) db.set(TestModel(id="3", name="Test"))
@ -90,8 +90,7 @@ def test_sqlite_service_can_list_with_pagination_and_offset():
assert results.items == [TestModel(id="3", name="Test")] assert results.items == [TestModel(id="3", name="Test")]
def test_sqlite_service_can_search(): def test_sqlite_service_can_search(db: SqliteItemStorage[TestModel]):
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
db.set(TestModel(id="1", name="Test")) db.set(TestModel(id="1", name="Test"))
db.set(TestModel(id="2", name="Test")) db.set(TestModel(id="2", name="Test"))
db.set(TestModel(id="3", name="Test")) db.set(TestModel(id="3", name="Test"))
@ -107,8 +106,7 @@ def test_sqlite_service_can_search():
] ]
def test_sqlite_service_can_search_with_pagination(): def test_sqlite_service_can_search_with_pagination(db: SqliteItemStorage[TestModel]):
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
db.set(TestModel(id="1", name="Test")) db.set(TestModel(id="1", name="Test"))
db.set(TestModel(id="2", name="Test")) db.set(TestModel(id="2", name="Test"))
db.set(TestModel(id="3", name="Test")) db.set(TestModel(id="3", name="Test"))
@ -120,8 +118,7 @@ def test_sqlite_service_can_search_with_pagination():
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")] assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
def test_sqlite_service_can_search_with_pagination_and_offset(): def test_sqlite_service_can_search_with_pagination_and_offset(db: SqliteItemStorage[TestModel]):
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
db.set(TestModel(id="1", name="Test")) db.set(TestModel(id="1", name="Test"))
db.set(TestModel(id="2", name="Test")) db.set(TestModel(id="2", name="Test"))
db.set(TestModel(id="3", name="Test")) db.set(TestModel(id="3", name="Test"))