From 276a95ae8e910201c1f64a630725380d5489cf12 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Feb 2024 01:41:04 +1100 Subject: [PATCH] refactor(nodes): merge processors Consolidate graph processing logic into session processor. With graphs as the unit of work, and the session queue distributing graphs, we no longer need the invocation queue or processor. Instead, the session processor dequeues the next session and processes it in a simple loop, greatly simplifying the app. - Remove `graph_execution_manager` service. - Remove `queue` (invocation queue) service. - Remove `processor` (invocation processor) service. - Remove queue-related logic from `Invoker`. It now only starts and stops the services, providing them with access to other services. - Remove unused `invocation_retrieval_error` and `session_retrieval_error` events, these are no longer needed. - Clean up stats service now that it is less coupled to the rest of the app. - Refactor cancellation logic - cancellations now originate from session queue (i.e. HTTP cancel endpoint) and are emitted as events. Processor gets the events and sets the canceled event. Access to this event is provided to the invocation context for e.g. the step callback. - Remove `sessions` router; it provided access to `graph_executions` but that no longer exists. --- invokeai/app/api/dependencies.py | 10 - invokeai/app/api/routers/sessions.py | 276 ------------------ invokeai/app/api_app.py | 3 - invokeai/app/services/events/events_base.py | 48 +-- .../services/invocation_processor/__init__.py | 0 .../invocation_processor_base.py | 5 - .../invocation_processor_common.py | 15 - .../invocation_processor_default.py | 241 --------------- .../app/services/invocation_queue/__init__.py | 0 .../invocation_queue/invocation_queue_base.py | 26 -- .../invocation_queue_common.py | 23 -- .../invocation_queue_memory.py | 44 --- invokeai/app/services/invocation_services.py | 10 - .../invocation_stats/invocation_stats_base.py | 10 +- .../invocation_stats_default.py | 43 +-- invokeai/app/services/invoker.py | 52 ---- .../services/model_load/model_load_default.py | 3 - .../session_processor_common.py | 14 + .../session_processor_default.py | 249 +++++++++++----- .../session_queue/session_queue_sqlite.py | 5 +- .../app/services/shared/invocation_context.py | 20 +- invokeai/app/util/step_callback.py | 9 +- 22 files changed, 227 insertions(+), 879 deletions(-) delete mode 100644 invokeai/app/api/routers/sessions.py delete mode 100644 invokeai/app/services/invocation_processor/__init__.py delete mode 100644 invokeai/app/services/invocation_processor/invocation_processor_base.py delete mode 100644 invokeai/app/services/invocation_processor/invocation_processor_common.py delete mode 100644 invokeai/app/services/invocation_processor/invocation_processor_default.py delete mode 100644 invokeai/app/services/invocation_queue/__init__.py delete mode 100644 invokeai/app/services/invocation_queue/invocation_queue_base.py delete mode 100644 invokeai/app/services/invocation_queue/invocation_queue_common.py delete mode 100644 invokeai/app/services/invocation_queue/invocation_queue_memory.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 8e79b26e2d..a9132516a8 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -4,7 +4,6 @@ from logging import Logger import torch -from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.shared.sqlite.sqlite_util import init_db @@ -22,8 +21,6 @@ from ..services.image_files.image_files_disk import DiskImageFileStorage from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage from ..services.images.images_default import ImageService from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache -from ..services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor -from ..services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue from ..services.invocation_services import InvocationServices from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invoker import Invoker @@ -33,7 +30,6 @@ from ..services.model_records import ModelRecordServiceSQL from ..services.names.names_default import SimpleNameService from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue -from ..services.shared.graph import GraphExecutionState from ..services.urls.urls_default import LocalUrlService from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage from .events import FastAPIEventService @@ -85,7 +81,6 @@ class ApiDependencies: board_records = SqliteBoardRecordStorage(db=db) boards = BoardService() events = FastAPIEventService(event_handler_id) - graph_execution_manager = ItemStorageMemory[GraphExecutionState]() image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) @@ -105,8 +100,6 @@ class ApiDependencies: ) names = SimpleNameService() performance_statistics = InvocationStatsService() - processor = DefaultInvocationProcessor() - queue = MemoryInvocationQueue() session_processor = DefaultSessionProcessor() session_queue = SqliteSessionQueue(db=db) urls = LocalUrlService() @@ -119,7 +112,6 @@ class ApiDependencies: boards=boards, configuration=configuration, events=events, - graph_execution_manager=graph_execution_manager, image_files=image_files, image_records=image_records, images=images, @@ -129,8 +121,6 @@ class ApiDependencies: download_queue=download_queue_service, names=names, performance_statistics=performance_statistics, - processor=processor, - queue=queue, session_processor=session_processor, session_queue=session_queue, urls=urls, diff --git a/invokeai/app/api/routers/sessions.py b/invokeai/app/api/routers/sessions.py deleted file mode 100644 index fb850d0b2b..0000000000 --- a/invokeai/app/api/routers/sessions.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - - -from fastapi import HTTPException, Path -from fastapi.routing import APIRouter - -from ...services.shared.graph import GraphExecutionState -from ..dependencies import ApiDependencies - -session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"]) - - -# @session_router.post( -# "/", -# operation_id="create_session", -# responses={ -# 200: {"model": GraphExecutionState}, -# 400: {"description": "Invalid json"}, -# }, -# deprecated=True, -# ) -# async def create_session( -# queue_id: str = Query(default="", description="The id of the queue to associate the session with"), -# graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"), -# ) -> GraphExecutionState: -# """Creates a new session, optionally initializing it with an invocation graph""" -# session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph) -# return session - - -# @session_router.get( -# "/", -# operation_id="list_sessions", -# responses={200: {"model": PaginatedResults[GraphExecutionState]}}, -# deprecated=True, -# ) -# async def list_sessions( -# page: int = Query(default=0, description="The page of results to get"), -# per_page: int = Query(default=10, description="The number of results per page"), -# query: str = Query(default="", description="The query string to search for"), -# ) -> PaginatedResults[GraphExecutionState]: -# """Gets a list of sessions, optionally searching""" -# if query == "": -# result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page) -# else: -# result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page) -# return result - - -@session_router.get( - "/{session_id}", - operation_id="get_session", - responses={ - 200: {"model": GraphExecutionState}, - 404: {"description": "Session not found"}, - }, -) -async def get_session( - session_id: str = Path(description="The id of the session to get"), -) -> GraphExecutionState: - """Gets a session""" - session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) - if session is None: - raise HTTPException(status_code=404) - else: - return session - - -# @session_router.post( -# "/{session_id}/nodes", -# operation_id="add_node", -# responses={ -# 200: {"model": str}, -# 400: {"description": "Invalid node or link"}, -# 404: {"description": "Session not found"}, -# }, -# deprecated=True, -# ) -# async def add_node( -# session_id: str = Path(description="The id of the session"), -# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore -# description="The node to add" -# ), -# ) -> str: -# """Adds a node to the graph""" -# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) -# if session is None: -# raise HTTPException(status_code=404) - -# try: -# session.add_node(node) -# ApiDependencies.invoker.services.graph_execution_manager.set( -# session -# ) # TODO: can this be done automatically, or add node through an API? -# return session.id -# except NodeAlreadyExecutedError: -# raise HTTPException(status_code=400) -# except IndexError: -# raise HTTPException(status_code=400) - - -# @session_router.put( -# "/{session_id}/nodes/{node_path}", -# operation_id="update_node", -# responses={ -# 200: {"model": GraphExecutionState}, -# 400: {"description": "Invalid node or link"}, -# 404: {"description": "Session not found"}, -# }, -# deprecated=True, -# ) -# async def update_node( -# session_id: str = Path(description="The id of the session"), -# node_path: str = Path(description="The path to the node in the graph"), -# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore -# description="The new node" -# ), -# ) -> GraphExecutionState: -# """Updates a node in the graph and removes all linked edges""" -# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) -# if session is None: -# raise HTTPException(status_code=404) - -# try: -# session.update_node(node_path, node) -# ApiDependencies.invoker.services.graph_execution_manager.set( -# session -# ) # TODO: can this be done automatically, or add node through an API? -# return session -# except NodeAlreadyExecutedError: -# raise HTTPException(status_code=400) -# except IndexError: -# raise HTTPException(status_code=400) - - -# @session_router.delete( -# "/{session_id}/nodes/{node_path}", -# operation_id="delete_node", -# responses={ -# 200: {"model": GraphExecutionState}, -# 400: {"description": "Invalid node or link"}, -# 404: {"description": "Session not found"}, -# }, -# deprecated=True, -# ) -# async def delete_node( -# session_id: str = Path(description="The id of the session"), -# node_path: str = Path(description="The path to the node to delete"), -# ) -> GraphExecutionState: -# """Deletes a node in the graph and removes all linked edges""" -# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) -# if session is None: -# raise HTTPException(status_code=404) - -# try: -# session.delete_node(node_path) -# ApiDependencies.invoker.services.graph_execution_manager.set( -# session -# ) # TODO: can this be done automatically, or add node through an API? -# return session -# except NodeAlreadyExecutedError: -# raise HTTPException(status_code=400) -# except IndexError: -# raise HTTPException(status_code=400) - - -# @session_router.post( -# "/{session_id}/edges", -# operation_id="add_edge", -# responses={ -# 200: {"model": GraphExecutionState}, -# 400: {"description": "Invalid node or link"}, -# 404: {"description": "Session not found"}, -# }, -# deprecated=True, -# ) -# async def add_edge( -# session_id: str = Path(description="The id of the session"), -# edge: Edge = Body(description="The edge to add"), -# ) -> GraphExecutionState: -# """Adds an edge to the graph""" -# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) -# if session is None: -# raise HTTPException(status_code=404) - -# try: -# session.add_edge(edge) -# ApiDependencies.invoker.services.graph_execution_manager.set( -# session -# ) # TODO: can this be done automatically, or add node through an API? -# return session -# except NodeAlreadyExecutedError: -# raise HTTPException(status_code=400) -# except IndexError: -# raise HTTPException(status_code=400) - - -# # TODO: the edge being in the path here is really ugly, find a better solution -# @session_router.delete( -# "/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}", -# operation_id="delete_edge", -# responses={ -# 200: {"model": GraphExecutionState}, -# 400: {"description": "Invalid node or link"}, -# 404: {"description": "Session not found"}, -# }, -# deprecated=True, -# ) -# async def delete_edge( -# session_id: str = Path(description="The id of the session"), -# from_node_id: str = Path(description="The id of the node the edge is coming from"), -# from_field: str = Path(description="The field of the node the edge is coming from"), -# to_node_id: str = Path(description="The id of the node the edge is going to"), -# to_field: str = Path(description="The field of the node the edge is going to"), -# ) -> GraphExecutionState: -# """Deletes an edge from the graph""" -# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) -# if session is None: -# raise HTTPException(status_code=404) - -# try: -# edge = Edge( -# source=EdgeConnection(node_id=from_node_id, field=from_field), -# destination=EdgeConnection(node_id=to_node_id, field=to_field), -# ) -# session.delete_edge(edge) -# ApiDependencies.invoker.services.graph_execution_manager.set( -# session -# ) # TODO: can this be done automatically, or add node through an API? -# return session -# except NodeAlreadyExecutedError: -# raise HTTPException(status_code=400) -# except IndexError: -# raise HTTPException(status_code=400) - - -# @session_router.put( -# "/{session_id}/invoke", -# operation_id="invoke_session", -# responses={ -# 200: {"model": None}, -# 202: {"description": "The invocation is queued"}, -# 400: {"description": "The session has no invocations ready to invoke"}, -# 404: {"description": "Session not found"}, -# }, -# deprecated=True, -# ) -# async def invoke_session( -# queue_id: str = Query(description="The id of the queue to associate the session with"), -# session_id: str = Path(description="The id of the session to invoke"), -# all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"), -# ) -> Response: -# """Invokes a session""" -# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) -# if session is None: -# raise HTTPException(status_code=404) - -# if session.is_complete(): -# raise HTTPException(status_code=400) - -# ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all) -# return Response(status_code=202) - - -# @session_router.delete( -# "/{session_id}/invoke", -# operation_id="cancel_session_invoke", -# responses={202: {"description": "The invocation is canceled"}}, -# deprecated=True, -# ) -# async def cancel_session_invoke( -# session_id: str = Path(description="The id of the session to cancel"), -# ) -> Response: -# """Invokes a session""" -# ApiDependencies.invoker.cancel(session_id) -# return Response(status_code=202) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 65607c436a..f6b08ddba6 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -50,7 +50,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c images, model_manager, session_queue, - sessions, utilities, workflows, ) @@ -110,8 +109,6 @@ async def shutdown_event() -> None: # Include all routers -app.include_router(sessions.session_router, prefix="/api") - app.include_router(utilities.utilities_router, prefix="/api") app.include_router(model_manager.model_manager_router, prefix="/api") app.include_router(download_queue.download_queue_router, prefix="/api") diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 90d9068b88..5355fe2298 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Union -from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage +from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_queue.session_queue_common import ( BatchStatus, EnqueueBatchResult, @@ -204,52 +204,6 @@ class EventServiceBase: }, ) - def emit_session_retrieval_error( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - error_type: str, - error: str, - ) -> None: - """Emitted when session retrieval fails""" - self.__emit_queue_event( - event_name="session_retrieval_error", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "error_type": error_type, - "error": error, - }, - ) - - def emit_invocation_retrieval_error( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - node_id: str, - error_type: str, - error: str, - ) -> None: - """Emitted when invocation retrieval fails""" - self.__emit_queue_event( - event_name="invocation_retrieval_error", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node_id": node_id, - "error_type": error_type, - "error": error, - }, - ) - def emit_session_canceled( self, queue_id: str, diff --git a/invokeai/app/services/invocation_processor/__init__.py b/invokeai/app/services/invocation_processor/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/invokeai/app/services/invocation_processor/invocation_processor_base.py b/invokeai/app/services/invocation_processor/invocation_processor_base.py deleted file mode 100644 index 7947a201dd..0000000000 --- a/invokeai/app/services/invocation_processor/invocation_processor_base.py +++ /dev/null @@ -1,5 +0,0 @@ -from abc import ABC - - -class InvocationProcessorABC(ABC): # noqa: B024 - pass diff --git a/invokeai/app/services/invocation_processor/invocation_processor_common.py b/invokeai/app/services/invocation_processor/invocation_processor_common.py deleted file mode 100644 index 347f6c7323..0000000000 --- a/invokeai/app/services/invocation_processor/invocation_processor_common.py +++ /dev/null @@ -1,15 +0,0 @@ -from pydantic import BaseModel, Field - - -class ProgressImage(BaseModel): - """The progress image sent intermittently during processing""" - - width: int = Field(description="The effective width of the image in pixels") - height: int = Field(description="The effective height of the image in pixels") - dataURL: str = Field(description="The image data as a b64 data URL") - - -class CanceledException(Exception): - """Execution canceled by user.""" - - pass diff --git a/invokeai/app/services/invocation_processor/invocation_processor_default.py b/invokeai/app/services/invocation_processor/invocation_processor_default.py deleted file mode 100644 index d2ebe235e6..0000000000 --- a/invokeai/app/services/invocation_processor/invocation_processor_default.py +++ /dev/null @@ -1,241 +0,0 @@ -import time -import traceback -from contextlib import suppress -from threading import BoundedSemaphore, Event, Thread -from typing import Optional - -import invokeai.backend.util.logging as logger -from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem -from invokeai.app.services.invocation_stats.invocation_stats_common import ( - GESStatsNotFoundError, -) -from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context -from invokeai.app.util.profiler import Profiler - -from ..invoker import Invoker -from .invocation_processor_base import InvocationProcessorABC -from .invocation_processor_common import CanceledException - - -class DefaultInvocationProcessor(InvocationProcessorABC): - __invoker_thread: Thread - __stop_event: Event - __invoker: Invoker - __threadLimit: BoundedSemaphore - - def start(self, invoker: Invoker) -> None: - # if we do want multithreading at some point, we could make this configurable - self.__threadLimit = BoundedSemaphore(1) - self.__invoker = invoker - self.__stop_event = Event() - self.__invoker_thread = Thread( - name="invoker_processor", - target=self.__process, - kwargs={"stop_event": self.__stop_event}, - ) - self.__invoker_thread.daemon = True # TODO: make async and do not use threads - self.__invoker_thread.start() - - def stop(self, *args, **kwargs) -> None: - self.__stop_event.set() - - def __process(self, stop_event: Event): - try: - self.__threadLimit.acquire() - queue_item: Optional[InvocationQueueItem] = None - - profiler = ( - Profiler( - logger=self.__invoker.services.logger, - output_dir=self.__invoker.services.configuration.profiles_path, - prefix=self.__invoker.services.configuration.profile_prefix, - ) - if self.__invoker.services.configuration.profile_graphs - else None - ) - - def stats_cleanup(graph_execution_state_id: str) -> None: - if profiler: - profile_path = profiler.stop() - stats_path = profile_path.with_suffix(".json") - self.__invoker.services.performance_statistics.dump_stats( - graph_execution_state_id=graph_execution_state_id, output_path=stats_path - ) - with suppress(GESStatsNotFoundError): - self.__invoker.services.performance_statistics.log_stats(graph_execution_state_id) - self.__invoker.services.performance_statistics.reset_stats(graph_execution_state_id) - - while not stop_event.is_set(): - try: - queue_item = self.__invoker.services.queue.get() - except Exception as e: - self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e) - - if not queue_item: # Probably stopping - # do not hammer the queue - time.sleep(0.5) - continue - - if profiler and profiler.profile_id != queue_item.graph_execution_state_id: - profiler.start(profile_id=queue_item.graph_execution_state_id) - - try: - graph_execution_state = self.__invoker.services.graph_execution_manager.get( - queue_item.graph_execution_state_id - ) - except Exception as e: - self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e) - self.__invoker.services.events.emit_session_retrieval_error( - queue_batch_id=queue_item.session_queue_batch_id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - graph_execution_state_id=queue_item.graph_execution_state_id, - error_type=e.__class__.__name__, - error=traceback.format_exc(), - ) - continue - - try: - invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id) - except Exception as e: - self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e) - self.__invoker.services.events.emit_invocation_retrieval_error( - queue_batch_id=queue_item.session_queue_batch_id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - graph_execution_state_id=queue_item.graph_execution_state_id, - node_id=queue_item.invocation_id, - error_type=e.__class__.__name__, - error=traceback.format_exc(), - ) - continue - - # get the source node id to provide to clients (the prepared node id is not as useful) - source_node_id = graph_execution_state.prepared_source_mapping[invocation.id] - - # Send starting event - self.__invoker.services.events.emit_invocation_started( - queue_batch_id=queue_item.session_queue_batch_id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - graph_execution_state_id=graph_execution_state.id, - node=invocation.model_dump(), - source_node_id=source_node_id, - ) - - # Invoke - try: - graph_id = graph_execution_state.id - with self.__invoker.services.performance_statistics.collect_stats(invocation, graph_id): - # use the internal invoke_internal(), which wraps the node's invoke() method, - # which handles a few things: - # - nodes that require a value, but get it only from a connection - # - referencing the invocation cache instead of executing the node - context_data = InvocationContextData( - invocation=invocation, - session_id=graph_id, - workflow=queue_item.workflow, - source_node_id=source_node_id, - queue_id=queue_item.session_queue_id, - queue_item_id=queue_item.session_queue_item_id, - batch_id=queue_item.session_queue_batch_id, - ) - context = build_invocation_context( - services=self.__invoker.services, - context_data=context_data, - ) - outputs = invocation.invoke_internal(context=context, services=self.__invoker.services) - - # Check queue to see if this is canceled, and skip if so - if self.__invoker.services.queue.is_canceled(graph_execution_state.id): - continue - - # Save outputs and history - graph_execution_state.complete(invocation.id, outputs) - - # Save the state changes - self.__invoker.services.graph_execution_manager.set(graph_execution_state) - - # Send complete event - self.__invoker.services.events.emit_invocation_complete( - queue_batch_id=queue_item.session_queue_batch_id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - graph_execution_state_id=graph_execution_state.id, - node=invocation.model_dump(), - source_node_id=source_node_id, - result=outputs.model_dump(), - ) - - except KeyboardInterrupt: - pass - - except CanceledException: - stats_cleanup(graph_execution_state.id) - pass - - except Exception as e: - error = traceback.format_exc() - logger.error(error) - - # Save error - graph_execution_state.set_node_error(invocation.id, error) - - # Save the state changes - self.__invoker.services.graph_execution_manager.set(graph_execution_state) - - self.__invoker.services.logger.error("Error while invoking:\n%s" % e) - # Send error event - self.__invoker.services.events.emit_invocation_error( - queue_batch_id=queue_item.session_queue_batch_id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - graph_execution_state_id=graph_execution_state.id, - node=invocation.model_dump(), - source_node_id=source_node_id, - error_type=e.__class__.__name__, - error=error, - ) - pass - - # Check queue to see if this is canceled, and skip if so - if self.__invoker.services.queue.is_canceled(graph_execution_state.id): - continue - - # Queue any further commands if invoking all - is_complete = graph_execution_state.is_complete() - if queue_item.invoke_all and not is_complete: - try: - self.__invoker.invoke( - session_queue_batch_id=queue_item.session_queue_batch_id, - session_queue_item_id=queue_item.session_queue_item_id, - session_queue_id=queue_item.session_queue_id, - graph_execution_state=graph_execution_state, - workflow=queue_item.workflow, - invoke_all=True, - ) - except Exception as e: - self.__invoker.services.logger.error("Error while invoking:\n%s" % e) - self.__invoker.services.events.emit_invocation_error( - queue_batch_id=queue_item.session_queue_batch_id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - graph_execution_state_id=graph_execution_state.id, - node=invocation.model_dump(), - source_node_id=source_node_id, - error_type=e.__class__.__name__, - error=traceback.format_exc(), - ) - elif is_complete: - self.__invoker.services.events.emit_graph_execution_complete( - queue_batch_id=queue_item.session_queue_batch_id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - graph_execution_state_id=graph_execution_state.id, - ) - stats_cleanup(graph_execution_state.id) - - except KeyboardInterrupt: - pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor - finally: - self.__threadLimit.release() diff --git a/invokeai/app/services/invocation_queue/__init__.py b/invokeai/app/services/invocation_queue/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/invokeai/app/services/invocation_queue/invocation_queue_base.py b/invokeai/app/services/invocation_queue/invocation_queue_base.py deleted file mode 100644 index 09f4875c5f..0000000000 --- a/invokeai/app/services/invocation_queue/invocation_queue_base.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -from abc import ABC, abstractmethod -from typing import Optional - -from .invocation_queue_common import InvocationQueueItem - - -class InvocationQueueABC(ABC): - """Abstract base class for all invocation queues""" - - @abstractmethod - def get(self) -> InvocationQueueItem: - pass - - @abstractmethod - def put(self, item: Optional[InvocationQueueItem]) -> None: - pass - - @abstractmethod - def cancel(self, graph_execution_state_id: str) -> None: - pass - - @abstractmethod - def is_canceled(self, graph_execution_state_id: str) -> bool: - pass diff --git a/invokeai/app/services/invocation_queue/invocation_queue_common.py b/invokeai/app/services/invocation_queue/invocation_queue_common.py deleted file mode 100644 index 696f6a981d..0000000000 --- a/invokeai/app/services/invocation_queue/invocation_queue_common.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -import time -from typing import Optional - -from pydantic import BaseModel, Field - -from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID - - -class InvocationQueueItem(BaseModel): - graph_execution_state_id: str = Field(description="The ID of the graph execution state") - invocation_id: str = Field(description="The ID of the node being invoked") - session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came") - session_queue_item_id: int = Field( - description="The ID of session queue item from which this invocation queue item came" - ) - session_queue_batch_id: str = Field( - description="The ID of the session batch from which this invocation queue item came" - ) - workflow: Optional[WorkflowWithoutID] = Field(description="The workflow associated with this queue item") - invoke_all: bool = Field(default=False) - timestamp: float = Field(default_factory=time.time) diff --git a/invokeai/app/services/invocation_queue/invocation_queue_memory.py b/invokeai/app/services/invocation_queue/invocation_queue_memory.py deleted file mode 100644 index 8d6fff7052..0000000000 --- a/invokeai/app/services/invocation_queue/invocation_queue_memory.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -import time -from queue import Queue -from typing import Optional - -from .invocation_queue_base import InvocationQueueABC -from .invocation_queue_common import InvocationQueueItem - - -class MemoryInvocationQueue(InvocationQueueABC): - __queue: Queue - __cancellations: dict[str, float] - - def __init__(self): - self.__queue = Queue() - self.__cancellations = {} - - def get(self) -> InvocationQueueItem: - item = self.__queue.get() - - while ( - isinstance(item, InvocationQueueItem) - and item.graph_execution_state_id in self.__cancellations - and self.__cancellations[item.graph_execution_state_id] > item.timestamp - ): - item = self.__queue.get() - - # Clear old items - for graph_execution_state_id in list(self.__cancellations.keys()): - if self.__cancellations[graph_execution_state_id] < item.timestamp: - del self.__cancellations[graph_execution_state_id] - - return item - - def put(self, item: Optional[InvocationQueueItem]) -> None: - self.__queue.put(item) - - def cancel(self, graph_execution_state_id: str) -> None: - if graph_execution_state_id not in self.__cancellations: - self.__cancellations[graph_execution_state_id] = time.time() - - def is_canceled(self, graph_execution_state_id: str) -> bool: - return graph_execution_state_id in self.__cancellations diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 0a1fa1e922..04fe71a3eb 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -23,15 +23,11 @@ if TYPE_CHECKING: from .image_records.image_records_base import ImageRecordStorageBase from .images.images_base import ImageServiceABC from .invocation_cache.invocation_cache_base import InvocationCacheBase - from .invocation_processor.invocation_processor_base import InvocationProcessorABC - from .invocation_queue.invocation_queue_base import InvocationQueueABC from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase - from .item_storage.item_storage_base import ItemStorageABC from .model_manager.model_manager_base import ModelManagerServiceBase from .names.names_base import NameServiceBase from .session_processor.session_processor_base import SessionProcessorBase from .session_queue.session_queue_base import SessionQueueBase - from .shared.graph import GraphExecutionState from .urls.urls_base import UrlServiceBase from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase @@ -47,16 +43,13 @@ class InvocationServices: board_records: "BoardRecordStorageBase", configuration: "InvokeAIAppConfig", events: "EventServiceBase", - graph_execution_manager: "ItemStorageABC[GraphExecutionState]", images: "ImageServiceABC", image_files: "ImageFileStorageBase", image_records: "ImageRecordStorageBase", logger: "Logger", model_manager: "ModelManagerServiceBase", download_queue: "DownloadQueueServiceBase", - processor: "InvocationProcessorABC", performance_statistics: "InvocationStatsServiceBase", - queue: "InvocationQueueABC", session_queue: "SessionQueueBase", session_processor: "SessionProcessorBase", invocation_cache: "InvocationCacheBase", @@ -72,16 +65,13 @@ class InvocationServices: self.board_records = board_records self.configuration = configuration self.events = events - self.graph_execution_manager = graph_execution_manager self.images = images self.image_files = image_files self.image_records = image_records self.logger = logger self.model_manager = model_manager self.download_queue = download_queue - self.processor = processor self.performance_statistics = performance_statistics - self.queue = queue self.session_queue = session_queue self.session_processor = session_processor self.invocation_cache = invocation_cache diff --git a/invokeai/app/services/invocation_stats/invocation_stats_base.py b/invokeai/app/services/invocation_stats/invocation_stats_base.py index ec8a453323..b28220e74c 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_base.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_base.py @@ -3,7 +3,7 @@ Usage: -statistics = InvocationStatsService(graph_execution_manager) +statistics = InvocationStatsService() with statistics.collect_stats(invocation, graph_execution_state.id): ... execute graphs... statistics.log_stats() @@ -60,12 +60,8 @@ class InvocationStatsServiceBase(ABC): pass @abstractmethod - def reset_stats(self, graph_execution_state_id: str) -> None: - """ - Reset all statistics for the indicated graph. - :param graph_execution_state_id: The id of the session whose stats to reset. - :raises GESStatsNotFoundError: if the graph isn't tracked in the stats. - """ + def reset_stats(self): + """Reset all stored statistics.""" pass @abstractmethod diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index 486a1ca5b3..06a5b675c3 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -10,7 +10,6 @@ import torch import invokeai.backend.util.logging as logger from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.invoker import Invoker -from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError from invokeai.backend.model_manager.load.model_cache import CacheStats from .invocation_stats_base import InvocationStatsServiceBase @@ -51,9 +50,6 @@ class InvocationStatsService(InvocationStatsServiceBase): self._stats[graph_execution_state_id] = GraphExecutionStats() self._cache_stats[graph_execution_state_id] = CacheStats() - # Prune stale stats. There should be none since we're starting a new graph, but just in case. - self._prune_stale_stats() - # Record state before the invocation. start_time = time.time() start_ram = psutil.Process().memory_info().rss @@ -78,42 +74,9 @@ class InvocationStatsService(InvocationStatsServiceBase): ) self._stats[graph_execution_state_id].add_node_execution_stats(node_stats) - def _prune_stale_stats(self) -> None: - """Check all graphs being tracked and prune any that have completed/errored. - - This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so - for now we call this function periodically to prevent them from accumulating. - """ - to_prune: list[str] = [] - for graph_execution_state_id in self._stats: - try: - graph_execution_state = self._invoker.services.graph_execution_manager.get(graph_execution_state_id) - except ItemNotFoundError: - # TODO(ryand): What would cause this? Should this exception just be allowed to propagate? - logger.warning(f"Failed to get graph state for {graph_execution_state_id}.") - continue - - if not graph_execution_state.is_complete(): - # The graph is still running, don't prune it. - continue - - to_prune.append(graph_execution_state_id) - - for graph_execution_state_id in to_prune: - del self._stats[graph_execution_state_id] - del self._cache_stats[graph_execution_state_id] - - if len(to_prune) > 0: - logger.info(f"Pruned stale graph stats for {to_prune}.") - - def reset_stats(self, graph_execution_state_id: str): - try: - del self._stats[graph_execution_state_id] - del self._cache_stats[graph_execution_state_id] - except KeyError as e: - raise GESStatsNotFoundError( - f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}." - ) from e + def reset_stats(self): + self._stats = {} + self._cache_stats = {} def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary: graph_stats_summary = self._get_graph_summary(graph_execution_state_id) diff --git a/invokeai/app/services/invoker.py b/invokeai/app/services/invoker.py index a04c6f2059..527afb37f4 100644 --- a/invokeai/app/services/invoker.py +++ b/invokeai/app/services/invoker.py @@ -1,12 +1,7 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from typing import Optional -from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID - -from .invocation_queue.invocation_queue_common import InvocationQueueItem from .invocation_services import InvocationServices -from .shared.graph import Graph, GraphExecutionState class Invoker: @@ -18,51 +13,6 @@ class Invoker: self.services = services self._start() - def invoke( - self, - session_queue_id: str, - session_queue_item_id: int, - session_queue_batch_id: str, - graph_execution_state: GraphExecutionState, - workflow: Optional[WorkflowWithoutID] = None, - invoke_all: bool = False, - ) -> Optional[str]: - """Determines the next node to invoke and enqueues it, preparing if needed. - Returns the id of the queued node, or `None` if there are no nodes left to enqueue.""" - - # Get the next invocation - invocation = graph_execution_state.next() - if not invocation: - return None - - # Save the execution state - self.services.graph_execution_manager.set(graph_execution_state) - - # Queue the invocation - self.services.queue.put( - InvocationQueueItem( - session_queue_id=session_queue_id, - session_queue_item_id=session_queue_item_id, - session_queue_batch_id=session_queue_batch_id, - graph_execution_state_id=graph_execution_state.id, - invocation_id=invocation.id, - workflow=workflow, - invoke_all=invoke_all, - ) - ) - - return invocation.id - - def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState: - """Creates a new execution state for the given graph""" - new_state = GraphExecutionState(graph=Graph() if graph is None else graph) - self.services.graph_execution_manager.set(new_state) - return new_state - - def cancel(self, graph_execution_state_id: str) -> None: - """Cancels the given execution state""" - self.services.queue.cancel(graph_execution_state_id) - def __start_service(self, service) -> None: # Call start() method on any services that have it start_op = getattr(service, "start", None) @@ -85,5 +35,3 @@ class Invoker: # First stop all services for service in vars(self.services): self.__stop_service(getattr(self.services, service)) - - self.services.queue.put(None) diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 15c6283d8a..24ab10b427 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -4,7 +4,6 @@ from typing import Optional, Type from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException from invokeai.app.services.invoker import Invoker from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType @@ -95,8 +94,6 @@ class ModelLoadService(ModelLoadServiceBase): ) -> None: if not self._invoker: return - if self._invoker.services.queue.is_canceled(context_data.session_id): - raise CanceledException() if not loaded: self._invoker.services.events.emit_model_load_started( diff --git a/invokeai/app/services/session_processor/session_processor_common.py b/invokeai/app/services/session_processor/session_processor_common.py index 00195a773f..0ca51de517 100644 --- a/invokeai/app/services/session_processor/session_processor_common.py +++ b/invokeai/app/services/session_processor/session_processor_common.py @@ -4,3 +4,17 @@ from pydantic import BaseModel, Field class SessionProcessorStatus(BaseModel): is_started: bool = Field(description="Whether the session processor is started") is_processing: bool = Field(description="Whether a session is being processed") + + +class CanceledException(Exception): + """Execution canceled by user.""" + + pass + + +class ProgressImage(BaseModel): + """The progress image sent intermittently during processing""" + + width: int = Field(description="The effective width of the image in pixels") + height: int = Field(description="The effective height of the image in pixels") + dataURL: str = Field(description="The image data as a b64 data URL") diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 32e94a305d..dd34c78252 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -1,4 +1,5 @@ import traceback +from contextlib import suppress from threading import BoundedSemaphore, Thread from threading import Event as ThreadEvent from typing import Optional @@ -7,7 +8,11 @@ from fastapi_events.handlers.local import local_handler from fastapi_events.typing import Event as FastAPIEvent from invokeai.app.services.events.events_base import EventServiceBase +from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError +from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem +from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context +from invokeai.app.util.profiler import Profiler from ..invoker import Invoker from .session_processor_base import SessionProcessorBase @@ -19,123 +24,237 @@ THREAD_LIMIT = 1 class DefaultSessionProcessor(SessionProcessorBase): def start(self, invoker: Invoker) -> None: - self.__invoker: Invoker = invoker - self.__queue_item: Optional[SessionQueueItem] = None + self._invoker: Invoker = invoker + self._queue_item: Optional[SessionQueueItem] = None - self.__resume_event = ThreadEvent() - self.__stop_event = ThreadEvent() - self.__poll_now_event = ThreadEvent() + self._resume_event = ThreadEvent() + self._stop_event = ThreadEvent() + self._poll_now_event = ThreadEvent() + self._cancel_event = ThreadEvent() local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event) - self.__threadLimit = BoundedSemaphore(THREAD_LIMIT) - self.__thread = Thread( + self._thread_limit = BoundedSemaphore(THREAD_LIMIT) + self._thread = Thread( name="session_processor", - target=self.__process, + target=self._process, kwargs={ - "stop_event": self.__stop_event, - "poll_now_event": self.__poll_now_event, - "resume_event": self.__resume_event, + "stop_event": self._stop_event, + "poll_now_event": self._poll_now_event, + "resume_event": self._resume_event, + "cancel_event": self._cancel_event, }, ) - self.__thread.start() + self._thread.start() def stop(self, *args, **kwargs) -> None: - self.__stop_event.set() + self._stop_event.set() def _poll_now(self) -> None: - self.__poll_now_event.set() + self._poll_now_event.set() async def _on_queue_event(self, event: FastAPIEvent) -> None: event_name = event[1]["event"] - # This was a match statement, but match is not supported on python 3.9 - if event_name in [ - "graph_execution_state_complete", - "invocation_error", - "session_retrieval_error", - "invocation_retrieval_error", - ]: - self.__queue_item = None - self._poll_now() - elif ( - event_name == "session_canceled" - and self.__queue_item is not None - and self.__queue_item.session_id == event[1]["data"]["graph_execution_state_id"] - ): - self.__queue_item = None + if event_name == "session_canceled" or event_name == "queue_cleared": + # These both mean we should cancel the current session. + self._cancel_event.set() self._poll_now() elif event_name == "batch_enqueued": self._poll_now() - elif event_name == "queue_cleared": - self.__queue_item = None - self._poll_now() def resume(self) -> SessionProcessorStatus: - if not self.__resume_event.is_set(): - self.__resume_event.set() + if not self._resume_event.is_set(): + self._resume_event.set() return self.get_status() def pause(self) -> SessionProcessorStatus: - if self.__resume_event.is_set(): - self.__resume_event.clear() + if self._resume_event.is_set(): + self._resume_event.clear() return self.get_status() def get_status(self) -> SessionProcessorStatus: return SessionProcessorStatus( - is_started=self.__resume_event.is_set(), - is_processing=self.__queue_item is not None, + is_started=self._resume_event.is_set(), + is_processing=self._queue_item is not None, ) - def __process( + def _process( self, stop_event: ThreadEvent, poll_now_event: ThreadEvent, resume_event: ThreadEvent, + cancel_event: ThreadEvent, ): + # Outermost processor try block; any unhandled exception is a fatal processor error try: + self._thread_limit.acquire() stop_event.clear() resume_event.set() - self.__threadLimit.acquire() - queue_item: Optional[SessionQueueItem] = None + cancel_event.clear() + + # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, + # the profiler will create a new profile for each session. + profiler = ( + Profiler( + logger=self._invoker.services.logger, + output_dir=self._invoker.services.configuration.profiles_path, + prefix=self._invoker.services.configuration.profile_prefix, + ) + if self._invoker.services.configuration.profile_graphs + else None + ) + + # Helper function to stop the profiler and save the stats + def stats_cleanup(graph_execution_state_id: str) -> None: + if profiler: + profile_path = profiler.stop() + stats_path = profile_path.with_suffix(".json") + self._invoker.services.performance_statistics.dump_stats( + graph_execution_state_id=graph_execution_state_id, output_path=stats_path + ) + # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor + # we don't care about that - suppress the error. + with suppress(GESStatsNotFoundError): + self._invoker.services.performance_statistics.log_stats(graph_execution_state_id) + self._invoker.services.performance_statistics.reset_stats() + while not stop_event.is_set(): poll_now_event.clear() + # Middle processor try block; any unhandled exception is a non-fatal processor error try: - # do not dequeue if there is already a session running - if self.__queue_item is None and resume_event.is_set(): - queue_item = self.__invoker.services.session_queue.dequeue() + # Get the next session to process + self._queue_item = self._invoker.services.session_queue.dequeue() + if self._queue_item is not None and resume_event.is_set(): + self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") + cancel_event.clear() - if queue_item is not None: - self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}") - self.__queue_item = queue_item - self.__invoker.services.graph_execution_manager.set(queue_item.session) - self.__invoker.invoke( - session_queue_batch_id=queue_item.batch_id, - session_queue_id=queue_item.queue_id, - session_queue_item_id=queue_item.item_id, - graph_execution_state=queue_item.session, - workflow=queue_item.workflow, - invoke_all=True, + # If profiling is enabled, start the profiler + if profiler is not None: + profiler.start(profile_id=self._queue_item.session_id) + + # Prepare invocations and take the first + invocation = self._queue_item.session.next() + + # Loop over invocations until the session is complete or canceled + while invocation is not None and not cancel_event.is_set(): + # get the source node id to provide to clients (the prepared node id is not as useful) + source_node_id = self._queue_item.session.prepared_source_mapping[invocation.id] + + # Send starting event + self._invoker.services.events.emit_invocation_started( + queue_batch_id=self._queue_item.batch_id, + queue_item_id=self._queue_item.item_id, + queue_id=self._queue_item.queue_id, + graph_execution_state_id=self._queue_item.session_id, + node=invocation.model_dump(), + source_node_id=source_node_id, ) - queue_item = None - if queue_item is None: - self.__invoker.services.logger.debug("Waiting for next polling interval or event") + # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph + try: + with self._invoker.services.performance_statistics.collect_stats( + invocation, self._queue_item.session.id + ): + # Build invocation context (the node-facing API) + context_data = InvocationContextData( + invocation=invocation, + source_node_id=source_node_id, + session_id=self._queue_item.session.id, + workflow=self._queue_item.workflow, + queue_id=self._queue_item.queue_id, + queue_item_id=self._queue_item.item_id, + batch_id=self._queue_item.batch_id, + ) + context = build_invocation_context( + context_data=context_data, + services=self._invoker.services, + cancel_event=self._cancel_event, + ) + + # Invoke the node + outputs = invocation.invoke_internal( + context=context, services=self._invoker.services + ) + + # Save outputs and history + self._queue_item.session.complete(invocation.id, outputs) + + # Send complete event + self._invoker.services.events.emit_invocation_complete( + queue_batch_id=self._queue_item.batch_id, + queue_item_id=self._queue_item.item_id, + queue_id=self._queue_item.queue_id, + graph_execution_state_id=self._queue_item.session.id, + node=invocation.model_dump(), + source_node_id=source_node_id, + result=outputs.model_dump(), + ) + + except KeyboardInterrupt: + pass + + except CanceledException: + pass + + except Exception as e: + error = traceback.format_exc() + + # Save error + self._queue_item.session.set_node_error(invocation.id, error) + self._invoker.services.logger.error("Error while invoking:\n%s" % e) + + # Send error event + self._invoker.services.events.emit_invocation_error( + queue_batch_id=self._queue_item.session_id, + queue_item_id=self._queue_item.item_id, + queue_id=self._queue_item.queue_id, + graph_execution_state_id=self._queue_item.session.id, + node=invocation.model_dump(), + source_node_id=source_node_id, + error_type=e.__class__.__name__, + error=error, + ) + pass + + if self._queue_item.session.is_complete() or cancel_event.is_set(): + # Send complete event + self._invoker.services.events.emit_graph_execution_complete( + queue_batch_id=self._queue_item.batch_id, + queue_item_id=self._queue_item.item_id, + queue_id=self._queue_item.queue_id, + graph_execution_state_id=self._queue_item.session.id, + ) + # Save the stats and stop the profiler if it's running + stats_cleanup(self._queue_item.session.id) + invocation = None + else: + # Prepare the next invocation + invocation = self._queue_item.session.next() + + # The session is complete, immediately poll for next session + self._queue_item = None + poll_now_event.set() + else: + # The queue was empty, wait for next polling interval or event to try again + self._invoker.services.logger.debug("Waiting for next polling interval or event") poll_now_event.wait(POLLING_INTERVAL) continue except Exception as e: - self.__invoker.services.logger.error(f"Error in session processor: {e}") - if queue_item is not None: - self.__invoker.services.session_queue.cancel_queue_item( - queue_item.item_id, error=traceback.format_exc() + # Non-fatal error in processor, cancel the queue item and wait for next polling interval or event + self._invoker.services.logger.error(f"Error in session processor: {e}") + if self._queue_item is not None: + self._invoker.services.session_queue.cancel_queue_item( + self._queue_item.item_id, error=traceback.format_exc() ) poll_now_event.wait(POLLING_INTERVAL) continue except Exception as e: - self.__invoker.services.logger.error(f"Fatal Error in session processor: {e}") + # Fatal error in processor, log and pass - we're done here + self._invoker.services.logger.error(f"Fatal Error in session processor: {e}") pass finally: stop_event.clear() poll_now_event.clear() - self.__queue_item = None - self.__threadLimit.release() + self._queue_item = None + self._thread_limit.release() diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 64642690e9..7af9f0e08c 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -60,7 +60,7 @@ class SqliteSessionQueue(SessionQueueBase): # This was a match statement, but match is not supported on python 3.9 if event_name == "graph_execution_state_complete": await self._handle_complete_event(event) - elif event_name in ["invocation_error", "session_retrieval_error", "invocation_retrieval_error"]: + elif event_name == "invocation_error": await self._handle_error_event(event) elif event_name == "session_canceled": await self._handle_cancel_event(event) @@ -429,7 +429,6 @@ class SqliteSessionQueue(SessionQueueBase): if queue_item.status not in ["canceled", "failed", "completed"]: status = "failed" if error is not None else "canceled" queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here - self.__invoker.services.queue.cancel(queue_item.session_id) self.__invoker.services.events.emit_session_canceled( queue_item_id=queue_item.item_id, queue_id=queue_item.queue_id, @@ -471,7 +470,6 @@ class SqliteSessionQueue(SessionQueueBase): ) self.__conn.commit() if current_queue_item is not None and current_queue_item.batch_id in batch_ids: - self.__invoker.services.queue.cancel(current_queue_item.session_id) self.__invoker.services.events.emit_session_canceled( queue_item_id=current_queue_item.item_id, queue_id=current_queue_item.queue_id, @@ -523,7 +521,6 @@ class SqliteSessionQueue(SessionQueueBase): ) self.__conn.commit() if current_queue_item is not None and current_queue_item.queue_id == queue_id: - self.__invoker.services.queue.cancel(current_queue_item.session_id) self.__invoker.services.events.emit_session_canceled( queue_item_id=current_queue_item.item_id, queue_id=current_queue_item.queue_id, diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 43ecb2c543..4606bd9e03 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,6 +1,7 @@ +import threading from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Callable, Optional from PIL.Image import Image from torch import Tensor @@ -370,6 +371,12 @@ class ConfigInterface(InvocationContextInterface): class UtilInterface(InvocationContextInterface): + def __init__( + self, services: InvocationServices, context_data: InvocationContextData, is_canceled: Callable[[], bool] + ) -> None: + super().__init__(services, context_data) + self._is_canceled = is_canceled + def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None: """ The step callback emits a progress event with the current step, the total number of @@ -390,8 +397,8 @@ class UtilInterface(InvocationContextInterface): context_data=self._context_data, intermediate_state=intermediate_state, base_model=base_model, - invocation_queue=self._services.queue, events=self._services.events, + is_canceled=self._is_canceled, ) @@ -412,6 +419,7 @@ class InvocationContext: boards: BoardsInterface, context_data: InvocationContextData, services: InvocationServices, + is_canceled: Callable[[], bool], ) -> None: self.images = images """Provides methods to save, get and update images and their metadata.""" @@ -433,11 +441,13 @@ class InvocationContext: """Provides data about the current queue item and invocation. This is an internal API and may change without warning.""" self._services = services """Provides access to the full application services. This is an internal API and may change without warning.""" + self._is_canceled = is_canceled def build_invocation_context( services: InvocationServices, context_data: InvocationContextData, + cancel_event: threading.Event, ) -> InvocationContext: """ Builds the invocation context for a specific invocation execution. @@ -446,12 +456,15 @@ def build_invocation_context( :param invocation_context_data: The invocation context data. """ + def is_canceled() -> bool: + return cancel_event.is_set() + logger = LoggerInterface(services=services, context_data=context_data) images = ImagesInterface(services=services, context_data=context_data) tensors = TensorsInterface(services=services, context_data=context_data) models = ModelsInterface(services=services, context_data=context_data) config = ConfigInterface(services=services, context_data=context_data) - util = UtilInterface(services=services, context_data=context_data) + util = UtilInterface(services=services, context_data=context_data, is_canceled=is_canceled) conditioning = ConditioningInterface(services=services, context_data=context_data) boards = BoardsInterface(services=services, context_data=context_data) @@ -466,6 +479,7 @@ def build_invocation_context( conditioning=conditioning, services=services, boards=boards, + is_canceled=is_canceled, ) return ctx diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 33d00ca366..9c9f5254a4 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -1,9 +1,9 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable import torch from PIL import Image -from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage +from invokeai.app.services.session_processor.session_processor_common import CanceledException, ProgressImage from invokeai.backend.model_manager.config import BaseModelType from ...backend.stable_diffusion import PipelineIntermediateState @@ -11,7 +11,6 @@ from ...backend.util.util import image_to_dataURL if TYPE_CHECKING: from invokeai.app.services.events.events_base import EventServiceBase - from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC from invokeai.app.services.shared.invocation_context import InvocationContextData @@ -34,10 +33,10 @@ def stable_diffusion_step_callback( context_data: "InvocationContextData", intermediate_state: PipelineIntermediateState, base_model: BaseModelType, - invocation_queue: "InvocationQueueABC", events: "EventServiceBase", + is_canceled: Callable[[], bool], ) -> None: - if invocation_queue.is_canceled(context_data.session_id): + if is_canceled(): raise CanceledException # Some schedulers report not only the noisy latents at the current timestep,