From cef1585dfbbd12cd94c0f3c1a0a6efd019911362 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 22 May 2024 19:05:49 +1000 Subject: [PATCH] feat(app): support multiple processor lifecycle callbacks --- invokeai/app/api/dependencies.py | 12 ++--- .../session_processor_default.py | 54 ++++++++++--------- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 87df06d569..d9cefb0acf 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -130,13 +130,13 @@ class ApiDependencies: session_processor = DefaultSessionProcessor( DefaultSessionRunner( - on_before_run_session=on_before_run_session, - on_before_run_node=on_before_run_node, - on_after_run_node=on_after_run_node, - on_node_error=on_node_error, - on_after_run_session=on_after_run_session, + on_before_run_session_callbacks=[on_before_run_session], + on_before_run_node_callbacks=[on_before_run_node], + on_after_run_node_callbacks=[on_after_run_node], + on_node_error_callbacks=[on_node_error], + on_after_run_session_callbacks=[on_after_run_session], ), - on_non_fatal_processor_error, + on_non_fatal_processor_error_callbacks=[on_non_fatal_processor_error], ) session_queue = SqliteSessionQueue(db=db) urls = LocalUrlService() diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 6b4b84e099..74e7dd3deb 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -38,17 +38,17 @@ class DefaultSessionRunner(SessionRunnerBase): def __init__( self, - on_before_run_session: Optional[OnBeforeRunSession] = None, - on_before_run_node: Optional[OnBeforeRunNode] = None, - on_after_run_node: Optional[OnAfterRunNode] = None, - on_node_error: Optional[OnNodeError] = None, - on_after_run_session: Optional[OnAfterRunSession] = None, + on_before_run_session_callbacks: Optional[list[OnBeforeRunSession]] = None, + on_before_run_node_callbacks: Optional[list[OnBeforeRunNode]] = None, + on_after_run_node_callbacks: Optional[list[OnAfterRunNode]] = None, + on_node_error_callbacks: Optional[list[OnNodeError]] = None, + on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None, ): - self.on_before_run_session = on_before_run_session - self.on_before_run_node = on_before_run_node - self.on_after_run_node = on_after_run_node - self.on_node_error = on_node_error - self.on_after_run_session = on_after_run_session + self._on_before_run_session_callbacks = on_before_run_session_callbacks or [] + self._on_before_run_node_callbacks = on_before_run_node_callbacks or [] + self._on_after_run_node_callbacks = on_after_run_node_callbacks or [] + self._on_node_error_callbacks = on_node_error_callbacks or [] + self._on_after_run_session_callbacks = on_after_run_session_callbacks or [] def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None): """Start the session runner""" @@ -77,8 +77,8 @@ class DefaultSessionRunner(SessionRunnerBase): if self._profiler is not None: self._profiler.start(profile_id=queue_item.session_id) - if self.on_before_run_session: - self.on_before_run_session(queue_item=queue_item) + for callback in self._on_before_run_session_callbacks: + callback(queue_item=queue_item) def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: # If we are profiling, stop the profiler and dump the profile & stats @@ -103,8 +103,8 @@ class DefaultSessionRunner(SessionRunnerBase): self._services.performance_statistics.log_stats(queue_item.session.id) self._services.performance_statistics.reset_stats() - if self.on_after_run_session: - self.on_after_run_session(queue_item) + for callback in self._on_after_run_session_callbacks: + callback(queue_item) def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): """Run before a node is executed""" @@ -117,9 +117,9 @@ class DefaultSessionRunner(SessionRunnerBase): node=invocation.model_dump(), source_node_id=queue_item.session.prepared_source_mapping[invocation.id], ) - # And run lifecycle callbacks - if self.on_before_run_node is not None: - self.on_before_run_node(invocation, queue_item) + + for callback in self._on_before_run_node_callbacks: + callback(invocation, queue_item) def _on_after_run_node( self, invocation: BaseInvocation, queue_item: SessionQueueItem, outputs: BaseInvocationOutput @@ -135,9 +135,9 @@ class DefaultSessionRunner(SessionRunnerBase): source_node_id=queue_item.session.prepared_source_mapping[invocation.id], result=outputs.model_dump(), ) - # And run lifecycle callbacks - if self.on_after_run_node is not None: - self.on_after_run_node(invocation, queue_item, outputs) + + for callback in self._on_after_run_node_callbacks: + callback(invocation, queue_item, outputs) def _on_node_error( self, @@ -169,8 +169,8 @@ class DefaultSessionRunner(SessionRunnerBase): project_id=None, ) - if self.on_node_error is not None: - self.on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback) + for callback in self._on_node_error_callbacks: + callback(invocation, queue_item, exc_type, exc_value, exc_traceback) def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): """Run a single node in the graph""" @@ -213,6 +213,7 @@ class DefaultSessionRunner(SessionRunnerBase): # loop go to its next iteration, and the cancel event will be handled correctly. pass except Exception as e: + # Must extract the exception traceback here to not lose its stacktrace when we change scope exc_type = type(e) exc_value = e exc_traceback = e.__traceback__ @@ -224,14 +225,14 @@ class DefaultSessionProcessor(SessionProcessorBase): def __init__( self, session_runner: Optional[SessionRunnerBase] = None, - on_non_fatal_processor_error: Optional[OnNonFatalProcessorError] = None, + on_non_fatal_processor_error_callbacks: Optional[list[OnNonFatalProcessorError]] = None, thread_limit: int = 1, polling_interval: int = 1, ) -> None: super().__init__() self.session_runner = session_runner if session_runner else DefaultSessionRunner() - self.on_non_fatal_processor_error = on_non_fatal_processor_error + self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or [] self._thread_limit = thread_limit self._polling_interval = polling_interval @@ -250,8 +251,8 @@ class DefaultSessionProcessor(SessionProcessorBase): self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace) - if self.on_non_fatal_processor_error: - self.on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item) + for callback in self._on_non_fatal_processor_error_callbacks: + callback(exc_type, exc_value, exc_traceback, queue_item) def start(self, invoker: Invoker) -> None: self._invoker: Invoker = invoker @@ -377,6 +378,7 @@ class DefaultSessionProcessor(SessionProcessorBase): self.session_runner.run(queue_item=self._queue_item) except Exception as e: + # Must extract the exception traceback here to not lose its stacktrace when we change scope exc_type = type(e) exc_value = e exc_traceback = e.__traceback__