[nodes] Removed InvokerServices, simplying service model

This commit is contained in:
Kyle Schouviller 2023-02-24 20:11:28 -08:00
parent 34e3aa1f88
commit cd98d88fe7
8 changed files with 81 additions and 89 deletions

View File

@ -13,7 +13,7 @@ from ...globals import Globals
from ..services.image_storage import DiskImageStorage from ..services.image_storage import DiskImageStorage
from ..services.invocation_queue import MemoryInvocationQueue from ..services.invocation_queue import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker, InvokerServices from ..services.invoker import Invoker
from ..services.generate_initializer import get_generate from ..services.generate_initializer import get_generate
from .events import FastAPIEventService from .events import FastAPIEventService
@ -60,22 +60,19 @@ class ApiDependencies:
images = DiskImageStorage(output_folder) images = DiskImageStorage(output_folder)
services = InvocationServices(
generate = generate,
events = events,
images = images
)
# TODO: build a file/path manager? # TODO: build a file/path manager?
db_location = os.path.join(output_folder, 'invokeai.db') db_location = os.path.join(output_folder, 'invokeai.db')
invoker_services = InvokerServices( services = InvocationServices(
queue = MemoryInvocationQueue(), generate = generate,
events = events,
images = images,
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor() processor = DefaultInvocationProcessor()
) )
ApiDependencies.invoker = Invoker(services, invoker_services) ApiDependencies.invoker = Invoker(services)
@staticmethod @staticmethod
def shutdown(): def shutdown():

View File

