mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(app): support multiple processor lifecycle callbacks
This commit is contained in:
parent
cb8e9e1c7b
commit
cef1585dfb
@ -130,13 +130,13 @@ class ApiDependencies:
|
||||
|
||||
session_processor = DefaultSessionProcessor(
|
||||
DefaultSessionRunner(
|
||||
on_before_run_session=on_before_run_session,
|
||||
on_before_run_node=on_before_run_node,
|
||||
on_after_run_node=on_after_run_node,
|
||||
on_node_error=on_node_error,
|
||||
on_after_run_session=on_after_run_session,
|
||||
on_before_run_session_callbacks=[on_before_run_session],
|
||||
on_before_run_node_callbacks=[on_before_run_node],
|
||||
on_after_run_node_callbacks=[on_after_run_node],
|
||||
on_node_error_callbacks=[on_node_error],
|
||||
on_after_run_session_callbacks=[on_after_run_session],
|
||||
),
|
||||
on_non_fatal_processor_error,
|
||||
on_non_fatal_processor_error_callbacks=[on_non_fatal_processor_error],
|
||||
)
|
||||
session_queue = SqliteSessionQueue(db=db)
|
||||
urls = LocalUrlService()
|
||||
|
@ -38,17 +38,17 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_before_run_session: Optional[OnBeforeRunSession] = None,
|
||||
on_before_run_node: Optional[OnBeforeRunNode] = None,
|
||||
on_after_run_node: Optional[OnAfterRunNode] = None,
|
||||
on_node_error: Optional[OnNodeError] = None,
|
||||
on_after_run_session: Optional[OnAfterRunSession] = None,
|
||||
on_before_run_session_callbacks: Optional[list[OnBeforeRunSession]] = None,
|
||||
on_before_run_node_callbacks: Optional[list[OnBeforeRunNode]] = None,
|
||||
on_after_run_node_callbacks: Optional[list[OnAfterRunNode]] = None,
|
||||
on_node_error_callbacks: Optional[list[OnNodeError]] = None,
|
||||
on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None,
|
||||
):
|
||||
self.on_before_run_session = on_before_run_session
|
||||
self.on_before_run_node = on_before_run_node
|
||||
self.on_after_run_node = on_after_run_node
|
||||
self.on_node_error = on_node_error
|
||||
self.on_after_run_session = on_after_run_session
|
||||
self._on_before_run_session_callbacks = on_before_run_session_callbacks or []
|
||||
self._on_before_run_node_callbacks = on_before_run_node_callbacks or []
|
||||
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
|
||||
self._on_node_error_callbacks = on_node_error_callbacks or []
|
||||
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
|
||||
|
||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
|
||||
"""Start the session runner"""
|
||||
@ -77,8 +77,8 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
if self._profiler is not None:
|
||||
self._profiler.start(profile_id=queue_item.session_id)
|
||||
|
||||
if self.on_before_run_session:
|
||||
self.on_before_run_session(queue_item=queue_item)
|
||||
for callback in self._on_before_run_session_callbacks:
|
||||
callback(queue_item=queue_item)
|
||||
|
||||
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
@ -103,8 +103,8 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
self._services.performance_statistics.log_stats(queue_item.session.id)
|
||||
self._services.performance_statistics.reset_stats()
|
||||
|
||||
if self.on_after_run_session:
|
||||
self.on_after_run_session(queue_item)
|
||||
for callback in self._on_after_run_session_callbacks:
|
||||
callback(queue_item)
|
||||
|
||||
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run before a node is executed"""
|
||||
@ -117,9 +117,9 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
)
|
||||
# And run lifecycle callbacks
|
||||
if self.on_before_run_node is not None:
|
||||
self.on_before_run_node(invocation, queue_item)
|
||||
|
||||
for callback in self._on_before_run_node_callbacks:
|
||||
callback(invocation, queue_item)
|
||||
|
||||
def _on_after_run_node(
|
||||
self, invocation: BaseInvocation, queue_item: SessionQueueItem, outputs: BaseInvocationOutput
|
||||
@ -135,9 +135,9 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
# And run lifecycle callbacks
|
||||
if self.on_after_run_node is not None:
|
||||
self.on_after_run_node(invocation, queue_item, outputs)
|
||||
|
||||
for callback in self._on_after_run_node_callbacks:
|
||||
callback(invocation, queue_item, outputs)
|
||||
|
||||
def _on_node_error(
|
||||
self,
|
||||
@ -169,8 +169,8 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
project_id=None,
|
||||
)
|
||||
|
||||
if self.on_node_error is not None:
|
||||
self.on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback)
|
||||
for callback in self._on_node_error_callbacks:
|
||||
callback(invocation, queue_item, exc_type, exc_value, exc_traceback)
|
||||
|
||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run a single node in the graph"""
|
||||
@ -213,6 +213,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
except Exception as e:
|
||||
# Must extract the exception traceback here to not lose its stacktrace when we change scope
|
||||
exc_type = type(e)
|
||||
exc_value = e
|
||||
exc_traceback = e.__traceback__
|
||||
@ -224,14 +225,14 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def __init__(
|
||||
self,
|
||||
session_runner: Optional[SessionRunnerBase] = None,
|
||||
on_non_fatal_processor_error: Optional[OnNonFatalProcessorError] = None,
|
||||
on_non_fatal_processor_error_callbacks: Optional[list[OnNonFatalProcessorError]] = None,
|
||||
thread_limit: int = 1,
|
||||
polling_interval: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
|
||||
self.on_non_fatal_processor_error = on_non_fatal_processor_error
|
||||
self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or []
|
||||
self._thread_limit = thread_limit
|
||||
self._polling_interval = polling_interval
|
||||
|
||||
@ -250,8 +251,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
|
||||
|
||||
if self.on_non_fatal_processor_error:
|
||||
self.on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item)
|
||||
for callback in self._on_non_fatal_processor_error_callbacks:
|
||||
callback(exc_type, exc_value, exc_traceback, queue_item)
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker: Invoker = invoker
|
||||
@ -377,6 +378,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self.session_runner.run(queue_item=self._queue_item)
|
||||
|
||||
except Exception as e:
|
||||
# Must extract the exception traceback here to not lose its stacktrace when we change scope
|
||||
exc_type = type(e)
|
||||
exc_value = e
|
||||
exc_traceback = e.__traceback__
|
||||
|
Loading…
Reference in New Issue
Block a user