diff --git a/ldm/invoke/app/api/dependencies.py b/ldm/invoke/app/api/dependencies.py index 08f362133e..60dd522803 100644 --- a/ldm/invoke/app/api/dependencies.py +++ b/ldm/invoke/app/api/dependencies.py @@ -13,7 +13,7 @@ from ...globals import Globals from ..services.image_storage import DiskImageStorage from ..services.invocation_queue import MemoryInvocationQueue from ..services.invocation_services import InvocationServices -from ..services.invoker import Invoker +from ..services.invoker import Invoker, InvokerServices from ..services.generate_initializer import get_generate from .events import FastAPIEventService @@ -60,19 +60,22 @@ class ApiDependencies: images = DiskImageStorage(output_folder) - # TODO: build a file/path manager? - db_location = os.path.join(output_folder, 'invokeai.db') - services = InvocationServices( generate = generate, events = events, - images = images, - queue = MemoryInvocationQueue(), + images = images + ) + + # TODO: build a file/path manager? + db_location = os.path.join(output_folder, 'invokeai.db') + + invoker_services = InvokerServices( + queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor() + processor = DefaultInvocationProcessor() ) - ApiDependencies.invoker = Invoker(services) + ApiDependencies.invoker = Invoker(services, invoker_services) @staticmethod def shutdown(): diff --git a/ldm/invoke/app/api/routers/sessions.py b/ldm/invoke/app/api/routers/sessions.py index beb13736c6..77008ad6e4 100644 --- a/ldm/invoke/app/api/routers/sessions.py +++ b/ldm/invoke/app/api/routers/sessions.py @@ -44,9 +44,9 @@ async def list_sessions( ) -> PaginatedResults[GraphExecutionState]: """Gets a list of sessions, optionally searching""" if filter == '': - result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page) + result = ApiDependencies.invoker.invoker_services.graph_execution_manager.list(page, per_page) else: - result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page) + result = ApiDependencies.invoker.invoker_services.graph_execution_manager.search(query, page, per_page) return result @@ -60,7 +60,7 @@ async def get_session( session_id: str = Path(description = "The id of the session to get") ) -> GraphExecutionState: """Gets a session""" - session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) else: @@ -80,13 +80,13 @@ async def add_node( node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add") ) -> str: """Adds a node to the graph""" - session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: session.add_node(node) - ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session.id except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -108,13 +108,13 @@ async def update_node( node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node") ) -> GraphExecutionState: """Updates a node in the graph and removes all linked edges""" - session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: session.update_node(node_path, node) - ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -135,13 +135,13 @@ async def delete_node( node_path: str = Path(description = "The path to the node to delete") ) -> GraphExecutionState: """Deletes a node in the graph and removes all linked edges""" - session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: session.delete_node(node_path) - ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -162,13 +162,13 @@ async def add_edge( edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add") ) -> GraphExecutionState: """Adds an edge to the graph""" - session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: session.add_edge(edge) - ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -193,14 +193,14 @@ async def delete_edge( to_field: str = Path(description = "The field of the node the edge is going to") ) -> GraphExecutionState: """Deletes an edge from the graph""" - session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) try: edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field)) session.delete_edge(edge) - ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? return session except NodeAlreadyExecutedError: return Response(status_code = 400) @@ -221,7 +221,7 @@ async def invoke_session( all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations") ) -> None: """Invokes a session""" - session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) if session is None: return Response(status_code = 404) diff --git a/ldm/invoke/app/cli_app.py b/ldm/invoke/app/cli_app.py index 9081f3b083..6071afabb2 100644 --- a/ldm/invoke/app/cli_app.py +++ b/ldm/invoke/app/cli_app.py @@ -20,7 +20,7 @@ from .services.image_storage import DiskImageStorage from .services.invocation_queue import MemoryInvocationQueue from .invocations.baseinvocation import BaseInvocation from .services.invocation_services import InvocationServices -from .services.invoker import Invoker +from .services.invoker import Invoker, InvokerServices from .invocations import * from ..args import Args from .services.events import EventServiceBase @@ -171,25 +171,28 @@ def invoke_cli(): output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs')) + services = InvocationServices( + generate = generate, + events = events, + images = DiskImageStorage(output_folder) + ) + # TODO: build a file/path manager? db_location = os.path.join(output_folder, 'invokeai.db') - services = InvocationServices( - generate = generate, - events = events, - images = DiskImageStorage(output_folder), - queue = MemoryInvocationQueue(), + invoker_services = InvokerServices( + queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor() + processor = DefaultInvocationProcessor() ) - invoker = Invoker(services) + invoker = Invoker(services, invoker_services) session = invoker.create_execution_state() parser = get_invocation_parser() # Uncomment to print out previous sessions at startup - # print(services.session_manager.list()) + # print(invoker_services.session_manager.list()) # Defaults storage defaults: Dict[str, Any] = dict() @@ -210,7 +213,7 @@ def invoke_cli(): try: # Refresh the state of the session - session = invoker.services.graph_execution_manager.get(session.id) + session = invoker.invoker_services.graph_execution_manager.get(session.id) history = list(get_graph_execution_history(session)) # Split the command for piping @@ -286,7 +289,7 @@ def invoke_cli(): invoker.invoke(session, invoke_all = True) while not session.is_complete(): # Wait some time - session = invoker.services.graph_execution_manager.get(session.id) + session = invoker.invoker_services.graph_execution_manager.get(session.id) time.sleep(0.1) except InvalidArgs: diff --git a/ldm/invoke/app/services/invocation_services.py b/ldm/invoke/app/services/invocation_services.py index 40a64e64e5..9eb5309d3d 100644 --- a/ldm/invoke/app/services/invocation_services.py +++ b/ldm/invoke/app/services/invocation_services.py @@ -1,6 +1,4 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from .invocation_queue import InvocationQueueABC -from .item_storage import ItemStorageABC from .image_storage import ImageStorageBase from .events import EventServiceBase from ....generate import Generate @@ -11,23 +9,12 @@ class InvocationServices(): generate: Generate # TODO: wrap Generate, or split it up from model? events: EventServiceBase images: ImageStorageBase - queue: InvocationQueueABC - - # NOTE: we must forward-declare any types that include invocations, since invocations can use services - graph_execution_manager: ItemStorageABC['GraphExecutionState'] - processor: 'InvocationProcessorABC' def __init__(self, generate: Generate, events: EventServiceBase, - images: ImageStorageBase, - queue: InvocationQueueABC, - graph_execution_manager: ItemStorageABC['GraphExecutionState'], - processor: 'InvocationProcessorABC' + images: ImageStorageBase ): self.generate = generate self.events = events self.images = images - self.queue = queue - self.graph_execution_manager = graph_execution_manager - self.processor = processor diff --git a/ldm/invoke/app/services/invoker.py b/ldm/invoke/app/services/invoker.py index 4397a75021..796f541781 100644 --- a/ldm/invoke/app/services/invoker.py +++ b/ldm/invoke/app/services/invoker.py @@ -9,15 +9,34 @@ from .invocation_services import InvocationServices from .invocation_queue import InvocationQueueABC, InvocationQueueItem +class InvokerServices: + """Services used by the Invoker for execution""" + + queue: InvocationQueueABC + graph_execution_manager: ItemStorageABC[GraphExecutionState] + processor: 'InvocationProcessorABC' + + def __init__(self, + queue: InvocationQueueABC, + graph_execution_manager: ItemStorageABC[GraphExecutionState], + processor: 'InvocationProcessorABC'): + self.queue = queue + self.graph_execution_manager = graph_execution_manager + self.processor = processor + + class Invoker: """The invoker, used to execute invocations""" services: InvocationServices + invoker_services: InvokerServices def __init__(self, - services: InvocationServices + services: InvocationServices, # Services used by nodes to perform invocations + invoker_services: InvokerServices # Services used by the invoker for orchestration ): self.services = services + self.invoker_services = invoker_services self._start() @@ -30,11 +49,11 @@ class Invoker: return None # Save the execution state - self.services.graph_execution_manager.set(graph_execution_state) + self.invoker_services.graph_execution_manager.set(graph_execution_state) # Queue the invocation print(f'queueing item {invocation.id}') - self.services.queue.put(InvocationQueueItem( + self.invoker_services.queue.put(InvocationQueueItem( #session_id = session.id, graph_execution_state_id = graph_execution_state.id, invocation_id = invocation.id, @@ -47,7 +66,7 @@ class Invoker: def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState: """Creates a new execution state for the given graph""" new_state = GraphExecutionState(graph = Graph() if graph is None else graph) - self.services.graph_execution_manager.set(new_state) + self.invoker_services.graph_execution_manager.set(new_state) return new_state @@ -67,8 +86,8 @@ class Invoker: def _start(self) -> None: """Starts the invoker. This is called automatically when the invoker is created.""" - for service in vars(self.services): - self.__start_service(getattr(self.services, service)) + for service in vars(self.invoker_services): + self.__start_service(getattr(self.invoker_services, service)) for service in vars(self.services): self.__start_service(getattr(self.services, service)) @@ -80,10 +99,10 @@ class Invoker: for service in vars(self.services): self.__stop_service(getattr(self.services, service)) - for service in vars(self.services): - self.__stop_service(getattr(self.services, service)) + for service in vars(self.invoker_services): + self.__stop_service(getattr(self.invoker_services, service)) - self.services.queue.put(None) + self.invoker_services.queue.put(None) class InvocationProcessorABC(ABC): diff --git a/ldm/invoke/app/services/processor.py b/ldm/invoke/app/services/processor.py index 9ea4349bbf..9b51a6bcbc 100644 --- a/ldm/invoke/app/services/processor.py +++ b/ldm/invoke/app/services/processor.py @@ -28,11 +28,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC): def __process(self, stop_event: Event): try: while not stop_event.is_set(): - queue_item: InvocationQueueItem = self.__invoker.services.queue.get() + queue_item: InvocationQueueItem = self.__invoker.invoker_services.queue.get() if not queue_item: # Probably stopping continue - graph_execution_state = self.__invoker.services.graph_execution_manager.get(queue_item.graph_execution_state_id) + graph_execution_state = self.__invoker.invoker_services.graph_execution_manager.get(queue_item.graph_execution_state_id) invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id) # Send starting event @@ -52,7 +52,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): graph_execution_state.complete(invocation.id, outputs) # Save the state changes - self.__invoker.services.graph_execution_manager.set(graph_execution_state) + self.__invoker.invoker_services.graph_execution_manager.set(graph_execution_state) # Send complete event self.__invoker.services.events.emit_invocation_complete( diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index 980c262501..0a5dcc7734 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -1,11 +1,10 @@ from .test_invoker import create_edge from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext -from ldm.invoke.app.services.processor import DefaultInvocationProcessor -from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory -from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue from ldm.invoke.app.services.invocation_services import InvocationServices from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState +from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation +from ldm.invoke.app.invocations.upscale import UpscaleInvocation import pytest @@ -20,14 +19,7 @@ def simple_graph(): @pytest.fixture def mock_services(): # NOTE: none of these are actually called by the test invocations - return InvocationServices( - generate = None, - events = None, - images = None, - queue = MemoryInvocationQueue(), - graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor() - ) + return InvocationServices(generate = None, events = None, images = None) def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: n = g.next() diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index e9109728d5..a6d96f61c0 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -2,10 +2,12 @@ from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTe from ldm.invoke.app.services.processor import DefaultInvocationProcessor from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue -from ldm.invoke.app.services.invoker import Invoker +from ldm.invoke.app.services.invoker import Invoker, InvokerServices from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from ldm.invoke.app.services.invocation_services import InvocationServices from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState +from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation +from ldm.invoke.app.invocations.upscale import UpscaleInvocation import pytest @@ -20,19 +22,21 @@ def simple_graph(): @pytest.fixture def mock_services() -> InvocationServices: # NOTE: none of these are actually called by the test invocations - return InvocationServices( - generate = None, - events = TestEventService(), - images = None, + return InvocationServices(generate = None, events = TestEventService(), images = None) + +@pytest.fixture() +def mock_invoker_services() -> InvokerServices: + return InvokerServices( queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), processor = DefaultInvocationProcessor() ) @pytest.fixture() -def mock_invoker(mock_services: InvocationServices) -> Invoker: +def mock_invoker(mock_services: InvocationServices, mock_invoker_services: InvokerServices) -> Invoker: return Invoker( - services = mock_services + services = mock_services, + invoker_services = mock_invoker_services ) def test_can_create_graph_state(mock_invoker: Invoker): @@ -56,13 +60,13 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph): assert invocation_id is not None def has_executed_any(g: GraphExecutionState): - g = mock_invoker.services.graph_execution_manager.get(g.id) + g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) return len(g.executed) > 0 wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1) mock_invoker.stop() - g = mock_invoker.services.graph_execution_manager.get(g.id) + g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) assert len(g.executed) > 0 def test_can_invoke_all(mock_invoker: Invoker, simple_graph): @@ -71,11 +75,11 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph): assert invocation_id is not None def has_executed_all(g: GraphExecutionState): - g = mock_invoker.services.graph_execution_manager.get(g.id) + g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) return g.is_complete() wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1) mock_invoker.stop() - g = mock_invoker.services.graph_execution_manager.get(g.id) + g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) assert g.is_complete()