feat(app): support multiple processor lifecycle callbacks

This commit is contained in:
psychedelicious 2024-05-22 19:05:49 +10:00
parent cb8e9e1c7b
commit cef1585dfb
2 changed files with 34 additions and 32 deletions

View File

@ -130,13 +130,13 @@ class ApiDependencies:
session_processor = DefaultSessionProcessor( session_processor = DefaultSessionProcessor(
DefaultSessionRunner( DefaultSessionRunner(
on_before_run_session=on_before_run_session, on_before_run_session_callbacks=[on_before_run_session],
on_before_run_node=on_before_run_node, on_before_run_node_callbacks=[on_before_run_node],
on_after_run_node=on_after_run_node, on_after_run_node_callbacks=[on_after_run_node],
on_node_error=on_node_error, on_node_error_callbacks=[on_node_error],
on_after_run_session=on_after_run_session, on_after_run_session_callbacks=[on_after_run_session],
), ),
on_non_fatal_processor_error, on_non_fatal_processor_error_callbacks=[on_non_fatal_processor_error],
) )
session_queue = SqliteSessionQueue(db=db) session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService() urls = LocalUrlService()

View File

@ -38,17 +38,17 @@ class DefaultSessionRunner(SessionRunnerBase):
def __init__( def __init__(
self, self,
on_before_run_session: Optional[OnBeforeRunSession] = None, on_before_run_session_callbacks: Optional[list[OnBeforeRunSession]] = None,
on_before_run_node: Optional[OnBeforeRunNode] = None, on_before_run_node_callbacks: Optional[list[OnBeforeRunNode]] = None,
on_after_run_node: Optional[OnAfterRunNode] = None, on_after_run_node_callbacks: Optional[list[OnAfterRunNode]] = None,
on_node_error: Optional[OnNodeError] = None, on_node_error_callbacks: Optional[list[OnNodeError]] = None,
on_after_run_session: Optional[OnAfterRunSession] = None, on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None,
): ):
self.on_before_run_session = on_before_run_session self._on_before_run_session_callbacks = on_before_run_session_callbacks or []
self.on_before_run_node = on_before_run_node self._on_before_run_node_callbacks = on_before_run_node_callbacks or []
self.on_after_run_node = on_after_run_node self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
self.on_node_error = on_node_error self._on_node_error_callbacks = on_node_error_callbacks or []
self.on_after_run_session = on_after_run_session self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None): def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
"""Start the session runner""" """Start the session runner"""
@ -77,8 +77,8 @@ class DefaultSessionRunner(SessionRunnerBase):
if self._profiler is not None: if self._profiler is not None:
self._profiler.start(profile_id=queue_item.session_id) self._profiler.start(profile_id=queue_item.session_id)
if self.on_before_run_session: for callback in self._on_before_run_session_callbacks:
self.on_before_run_session(queue_item=queue_item) callback(queue_item=queue_item)
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
# If we are profiling, stop the profiler and dump the profile & stats # If we are profiling, stop the profiler and dump the profile & stats
@ -103,8 +103,8 @@ class DefaultSessionRunner(SessionRunnerBase):
self._services.performance_statistics.log_stats(queue_item.session.id) self._services.performance_statistics.log_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats() self._services.performance_statistics.reset_stats()
if self.on_after_run_session: for callback in self._on_after_run_session_callbacks:
self.on_after_run_session(queue_item) callback(queue_item)
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run before a node is executed""" """Run before a node is executed"""
@ -117,9 +117,9 @@ class DefaultSessionRunner(SessionRunnerBase):
node=invocation.model_dump(), node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id], source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
) )
# And run lifecycle callbacks
if self.on_before_run_node is not None: for callback in self._on_before_run_node_callbacks:
self.on_before_run_node(invocation, queue_item) callback(invocation, queue_item)
def _on_after_run_node( def _on_after_run_node(
self, invocation: BaseInvocation, queue_item: SessionQueueItem, outputs: BaseInvocationOutput self, invocation: BaseInvocation, queue_item: SessionQueueItem, outputs: BaseInvocationOutput
@ -135,9 +135,9 @@ class DefaultSessionRunner(SessionRunnerBase):
source_node_id=queue_item.session.prepared_source_mapping[invocation.id], source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
result=outputs.model_dump(), result=outputs.model_dump(),
) )
# And run lifecycle callbacks
if self.on_after_run_node is not None: for callback in self._on_after_run_node_callbacks:
self.on_after_run_node(invocation, queue_item, outputs) callback(invocation, queue_item, outputs)
def _on_node_error( def _on_node_error(
self, self,
@ -169,8 +169,8 @@ class DefaultSessionRunner(SessionRunnerBase):
project_id=None, project_id=None,
) )
if self.on_node_error is not None: for callback in self._on_node_error_callbacks:
self.on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback) callback(invocation, queue_item, exc_type, exc_value, exc_traceback)
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run a single node in the graph""" """Run a single node in the graph"""
@ -213,6 +213,7 @@ class DefaultSessionRunner(SessionRunnerBase):
# loop go to its next iteration, and the cancel event will be handled correctly. # loop go to its next iteration, and the cancel event will be handled correctly.
pass pass
except Exception as e: except Exception as e:
# Must extract the exception traceback here to not lose its stacktrace when we change scope
exc_type = type(e) exc_type = type(e)
exc_value = e exc_value = e
exc_traceback = e.__traceback__ exc_traceback = e.__traceback__
@ -224,14 +225,14 @@ class DefaultSessionProcessor(SessionProcessorBase):
def __init__( def __init__(
self, self,
session_runner: Optional[SessionRunnerBase] = None, session_runner: Optional[SessionRunnerBase] = None,
on_non_fatal_processor_error: Optional[OnNonFatalProcessorError] = None, on_non_fatal_processor_error_callbacks: Optional[list[OnNonFatalProcessorError]] = None,
thread_limit: int = 1, thread_limit: int = 1,
polling_interval: int = 1, polling_interval: int = 1,
) -> None: ) -> None:
super().__init__() super().__init__()
self.session_runner = session_runner if session_runner else DefaultSessionRunner() self.session_runner = session_runner if session_runner else DefaultSessionRunner()
self.on_non_fatal_processor_error = on_non_fatal_processor_error self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or []
self._thread_limit = thread_limit self._thread_limit = thread_limit
self._polling_interval = polling_interval self._polling_interval = polling_interval
@ -250,8 +251,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace) self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
if self.on_non_fatal_processor_error: for callback in self._on_non_fatal_processor_error_callbacks:
self.on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item) callback(exc_type, exc_value, exc_traceback, queue_item)
def start(self, invoker: Invoker) -> None: def start(self, invoker: Invoker) -> None:
self._invoker: Invoker = invoker self._invoker: Invoker = invoker
@ -377,6 +378,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
self.session_runner.run(queue_item=self._queue_item) self.session_runner.run(queue_item=self._queue_item)
except Exception as e: except Exception as e:
# Must extract the exception traceback here to not lose its stacktrace when we change scope
exc_type = type(e) exc_type = type(e)
exc_value = e exc_value = e
exc_traceback = e.__traceback__ exc_traceback = e.__traceback__