mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix tests
This commit is contained in:
parent
ef8dc2e8c5
commit
d6a5c2dbe3
@ -16,7 +16,7 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.graph import Graph
|
from invokeai.app.services.graph import Graph
|
||||||
from invokeai.app.models.image import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, Extra, parse_raw_as, StrictStr, StrictInt, StrictFloat
|
from pydantic import BaseModel, Field, Extra, parse_raw_as, StrictStr, StrictInt, StrictFloat
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ from invokeai.app.services.graph import (
|
|||||||
LibraryGraph,
|
LibraryGraph,
|
||||||
)
|
)
|
||||||
import pytest
|
import pytest
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def simple_graph():
|
def simple_graph():
|
||||||
@ -42,8 +42,9 @@ def simple_graph():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_services() -> InvocationServices:
|
def mock_services() -> InvocationServices:
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
|
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||||
filename=sqlite_memory, table_name="graph_executions"
|
conn=db_conn, table_name="graph_executions"
|
||||||
)
|
)
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
model_manager=None, # type: ignore
|
model_manager=None, # type: ignore
|
||||||
@ -55,7 +56,7 @@ def mock_services() -> InvocationServices:
|
|||||||
batch_manager=None, # type: ignore
|
batch_manager=None, # type: ignore
|
||||||
board_images=None, # type: ignore
|
board_images=None, # type: ignore
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
|
@ -23,6 +23,7 @@ from invokeai.app.services.graph import (
|
|||||||
LibraryGraph,
|
LibraryGraph,
|
||||||
)
|
)
|
||||||
import pytest
|
import pytest
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -87,10 +88,11 @@ def simple_batches():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_services() -> InvocationServices:
|
def mock_services() -> InvocationServices:
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
|
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||||
filename=sqlite_memory, table_name="graph_executions"
|
conn=db_conn, table_name="graph_executions"
|
||||||
)
|
)
|
||||||
batch_manager_storage = SqliteBatchProcessStorage(sqlite_memory)
|
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
model_manager=None, # type: ignore
|
model_manager=None, # type: ignore
|
||||||
events=TestEventService(),
|
events=TestEventService(),
|
||||||
@ -101,7 +103,7 @@ def mock_services() -> InvocationServices:
|
|||||||
boards=None, # type: ignore
|
boards=None, # type: ignore
|
||||||
board_images=None, # type: ignore
|
board_images=None, # type: ignore
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||||
|
@ -1,20 +1,25 @@
|
|||||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
class TestModel(BaseModel):
|
class TestModel(BaseModel):
|
||||||
id: str = Field(description="ID")
|
id: str = Field(description="ID")
|
||||||
name: str = Field(description="Name")
|
name: str = Field(description="Name")
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_create_and_get():
|
@pytest.fixture
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
def db() -> SqliteItemStorage[TestModel]:
|
||||||
|
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||||
|
return SqliteItemStorage[TestModel](db_conn, "test", "id")
|
||||||
|
|
||||||
|
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
|
||||||
db.set(TestModel(id="1", name="Test"))
|
db.set(TestModel(id="1", name="Test"))
|
||||||
assert db.get("1") == TestModel(id="1", name="Test")
|
assert db.get("1") == TestModel(id="1", name="Test")
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_list():
|
def test_sqlite_service_can_list(db: SqliteItemStorage[TestModel]):
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
|
||||||
db.set(TestModel(id="1", name="Test"))
|
db.set(TestModel(id="1", name="Test"))
|
||||||
db.set(TestModel(id="2", name="Test"))
|
db.set(TestModel(id="2", name="Test"))
|
||||||
db.set(TestModel(id="3", name="Test"))
|
db.set(TestModel(id="3", name="Test"))
|
||||||
@ -30,15 +35,13 @@ def test_sqlite_service_can_list():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_delete():
|
def test_sqlite_service_can_delete(db: SqliteItemStorage[TestModel]):
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
|
||||||
db.set(TestModel(id="1", name="Test"))
|
db.set(TestModel(id="1", name="Test"))
|
||||||
db.delete("1")
|
db.delete("1")
|
||||||
assert db.get("1") is None
|
assert db.get("1") is None
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_calls_set_callback():
|
def test_sqlite_service_calls_set_callback(db: SqliteItemStorage[TestModel]):
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
|
||||||
called = False
|
called = False
|
||||||
|
|
||||||
def on_changed(item: TestModel):
|
def on_changed(item: TestModel):
|
||||||
@ -50,8 +53,7 @@ def test_sqlite_service_calls_set_callback():
|
|||||||
assert called
|
assert called
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_calls_delete_callback():
|
def test_sqlite_service_calls_delete_callback(db: SqliteItemStorage[TestModel]):
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
|
||||||
called = False
|
called = False
|
||||||
|
|
||||||
def on_deleted(item_id: str):
|
def on_deleted(item_id: str):
|
||||||
@ -64,8 +66,7 @@ def test_sqlite_service_calls_delete_callback():
|
|||||||
assert called
|
assert called
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_list_with_pagination():
|
def test_sqlite_service_can_list_with_pagination(db: SqliteItemStorage[TestModel]):
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
|
||||||
db.set(TestModel(id="1", name="Test"))
|
db.set(TestModel(id="1", name="Test"))
|
||||||
db.set(TestModel(id="2", name="Test"))
|
db.set(TestModel(id="2", name="Test"))
|
||||||
db.set(TestModel(id="3", name="Test"))
|
db.set(TestModel(id="3", name="Test"))
|
||||||
@ -77,8 +78,7 @@ def test_sqlite_service_can_list_with_pagination():
|
|||||||
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
|
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_list_with_pagination_and_offset():
|
def test_sqlite_service_can_list_with_pagination_and_offset(db: SqliteItemStorage[TestModel]):
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
|
||||||
db.set(TestModel(id="1", name="Test"))
|
db.set(TestModel(id="1", name="Test"))
|
||||||
db.set(TestModel(id="2", name="Test"))
|
db.set(TestModel(id="2", name="Test"))
|
||||||
db.set(TestModel(id="3", name="Test"))
|
db.set(TestModel(id="3", name="Test"))
|
||||||
@ -90,8 +90,7 @@ def test_sqlite_service_can_list_with_pagination_and_offset():
|
|||||||
assert results.items == [TestModel(id="3", name="Test")]
|
assert results.items == [TestModel(id="3", name="Test")]
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_search():
|
def test_sqlite_service_can_search(db: SqliteItemStorage[TestModel]):
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
|
||||||
db.set(TestModel(id="1", name="Test"))
|
db.set(TestModel(id="1", name="Test"))
|
||||||
db.set(TestModel(id="2", name="Test"))
|
db.set(TestModel(id="2", name="Test"))
|
||||||
db.set(TestModel(id="3", name="Test"))
|
db.set(TestModel(id="3", name="Test"))
|
||||||
@ -107,8 +106,7 @@ def test_sqlite_service_can_search():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_search_with_pagination():
|
def test_sqlite_service_can_search_with_pagination(db: SqliteItemStorage[TestModel]):
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
|
||||||
db.set(TestModel(id="1", name="Test"))
|
db.set(TestModel(id="1", name="Test"))
|
||||||
db.set(TestModel(id="2", name="Test"))
|
db.set(TestModel(id="2", name="Test"))
|
||||||
db.set(TestModel(id="3", name="Test"))
|
db.set(TestModel(id="3", name="Test"))
|
||||||
@ -120,8 +118,7 @@ def test_sqlite_service_can_search_with_pagination():
|
|||||||
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
|
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_search_with_pagination_and_offset():
|
def test_sqlite_service_can_search_with_pagination_and_offset(db: SqliteItemStorage[TestModel]):
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
|
||||||
db.set(TestModel(id="1", name="Test"))
|
db.set(TestModel(id="1", name="Test"))
|
||||||
db.set(TestModel(id="2", name="Test"))
|
db.set(TestModel(id="2", name="Test"))
|
||||||
db.set(TestModel(id="3", name="Test"))
|
db.set(TestModel(id="3", name="Test"))
|
||||||
|
Loading…
Reference in New Issue
Block a user