@ -44,9 +44,9 @@ async def list_sessions(
) -> PaginatedResults[GraphExecutionState]: ) -> PaginatedResults[GraphExecutionState]:
"""Gets a list of sessions, optionally searching""" """Gets a list of sessions, optionally searching"""
if filter == '': if filter == '':
result = ApiDependencies.invoker.invoker_services.graph_execution_manager.list(page, per_page) result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
else: else:
result = ApiDependencies.invoker.invoker_services.graph_execution_manager.search(query, page, per_page) result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
return result return result
@ -60,7 +60,7 @@ async def get_session(
session_id: str = Path(description = "The id of the session to get") session_id: str = Path(description = "The id of the session to get")
) -> GraphExecutionState: ) -> GraphExecutionState:
"""Gets a session""" """Gets a session"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None: if session is None:
return Response(status_code = 404) return Response(status_code = 404)
else: else:
@ -80,13 +80,13 @@ async def add_node(
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add") node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add")
) -> str: ) -> str:
"""Adds a node to the graph""" """Adds a node to the graph"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None: if session is None:
return Response(status_code = 404) return Response(status_code = 404)
try: try:
session.add_node(node) session.add_node(node)
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session.id return session.id
except NodeAlreadyExecutedError: except NodeAlreadyExecutedError:
return Response(status_code = 400) return Response(status_code = 400)
@ -108,13 +108,13 @@ async def update_node(
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node") node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node")
) -> GraphExecutionState: ) -> GraphExecutionState:
"""Updates a node in the graph and removes all linked edges""" """Updates a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None: if session is None:
return Response(status_code = 404) return Response(status_code = 404)
try: try:
session.update_node(node_path, node) session.update_node(node_path, node)
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session return session
except NodeAlreadyExecutedError: except NodeAlreadyExecutedError:
return Response(status_code = 400) return Response(status_code = 400)
@ -135,13 +135,13 @@ async def delete_node(
node_path: str = Path(description = "The path to the node to delete") node_path: str = Path(description = "The path to the node to delete")
) -> GraphExecutionState: ) -> GraphExecutionState:
"""Deletes a node in the graph and removes all linked edges""" """Deletes a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None: if session is None:
return Response(status_code = 404) return Response(status_code = 404)
try: try:
session.delete_node(node_path) session.delete_node(node_path)
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session return session
except NodeAlreadyExecutedError: except NodeAlreadyExecutedError:
return Response(status_code = 400) return Response(status_code = 400)
@ -162,13 +162,13 @@ async def add_edge(
edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add") edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add")
) -> GraphExecutionState: ) -> GraphExecutionState:
"""Adds an edge to the graph""" """Adds an edge to the graph"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None: if session is None:
return Response(status_code = 404) return Response(status_code = 404)
try: try:
session.add_edge(edge) session.add_edge(edge)
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session return session
except NodeAlreadyExecutedError: except NodeAlreadyExecutedError:
return Response(status_code = 400) return Response(status_code = 400)
@ -193,14 +193,14 @@ async def delete_edge(
to_field: str = Path(description = "The field of the node the edge is going to") to_field: str = Path(description = "The field of the node the edge is going to")
) -> GraphExecutionState: ) -> GraphExecutionState:
"""Deletes an edge from the graph""" """Deletes an edge from the graph"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None: if session is None:
return Response(status_code = 404) return Response(status_code = 404)
try: try:
edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field)) edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field))
session.delete_edge(edge) session.delete_edge(edge)
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session return session
except NodeAlreadyExecutedError: except NodeAlreadyExecutedError:
return Response(status_code = 400) return Response(status_code = 400)
@ -221,7 +221,7 @@ async def invoke_session(
all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations") all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations")
) -> None: ) -> None:
"""Invokes a session""" """Invokes a session"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None: if session is None:
return Response(status_code = 404) return Response(status_code = 404)

View File

@ -20,7 +20,7 @@ from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_queue import MemoryInvocationQueue
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
from .services.invocation_services import InvocationServices from .services.invocation_services import InvocationServices
from .services.invoker import Invoker, InvokerServices from .services.invoker import Invoker
from .invocations import * from .invocations import *
from ..args import Args from ..args import Args
from .services.events import EventServiceBase from .services.events import EventServiceBase
@ -171,28 +171,25 @@ def invoke_cli():
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs')) output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs'))
services = InvocationServices(
generate = generate,
events = events,
images = DiskImageStorage(output_folder)
)
# TODO: build a file/path manager? # TODO: build a file/path manager?
db_location = os.path.join(output_folder, 'invokeai.db') db_location = os.path.join(output_folder, 'invokeai.db')
invoker_services = InvokerServices( services = InvocationServices(
queue = MemoryInvocationQueue(), generate = generate,
events = events,
images = DiskImageStorage(output_folder),
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor() processor = DefaultInvocationProcessor()
) )
invoker = Invoker(services, invoker_services) invoker = Invoker(services)
session = invoker.create_execution_state() session = invoker.create_execution_state()
parser = get_invocation_parser() parser = get_invocation_parser()
# Uncomment to print out previous sessions at startup # Uncomment to print out previous sessions at startup
# print(invoker_services.session_manager.list()) # print(services.session_manager.list())
# Defaults storage # Defaults storage
defaults: Dict[str, Any] = dict() defaults: Dict[str, Any] = dict()
@ -213,7 +210,7 @@ def invoke_cli():
try: try:
# Refresh the state of the session # Refresh the state of the session
session = invoker.invoker_services.graph_execution_manager.get(session.id) session = invoker.services.graph_execution_manager.get(session.id)
history = list(get_graph_execution_history(session)) history = list(get_graph_execution_history(session))
# Split the command for piping # Split the command for piping
@ -289,7 +286,7 @@ def invoke_cli():
invoker.invoke(session, invoke_all = True) invoker.invoke(session, invoke_all = True)
while not session.is_complete(): while not session.is_complete():
# Wait some time # Wait some time
session = invoker.invoker_services.graph_execution_manager.get(session.id) session = invoker.services.graph_execution_manager.get(session.id)
time.sleep(0.1) time.sleep(0.1)
except InvalidArgs: except InvalidArgs:

View File

@ -1,4 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC
from .image_storage import ImageStorageBase from .image_storage import ImageStorageBase
from .events import EventServiceBase from .events import EventServiceBase
from ....generate import Generate from ....generate import Generate
@ -9,12 +11,23 @@ class InvocationServices():
generate: Generate # TODO: wrap Generate, or split it up from model? generate: Generate # TODO: wrap Generate, or split it up from model?
events: EventServiceBase events: EventServiceBase
images: ImageStorageBase images: ImageStorageBase
queue: InvocationQueueABC
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_execution_manager: ItemStorageABC['GraphExecutionState']
processor: 'InvocationProcessorABC'
def __init__(self, def __init__(self,
generate: Generate, generate: Generate,
events: EventServiceBase, events: EventServiceBase,
images: ImageStorageBase images: ImageStorageBase,
queue: InvocationQueueABC,
graph_execution_manager: ItemStorageABC['GraphExecutionState'],
processor: 'InvocationProcessorABC'
): ):
self.generate = generate self.generate = generate
self.events = events self.events = events
self.images = images self.images = images
self.queue = queue
self.graph_execution_manager = graph_execution_manager
self.processor = processor

View File

@ -9,34 +9,15 @@ from .invocation_services import InvocationServices
from .invocation_queue import InvocationQueueABC, InvocationQueueItem from .invocation_queue import InvocationQueueABC, InvocationQueueItem
class InvokerServices:
"""Services used by the Invoker for execution"""
queue: InvocationQueueABC
graph_execution_manager: ItemStorageABC[GraphExecutionState]
processor: 'InvocationProcessorABC'
def __init__(self,
queue: InvocationQueueABC,
graph_execution_manager: ItemStorageABC[GraphExecutionState],
processor: 'InvocationProcessorABC'):
self.queue = queue
self.graph_execution_manager = graph_execution_manager
self.processor = processor
class Invoker: class Invoker:
"""The invoker, used to execute invocations""" """The invoker, used to execute invocations"""
services: InvocationServices services: InvocationServices
invoker_services: InvokerServices
def __init__(self, def __init__(self,
services: InvocationServices, # Services used by nodes to perform invocations services: InvocationServices
invoker_services: InvokerServices # Services used by the invoker for orchestration
): ):
self.services = services self.services = services
self.invoker_services = invoker_services
self._start() self._start()
@ -49,11 +30,11 @@ class Invoker:
return None return None
# Save the execution state # Save the execution state
self.invoker_services.graph_execution_manager.set(graph_execution_state) self.services.graph_execution_manager.set(graph_execution_state)
# Queue the invocation # Queue the invocation
print(f'queueing item {invocation.id}') print(f'queueing item {invocation.id}')
self.invoker_services.queue.put(InvocationQueueItem( self.services.queue.put(InvocationQueueItem(
#session_id = session.id, #session_id = session.id,
graph_execution_state_id = graph_execution_state.id, graph_execution_state_id = graph_execution_state.id,
invocation_id = invocation.id, invocation_id = invocation.id,
@ -66,7 +47,7 @@ class Invoker:
def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState: def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState:
"""Creates a new execution state for the given graph""" """Creates a new execution state for the given graph"""
new_state = GraphExecutionState(graph = Graph() if graph is None else graph) new_state = GraphExecutionState(graph = Graph() if graph is None else graph)
self.invoker_services.graph_execution_manager.set(new_state) self.services.graph_execution_manager.set(new_state)
return new_state return new_state
@ -86,8 +67,8 @@ class Invoker:
def _start(self) -> None: def _start(self) -> None:
"""Starts the invoker. This is called automatically when the invoker is created.""" """Starts the invoker. This is called automatically when the invoker is created."""
for service in vars(self.invoker_services): for service in vars(self.services):
self.__start_service(getattr(self.invoker_services, service)) self.__start_service(getattr(self.services, service))
for service in vars(self.services): for service in vars(self.services):
self.__start_service(getattr(self.services, service)) self.__start_service(getattr(self.services, service))
@ -99,10 +80,10 @@ class Invoker:
for service in vars(self.services): for service in vars(self.services):
self.__stop_service(getattr(self.services, service)) self.__stop_service(getattr(self.services, service))
for service in vars(self.invoker_services): for service in vars(self.services):
self.__stop_service(getattr(self.invoker_services, service)) self.__stop_service(getattr(self.services, service))
self.invoker_services.queue.put(None) self.services.queue.put(None)
class InvocationProcessorABC(ABC): class InvocationProcessorABC(ABC):

View File

@ -28,11 +28,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
def __process(self, stop_event: Event): def __process(self, stop_event: Event):
try: try:
while not stop_event.is_set(): while not stop_event.is_set():
queue_item: InvocationQueueItem = self.__invoker.invoker_services.queue.get() queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
if not queue_item: # Probably stopping if not queue_item: # Probably stopping
continue continue
graph_execution_state = self.__invoker.invoker_services.graph_execution_manager.get(queue_item.graph_execution_state_id) graph_execution_state = self.__invoker.services.graph_execution_manager.get(queue_item.graph_execution_state_id)
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id) invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
# Send starting event # Send starting event
@ -52,7 +52,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
graph_execution_state.complete(invocation.id, outputs) graph_execution_state.complete(invocation.id, outputs)
# Save the state changes # Save the state changes
self.__invoker.invoker_services.graph_execution_manager.set(graph_execution_state) self.__invoker.services.graph_execution_manager.set(graph_execution_state)
# Send complete event # Send complete event
self.__invoker.services.events.emit_invocation_complete( self.__invoker.services.events.emit_invocation_complete(

View File

@ -1,10 +1,11 @@
from .test_invoker import create_edge from .test_invoker import create_edge
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from ldm.invoke.app.services.processor import DefaultInvocationProcessor
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
from ldm.invoke.app.services.invocation_services import InvocationServices from ldm.invoke.app.services.invocation_services import InvocationServices
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
import pytest import pytest
@ -19,7 +20,14 @@ def simple_graph():
@pytest.fixture @pytest.fixture
def mock_services(): def mock_services():
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
return InvocationServices(generate = None, events = None, images = None) return InvocationServices(
generate = None,
events = None,
images = None,
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor()
)
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
n = g.next() n = g.next()

View File

@ -2,12 +2,10 @@ from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTe
from ldm.invoke.app.services.processor import DefaultInvocationProcessor from ldm.invoke.app.services.processor import DefaultInvocationProcessor
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
from ldm.invoke.app.services.invoker import Invoker, InvokerServices from ldm.invoke.app.services.invoker import Invoker
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from ldm.invoke.app.services.invocation_services import InvocationServices from ldm.invoke.app.services.invocation_services import InvocationServices
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
import pytest import pytest
@ -22,21 +20,19 @@ def simple_graph():
@pytest.fixture @pytest.fixture
def mock_services() -> InvocationServices: def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
return InvocationServices(generate = None, events = TestEventService(), images = None) return InvocationServices(
generate = None,
@pytest.fixture() events = TestEventService(),
def mock_invoker_services() -> InvokerServices: images = None,
return InvokerServices(
queue = MemoryInvocationQueue(), queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor() processor = DefaultInvocationProcessor()
) )
@pytest.fixture() @pytest.fixture()
def mock_invoker(mock_services: InvocationServices, mock_invoker_services: InvokerServices) -> Invoker: def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker( return Invoker(
services = mock_services, services = mock_services
invoker_services = mock_invoker_services
) )
def test_can_create_graph_state(mock_invoker: Invoker): def test_can_create_graph_state(mock_invoker: Invoker):
@ -60,13 +56,13 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
assert invocation_id is not None assert invocation_id is not None
def has_executed_any(g: GraphExecutionState): def has_executed_any(g: GraphExecutionState):
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) g = mock_invoker.services.graph_execution_manager.get(g.id)
return len(g.executed) > 0 return len(g.executed) > 0
wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1) wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1)
mock_invoker.stop() mock_invoker.stop()
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) g = mock_invoker.services.graph_execution_manager.get(g.id)
assert len(g.executed) > 0 assert len(g.executed) > 0
def test_can_invoke_all(mock_invoker: Invoker, simple_graph): def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
@ -75,11 +71,11 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
assert invocation_id is not None assert invocation_id is not None
def has_executed_all(g: GraphExecutionState): def has_executed_all(g: GraphExecutionState):
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) g = mock_invoker.services.graph_execution_manager.get(g.id)
return g.is_complete() return g.is_complete()
wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1) wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1)
mock_invoker.stop() mock_invoker.stop()
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) g = mock_invoker.services.graph_execution_manager.get(g.id)
assert g.is_complete() assert g.is_complete()