feat(app): support multiple processor lifecycle callbacks

This commit is contained in:
psychedelicious 2024-05-22 19:05:49 +10:00
parent cb8e9e1c7b
commit cef1585dfb
2 changed files with 34 additions and 32 deletions

View File

@ -130,13 +130,13 @@ class ApiDependencies:
session_processor = DefaultSessionProcessor(
DefaultSessionRunner(
on_before_run_session=on_before_run_session,
on_before_run_node=on_before_run_node,
on_after_run_node=on_after_run_node,
on_node_error=on_node_error,
on_after_run_session=on_after_run_session,
on_before_run_session_callbacks=[on_before_run_session],
on_before_run_node_callbacks=[on_before_run_node],
on_after_run_node_callbacks=[on_after_run_node],
on_node_error_callbacks=[on_node_error],
on_after_run_session_callbacks=[on_after_run_session],
),
on_non_fatal_processor_error,
on_non_fatal_processor_error_callbacks=[on_non_fatal_processor_error],
)
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()

View File

@ -38,17 +38,17 @@ class DefaultSessionRunner(SessionRunnerBase):
def __init__(
self,
on_before_run_session: Optional[OnBeforeRunSession] = None,
on_before_run_node: Optional[OnBeforeRunNode] = None,
on_after_run_node: Optional[OnAfterRunNode] = None,
on_node_error: Optional[OnNodeError] = None,
on_after_run_session: Optional[OnAfterRunSession] = None,
on_before_run_session_callbacks: Optional[list[OnBeforeRunSession]] = None,
on_before_run_node_callbacks: Optional[list[OnBeforeRunNode]] = None,
on_after_run_node_callbacks: Optional[list[OnAfterRunNode]] = None,
on_node_error_callbacks: Optional[list[OnNodeError]] = None,
on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None,
):
self.on_before_run_session = on_before_run_session
self.on_before_run_node = on_before_run_node
self.on_after_run_node = on_after_run_node
self.on_node_error = on_node_error
self.on_after_run_session = on_after_run_session
self._on_before_run_session_callbacks = on_before_run_session_callbacks or []
self._on_before_run_node_callbacks = on_before_run_node_callbacks or []
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
self._on_node_error_callbacks = on_node_error_callbacks or []
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
"""Start the session runner"""
@ -77,8 +77,8 @@ class DefaultSessionRunner(SessionRunnerBase):
if self._profiler is not None:
self._profiler.start(profile_id=queue_item.session_id)
if self.on_before_run_session:
self.on_before_run_session(queue_item=queue_item)
for callback in self._on_before_run_session_callbacks:
callback(queue_item=queue_item)
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
# If we are profiling, stop the profiler and dump the profile & stats
@ -103,8 +103,8 @@ class DefaultSessionRunner(SessionRunnerBase):
self._services.performance_statistics.log_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats()
if self.on_after_run_session:
self.on_after_run_session(queue_item)
for callback in self._on_after_run_session_callbacks:
callback(queue_item)
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run before a node is executed"""
@ -117,9 +117,9 @@ class DefaultSessionRunner(SessionRunnerBase):
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
)
# And run lifecycle callbacks
if self.on_before_run_node is not None:
self.on_before_run_node(invocation, queue_item)
for callback in self._on_before_run_node_callbacks:
callback(invocation, queue_item)
def _on_after_run_node(
self, invocation: BaseInvocation, queue_item: SessionQueueItem, outputs: BaseInvocationOutput
@ -135,9 +135,9 @@ class DefaultSessionRunner(SessionRunnerBase):
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
result=outputs.model_dump(),
)
# And run lifecycle callbacks
if self.on_after_run_node is not None:
self.on_after_run_node(invocation, queue_item, outputs)
for callback in self._on_after_run_node_callbacks:
callback(invocation, queue_item, outputs)
def _on_node_error(
self,
@ -169,8 +169,8 @@ class DefaultSessionRunner(SessionRunnerBase):
project_id=None,
)
if self.on_node_error is not None:
self.on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback)
for callback in self._on_node_error_callbacks:
callback(invocation, queue_item, exc_type, exc_value, exc_traceback)
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run a single node in the graph"""
@ -213,6 +213,7 @@ class DefaultSessionRunner(SessionRunnerBase):
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
# Must extract the exception traceback here to not lose its stacktrace when we change scope
exc_type = type(e)
exc_value = e
exc_traceback = e.__traceback__
@ -224,14 +225,14 @@ class DefaultSessionProcessor(SessionProcessorBase):
def __init__(
self,
session_runner: Optional[SessionRunnerBase] = None,
on_non_fatal_processor_error: Optional[OnNonFatalProcessorError] = None,
on_non_fatal_processor_error_callbacks: Optional[list[OnNonFatalProcessorError]] = None,
thread_limit: int = 1,
polling_interval: int = 1,
) -> None:
super().__init__()
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
self.on_non_fatal_processor_error = on_non_fatal_processor_error
self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or []
self._thread_limit = thread_limit
self._polling_interval = polling_interval
@ -250,8 +251,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
if self.on_non_fatal_processor_error:
self.on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item)
for callback in self._on_non_fatal_processor_error_callbacks:
callback(exc_type, exc_value, exc_traceback, queue_item)
def start(self, invoker: Invoker) -> None:
self._invoker: Invoker = invoker
@ -377,6 +378,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
self.session_runner.run(queue_item=self._queue_item)
except Exception as e:
# Must extract the exception traceback here to not lose its stacktrace when we change scope
exc_type = type(e)
exc_value = e
exc_traceback = e.__traceback__