[nodes] Removed InvokerServices, simplying service model

This commit is contained in:
Kyle Schouviller
2023-02-24 20:11:28 -08:00
parent 34e3aa1f88
commit cd98d88fe7
8 changed files with 81 additions and 89 deletions

View File

@ -1,10 +1,11 @@
from .test_invoker import create_edge
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from ldm.invoke.app.services.processor import DefaultInvocationProcessor
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
from ldm.invoke.app.services.invocation_services import InvocationServices
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
import pytest
@ -19,7 +20,14 @@ def simple_graph():
@pytest.fixture
def mock_services():
# NOTE: none of these are actually called by the test invocations
return InvocationServices(generate = None, events = None, images = None)
return InvocationServices(
generate = None,
events = None,
images = None,
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor()
)
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
n = g.next()

View File

@ -2,12 +2,10 @@ from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTe
from ldm.invoke.app.services.processor import DefaultInvocationProcessor
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
from ldm.invoke.app.services.invoker import Invoker, InvokerServices
from ldm.invoke.app.services.invoker import Invoker
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from ldm.invoke.app.services.invocation_services import InvocationServices
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
import pytest
@ -22,21 +20,19 @@ def simple_graph():
@pytest.fixture
def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations
return InvocationServices(generate = None, events = TestEventService(), images = None)
@pytest.fixture()
def mock_invoker_services() -> InvokerServices:
return InvokerServices(
return InvocationServices(
generate = None,
events = TestEventService(),
images = None,
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor()
)
@pytest.fixture()
def mock_invoker(mock_services: InvocationServices, mock_invoker_services: InvokerServices) -> Invoker:
def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker(
services = mock_services,
invoker_services = mock_invoker_services
services = mock_services
)
def test_can_create_graph_state(mock_invoker: Invoker):
@ -60,13 +56,13 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
assert invocation_id is not None
def has_executed_any(g: GraphExecutionState):
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
g = mock_invoker.services.graph_execution_manager.get(g.id)
return len(g.executed) > 0
wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1)
mock_invoker.stop()
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert len(g.executed) > 0
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
@ -75,11 +71,11 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
assert invocation_id is not None
def has_executed_all(g: GraphExecutionState):
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
g = mock_invoker.services.graph_execution_manager.get(g.id)
return g.is_complete()
wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1)
mock_invoker.stop()
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert g.is_complete()