diff --git a/ldm/invoke/app/api/dependencies.py b/ldm/invoke/app/api/dependencies.py index 60dd522803..08f362133e 100644 --- a/ldm/invoke/app/api/dependencies.py +++ b/ldm/invoke/app/api/dependencies.py @@ -13,7 +13,7 @@ from ...globals import Globals from ..services.image_storage import DiskImageStorage from ..services.invocation_queue import MemoryInvocationQueue from ..services.invocation_services import InvocationServices -from ..services.invoker import Invoker, InvokerServices +from ..services.invoker import Invoker from ..services.generate_initializer import get_generate from .events import FastAPIEventService @@ -60,22 +60,19 @@ class ApiDependencies: images = DiskImageStorage(output_folder) - services = InvocationServices( - generate = generate, - events = events, - images = images - ) - # TODO: build a file/path manager? db_location = os.path.join(output_folder, 'invokeai.db') - invoker_services = InvokerServices( - queue = MemoryInvocationQueue(), + services = InvocationServices( + generate = generate, + events = events, + images = images, + queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor() + processor = DefaultInvocationProcessor() ) - ApiDependencies.invoker = Invoker(services, invoker_services) + ApiDependencies.invoker = Invoker(services) @staticmethod def shutdown(): diff --git a/ldm/invoke/app/api/routers/sessions.py b/ldm/invoke/app/api/routers/sessions.py index 77008ad6e4..beb13736c6 100644 --- a/ldm/invoke/app/api/routers/sessions.py +++ b/ldm/invoke/app/api/routers/sessions.py @@ -44,9 +44,9 @@ async def list_sessions( ) -> PaginatedResults[GraphExecutionState]: """Gets a list of sessions, optionally searching""" if filter == '': - result = ApiDependencies.invoker.invoker_services.graph_execution_manager.list(page, per_page) + result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page) else: - result = ApiDependencies.invoker.invoker_services.graph_execution_manager.search(query, page, per_page) + result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page) return result @@ -60,7 +60,7 @@ async def get_session( session_id: str = Path(description = "The id of the session to get") ) -> GraphExecutionState: """Gets a session""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) else: @@ -80,13 +80,13 @@ async def add_node( node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add") ) -> str: """Adds a node to the graph""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: session.add_node(node) - ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session.id except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -108,13 +108,13 @@ async def update_node( node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node") ) -> GraphExecutionState: """Updates a node in the graph and removes all linked edges""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: session.update_node(node_path, node) - ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -135,13 +135,13 @@ async def delete_node( node_path: str = Path(description = "The path to the node to delete") ) -> GraphExecutionState: """Deletes a node in the graph and removes all linked edges""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: session.delete_node(node_path) - ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -162,13 +162,13 @@ async def add_edge( edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add") ) -> GraphExecutionState: """Adds an edge to the graph""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: session.add_edge(edge) - ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -193,14 +193,14 @@ async def delete_edge( to_field: str = Path(description = "The field of the node the edge is going to") ) -> GraphExecutionState: """Deletes an edge from the graph""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field)) session.delete_edge(edge) - ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -221,7 +221,7 @@ async def invoke_session( all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations") ) -> None: """Invokes a session""" - session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) diff --git a/ldm/invoke/app/cli_app.py b/ldm/invoke/app/cli_app.py index 6071afabb2..9081f3b083 100644 --- a/ldm/invoke/app/cli_app.py +++ b/ldm/invoke/app/cli_app.py @@ -20,7 +20,7 @@ from .services.image_storage import DiskImageStorage from .services.invocation_queue import MemoryInvocationQueue from .invocations.baseinvocation import BaseInvocation from .services.invocation_services import InvocationServices -from .services.invoker import Invoker, InvokerServices +from .services.invoker import Invoker from .invocations import * from ..args import Args from .services.events import EventServiceBase @@ -171,28 +171,25 @@ def invoke_cli(): output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs')) - services = InvocationServices( - generate = generate, - events = events, - images = DiskImageStorage(output_folder) - ) - # TODO: build a file/path manager? db_location = os.path.join(output_folder, 'invokeai.db') - invoker_services = InvokerServices( - queue = MemoryInvocationQueue(), + services = InvocationServices( + generate = generate, + events = events, + images = DiskImageStorage(output_folder), + queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor() + processor = DefaultInvocationProcessor() ) - invoker = Invoker(services, invoker_services) + invoker = Invoker(services) session = invoker.create_execution_state() parser = get_invocation_parser() # Uncomment to print out previous sessions at startup - # print(invoker_services.session_manager.list()) + # print(services.session_manager.list()) # Defaults storage defaults: Dict[str, Any] = dict() @@ -213,7 +210,7 @@ def invoke_cli(): try: # Refresh the state of the session - session = invoker.invoker_services.graph_execution_manager.get(session.id) + session = invoker.services.graph_execution_manager.get(session.id) history = list(get_graph_execution_history(session)) # Split the command for piping @@ -289,7 +286,7 @@ def invoke_cli(): invoker.invoke(session, invoke_all = True) while not session.is_complete(): # Wait some time - session = invoker.invoker_services.graph_execution_manager.get(session.id) + session = invoker.services.graph_execution_manager.get(session.id) time.sleep(0.1) except InvalidArgs: diff --git a/ldm/invoke/app/services/invocation_services.py b/ldm/invoke/app/services/invocation_services.py index 9eb5309d3d..40a64e64e5 100644 --- a/ldm/invoke/app/services/invocation_services.py +++ b/ldm/invoke/app/services/invocation_services.py @@ -1,4 +1,6 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) +from .invocation_queue import InvocationQueueABC +from .item_storage import ItemStorageABC from .image_storage import ImageStorageBase from .events import EventServiceBase from ....generate import Generate @@ -9,12 +11,23 @@ class InvocationServices(): generate: Generate # TODO: wrap Generate, or split it up from model? events: EventServiceBase images: ImageStorageBase + queue: InvocationQueueABC + + # NOTE: we must forward-declare any types that include invocations, since invocations can use services + graph_execution_manager: ItemStorageABC['GraphExecutionState'] + processor: 'InvocationProcessorABC' def __init__(self, generate: Generate, events: EventServiceBase, - images: ImageStorageBase + images: ImageStorageBase, + queue: InvocationQueueABC, + graph_execution_manager: ItemStorageABC['GraphExecutionState'], + processor: 'InvocationProcessorABC' ): self.generate = generate self.events = events self.images = images + self.queue = queue + self.graph_execution_manager = graph_execution_manager + self.processor = processor diff --git a/ldm/invoke/app/services/invoker.py b/ldm/invoke/app/services/invoker.py index 796f541781..4397a75021 100644 --- a/ldm/invoke/app/services/invoker.py +++ b/ldm/invoke/app/services/invoker.py @@ -9,34 +9,15 @@ from .invocation_services import InvocationServices from .invocation_queue import InvocationQueueABC, InvocationQueueItem -class InvokerServices: - """Services used by the Invoker for execution""" - - queue: InvocationQueueABC - graph_execution_manager: ItemStorageABC[GraphExecutionState] - processor: 'InvocationProcessorABC' - - def __init__(self, - queue: InvocationQueueABC, - graph_execution_manager: ItemStorageABC[GraphExecutionState], - processor: 'InvocationProcessorABC'): - self.queue = queue - self.graph_execution_manager = graph_execution_manager - self.processor = processor - - class Invoker: """The invoker, used to execute invocations""" services: InvocationServices - invoker_services: InvokerServices def __init__(self, - services: InvocationServices, # Services used by nodes to perform invocations - invoker_services: InvokerServices # Services used by the invoker for orchestration + services: InvocationServices ): self.services = services - self.invoker_services = invoker_services self._start() @@ -49,11 +30,11 @@ class Invoker: return None # Save the execution state - self.invoker_services.graph_execution_manager.set(graph_execution_state) + self.services.graph_execution_manager.set(graph_execution_state) # Queue the invocation print(f'queueing item {invocation.id}') - self.invoker_services.queue.put(InvocationQueueItem( + self.services.queue.put(InvocationQueueItem( #session_id = session.id, graph_execution_state_id = graph_execution_state.id, invocation_id = invocation.id, @@ -66,7 +47,7 @@ class Invoker: def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState: """Creates a new execution state for the given graph""" new_state = GraphExecutionState(graph = Graph() if graph is None else graph) - self.invoker_services.graph_execution_manager.set(new_state) + self.services.graph_execution_manager.set(new_state) return new_state @@ -86,8 +67,8 @@ class Invoker: def _start(self) -> None: """Starts the invoker. This is called automatically when the invoker is created.""" - for service in vars(self.invoker_services): - self.__start_service(getattr(self.invoker_services, service)) + for service in vars(self.services): + self.__start_service(getattr(self.services, service)) for service in vars(self.services): self.__start_service(getattr(self.services, service)) @@ -99,10 +80,10 @@ class Invoker: for service in vars(self.services): self.__stop_service(getattr(self.services, service)) - for service in vars(self.invoker_services): - self.__stop_service(getattr(self.invoker_services, service)) + for service in vars(self.services): + self.__stop_service(getattr(self.services, service)) - self.invoker_services.queue.put(None) + self.services.queue.put(None) class InvocationProcessorABC(ABC): diff --git a/ldm/invoke/app/services/processor.py b/ldm/invoke/app/services/processor.py index 9b51a6bcbc..9ea4349bbf 100644 --- a/ldm/invoke/app/services/processor.py +++ b/ldm/invoke/app/services/processor.py @@ -28,11 +28,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC): def __process(self, stop_event: Event): try: while not stop_event.is_set(): - queue_item: InvocationQueueItem = self.__invoker.invoker_services.queue.get() + queue_item: InvocationQueueItem = self.__invoker.services.queue.get() if not queue_item: # Probably stopping continue - graph_execution_state = self.__invoker.invoker_services.graph_execution_manager.get(queue_item.graph_execution_state_id) + graph_execution_state = self.__invoker.services.graph_execution_manager.get(queue_item.graph_execution_state_id) invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id) # Send starting event @@ -52,7 +52,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): graph_execution_state.complete(invocation.id, outputs) # Save the state changes - self.__invoker.invoker_services.graph_execution_manager.set(graph_execution_state) + self.__invoker.services.graph_execution_manager.set(graph_execution_state) # Send complete event self.__invoker.services.events.emit_invocation_complete( diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index 0a5dcc7734..980c262501 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -1,10 +1,11 @@ from .test_invoker import create_edge from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from ldm.invoke.app.services.processor import DefaultInvocationProcessor +from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory +from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue from ldm.invoke.app.services.invocation_services import InvocationServices from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState -from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation -from ldm.invoke.app.invocations.upscale import UpscaleInvocation import pytest @@ -19,7 +20,14 @@ def simple_graph(): @pytest.fixture def mock_services(): # NOTE: none of these are actually called by the test invocations - return InvocationServices(generate = None, events = None, images = None) + return InvocationServices( + generate = None, + events = None, + images = None, + queue = MemoryInvocationQueue(), + graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), + processor = DefaultInvocationProcessor() + ) def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: n = g.next() diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index a6d96f61c0..e9109728d5 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -2,12 +2,10 @@ from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTe from ldm.invoke.app.services.processor import DefaultInvocationProcessor from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue -from ldm.invoke.app.services.invoker import Invoker, InvokerServices +from ldm.invoke.app.services.invoker import Invoker from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from ldm.invoke.app.services.invocation_services import InvocationServices from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState -from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation -from ldm.invoke.app.invocations.upscale import UpscaleInvocation import pytest @@ -22,21 +20,19 @@ def simple_graph(): @pytest.fixture def mock_services() -> InvocationServices: # NOTE: none of these are actually called by the test invocations - return InvocationServices(generate = None, events = TestEventService(), images = None) - -@pytest.fixture() -def mock_invoker_services() -> InvokerServices: - return InvokerServices( + return InvocationServices( + generate = None, + events = TestEventService(), + images = None, queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), processor = DefaultInvocationProcessor() ) @pytest.fixture() -def mock_invoker(mock_services: InvocationServices, mock_invoker_services: InvokerServices) -> Invoker: +def mock_invoker(mock_services: InvocationServices) -> Invoker: return Invoker( - services = mock_services, - invoker_services = mock_invoker_services + services = mock_services ) def test_can_create_graph_state(mock_invoker: Invoker): @@ -60,13 +56,13 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph): assert invocation_id is not None def has_executed_any(g: GraphExecutionState): - g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) + g = mock_invoker.services.graph_execution_manager.get(g.id) return len(g.executed) > 0 wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1) mock_invoker.stop() - g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) + g = mock_invoker.services.graph_execution_manager.get(g.id) assert len(g.executed) > 0 def test_can_invoke_all(mock_invoker: Invoker, simple_graph): @@ -75,11 +71,11 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph): assert invocation_id is not None def has_executed_all(g: GraphExecutionState): - g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) + g = mock_invoker.services.graph_execution_manager.get(g.id) return g.is_complete() wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1) mock_invoker.stop() - g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) + g = mock_invoker.services.graph_execution_manager.get(g.id) assert g.is_complete()