From 47722528a31319e53625e498b7c53eab7a10c27d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 22 May 2024 18:52:46 +1000 Subject: [PATCH] feat(app): iterate on processor split 2 - Use protocol to define callbacks, this allows them to have kwargs - Shuffle the profiler around a bit - Move `thread_limit` and `polling_interval` to `__init__`; `start` is called programmatically and will never get these args in practice --- invokeai/app/api/dependencies.py | 14 ++--- .../session_processor_base.py | 46 ++++++++++++++- .../session_processor_default.py | 57 ++++++++++--------- 3 files changed, 80 insertions(+), 37 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 1ff45be866..87df06d569 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -112,7 +112,7 @@ class ApiDependencies: print("BEFORE RUN NODE", invocation.id) return True - def on_after_run_node(invocation, queue_item, outputs): + def on_after_run_node(invocation, queue_item, output): print("AFTER RUN NODE", invocation.id) return True @@ -124,17 +124,17 @@ class ApiDependencies: print("AFTER RUN SESSION", queue_item.item_id) return True - def on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback): + def on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item=None): print("NON FATAL PROCESSOR ERROR", exc_value) return True session_processor = DefaultSessionProcessor( DefaultSessionRunner( - on_before_run_session, - on_before_run_node, - on_after_run_node, - on_node_error, - on_after_run_session, + on_before_run_session=on_before_run_session, + on_before_run_node=on_before_run_node, + on_after_run_node=on_after_run_node, + on_node_error=on_node_error, + on_after_run_session=on_after_run_session, ), on_non_fatal_processor_error, ) diff --git a/invokeai/app/services/session_processor/session_processor_base.py b/invokeai/app/services/session_processor/session_processor_base.py index 7140847518..bfae74e5fe 100644 --- a/invokeai/app/services/session_processor/session_processor_base.py +++ b/invokeai/app/services/session_processor/session_processor_base.py @@ -1,10 +1,13 @@ from abc import ABC, abstractmethod from threading import Event +from types import TracebackType +from typing import Optional, Protocol -from invokeai.app.invocations.baseinvocation import BaseInvocation +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem +from invokeai.app.util.profiler import Profiler class SessionRunnerBase(ABC): @@ -13,7 +16,7 @@ class SessionRunnerBase(ABC): """ @abstractmethod - def start(self, services: InvocationServices, cancel_event: Event) -> None: + def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None: """Starts the session runner""" pass @@ -51,3 +54,42 @@ class SessionProcessorBase(ABC): def get_status(self) -> SessionProcessorStatus: """Gets the status of the session processor""" pass + + +class OnBeforeRunNode(Protocol): + def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> bool: ... + + +class OnAfterRunNode(Protocol): + def __call__( + self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput + ) -> bool: ... + + +class OnNodeError(Protocol): + def __call__( + self, + invocation: BaseInvocation, + queue_item: SessionQueueItem, + exc_type: type, + exc_value: BaseException, + exc_traceback: TracebackType, + ) -> bool: ... + + +class OnBeforeRunSession(Protocol): + def __call__(self, queue_item: SessionQueueItem) -> bool: ... + + +class OnAfterRunSession(Protocol): + def __call__(self, queue_item: SessionQueueItem) -> bool: ... + + +class OnNonFatalProcessorError(Protocol): + def __call__( + self, + exc_type: type, + exc_value: BaseException, + exc_traceback: TracebackType, + queue_item: Optional[SessionQueueItem] = None, + ) -> bool: ... diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 33274dd97b..4172e45d17 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -3,7 +3,7 @@ from contextlib import suppress from threading import BoundedSemaphore, Thread from threading import Event as ThreadEvent from types import TracebackType -from typing import Callable, Optional, TypeAlias +from typing import Optional from fastapi_events.handlers.local import local_handler from fastapi_events.typing import Event as FastAPIEvent @@ -11,6 +11,14 @@ from fastapi_events.typing import Event as FastAPIEvent from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError +from invokeai.app.services.session_processor.session_processor_base import ( + OnAfterRunNode, + OnAfterRunSession, + OnBeforeRunNode, + OnBeforeRunSession, + OnNodeError, + OnNonFatalProcessorError, +) from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context @@ -20,13 +28,6 @@ from ..invoker import Invoker from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase from .session_processor_common import SessionProcessorStatus -OnBeforeRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem], bool] -OnAfterRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, BaseInvocationOutput], bool] -OnNodeError: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, type, BaseException, TracebackType], bool] -OnBeforeRunSession: TypeAlias = Callable[[SessionQueueItem], bool] -OnAfterRunSession: TypeAlias = Callable[[SessionQueueItem], bool] -OnNonFatalProcessorError: TypeAlias = Callable[[Optional[SessionQueueItem], type, BaseException, TracebackType], bool] - def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str: return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) @@ -49,16 +50,18 @@ class DefaultSessionRunner(SessionRunnerBase): self.on_node_error = on_node_error self.on_after_run_session = on_after_run_session - def start(self, services: InvocationServices, cancel_event: ThreadEvent): + def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None): """Start the session runner""" self.services = services self.cancel_event = cancel_event + self.profiler = profiler - def run(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None): + def run(self, queue_item: SessionQueueItem): """Run the graph""" # Loop over invocations until the session is complete or canceled self._on_before_run_session(queue_item=queue_item) + while True: invocation = queue_item.session.next() if invocation is None or self.cancel_event.is_set(): @@ -66,20 +69,21 @@ class DefaultSessionRunner(SessionRunnerBase): self.run_node(invocation, queue_item) if queue_item.session.is_complete() or self.cancel_event.is_set(): break + self._on_after_run_session(queue_item=queue_item) - def _on_before_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None: + def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: # If profiling is enabled, start the profiler - if profiler is not None: - profiler.start(profile_id=queue_item.session_id) + if self.profiler is not None: + self.profiler.start(profile_id=queue_item.session_id) if self.on_before_run_session: - self.on_before_run_session(queue_item) + self.on_before_run_session(queue_item=queue_item) - def _on_after_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None: + def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: # If we are profiling, stop the profiler and dump the profile & stats - if profiler: - profile_path = profiler.stop() + if self.profiler is not None: + profile_path = self.profiler.stop() stats_path = profile_path.with_suffix(".json") self.services.performance_statistics.dump_stats( graph_execution_state_id=queue_item.session.id, output_path=stats_path @@ -221,11 +225,15 @@ class DefaultSessionProcessor(SessionProcessorBase): self, session_runner: Optional[SessionRunnerBase] = None, on_non_fatal_processor_error: Optional[OnNonFatalProcessorError] = None, + thread_limit: int = 1, + polling_interval: int = 1, ) -> None: super().__init__() self.session_runner = session_runner if session_runner else DefaultSessionRunner() self.on_non_fatal_processor_error = on_non_fatal_processor_error + self._thread_limit = thread_limit + self._polling_interval = polling_interval def _on_non_fatal_processor_error( self, @@ -243,14 +251,9 @@ class DefaultSessionProcessor(SessionProcessorBase): self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace) if self.on_non_fatal_processor_error: - self.on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback) + self.on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item) - def start( - self, - invoker: Invoker, - thread_limit: int = 1, - polling_interval: int = 1, - ) -> None: + def start(self, invoker: Invoker) -> None: self._invoker: Invoker = invoker self._queue_item: Optional[SessionQueueItem] = None self._invocation: Optional[BaseInvocation] = None @@ -262,9 +265,7 @@ class DefaultSessionProcessor(SessionProcessorBase): local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event) - self._thread_limit = thread_limit - self._thread_semaphore = BoundedSemaphore(thread_limit) - self._polling_interval = polling_interval + self._thread_semaphore = BoundedSemaphore(self._thread_limit) # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, # the profiler will create a new profile for each session. @@ -278,7 +279,7 @@ class DefaultSessionProcessor(SessionProcessorBase): else None ) - self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event) + self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler) self._thread = Thread( name="session_processor", target=self._process,