mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(app): support multiple processor lifecycle callbacks
This commit is contained in:
parent
7f70cde038
commit
1d973f92ff
@ -130,13 +130,13 @@ class ApiDependencies:
|
|||||||
|
|
||||||
session_processor = DefaultSessionProcessor(
|
session_processor = DefaultSessionProcessor(
|
||||||
DefaultSessionRunner(
|
DefaultSessionRunner(
|
||||||
on_before_run_session=on_before_run_session,
|
on_before_run_session_callbacks=[on_before_run_session],
|
||||||
on_before_run_node=on_before_run_node,
|
on_before_run_node_callbacks=[on_before_run_node],
|
||||||
on_after_run_node=on_after_run_node,
|
on_after_run_node_callbacks=[on_after_run_node],
|
||||||
on_node_error=on_node_error,
|
on_node_error_callbacks=[on_node_error],
|
||||||
on_after_run_session=on_after_run_session,
|
on_after_run_session_callbacks=[on_after_run_session],
|
||||||
),
|
),
|
||||||
on_non_fatal_processor_error,
|
on_non_fatal_processor_error_callbacks=[on_non_fatal_processor_error],
|
||||||
)
|
)
|
||||||
session_queue = SqliteSessionQueue(db=db)
|
session_queue = SqliteSessionQueue(db=db)
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
|
@ -38,17 +38,17 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
on_before_run_session: Optional[OnBeforeRunSession] = None,
|
on_before_run_session_callbacks: Optional[list[OnBeforeRunSession]] = None,
|
||||||
on_before_run_node: Optional[OnBeforeRunNode] = None,
|
on_before_run_node_callbacks: Optional[list[OnBeforeRunNode]] = None,
|
||||||
on_after_run_node: Optional[OnAfterRunNode] = None,
|
on_after_run_node_callbacks: Optional[list[OnAfterRunNode]] = None,
|
||||||
on_node_error: Optional[OnNodeError] = None,
|
on_node_error_callbacks: Optional[list[OnNodeError]] = None,
|
||||||
on_after_run_session: Optional[OnAfterRunSession] = None,
|
on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None,
|
||||||
):
|
):
|
||||||
self.on_before_run_session = on_before_run_session
|
self._on_before_run_session_callbacks = on_before_run_session_callbacks or []
|
||||||
self.on_before_run_node = on_before_run_node
|
self._on_before_run_node_callbacks = on_before_run_node_callbacks or []
|
||||||
self.on_after_run_node = on_after_run_node
|
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
|
||||||
self.on_node_error = on_node_error
|
self._on_node_error_callbacks = on_node_error_callbacks or []
|
||||||
self.on_after_run_session = on_after_run_session
|
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
|
||||||
|
|
||||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
|
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
|
||||||
"""Start the session runner"""
|
"""Start the session runner"""
|
||||||
@ -77,8 +77,8 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
if self._profiler is not None:
|
if self._profiler is not None:
|
||||||
self._profiler.start(profile_id=queue_item.session_id)
|
self._profiler.start(profile_id=queue_item.session_id)
|
||||||
|
|
||||||
if self.on_before_run_session:
|
for callback in self._on_before_run_session_callbacks:
|
||||||
self.on_before_run_session(queue_item=queue_item)
|
callback(queue_item=queue_item)
|
||||||
|
|
||||||
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
|
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||||
# If we are profiling, stop the profiler and dump the profile & stats
|
# If we are profiling, stop the profiler and dump the profile & stats
|
||||||
@ -103,8 +103,8 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
self._services.performance_statistics.log_stats(queue_item.session.id)
|
self._services.performance_statistics.log_stats(queue_item.session.id)
|
||||||
self._services.performance_statistics.reset_stats()
|
self._services.performance_statistics.reset_stats()
|
||||||
|
|
||||||
if self.on_after_run_session:
|
for callback in self._on_after_run_session_callbacks:
|
||||||
self.on_after_run_session(queue_item)
|
callback(queue_item)
|
||||||
|
|
||||||
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||||
"""Run before a node is executed"""
|
"""Run before a node is executed"""
|
||||||
@ -117,9 +117,9 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
node=invocation.model_dump(),
|
node=invocation.model_dump(),
|
||||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
)
|
)
|
||||||
# And run lifecycle callbacks
|
|
||||||
if self.on_before_run_node is not None:
|
for callback in self._on_before_run_node_callbacks:
|
||||||
self.on_before_run_node(invocation, queue_item)
|
callback(invocation, queue_item)
|
||||||
|
|
||||||
def _on_after_run_node(
|
def _on_after_run_node(
|
||||||
self, invocation: BaseInvocation, queue_item: SessionQueueItem, outputs: BaseInvocationOutput
|
self, invocation: BaseInvocation, queue_item: SessionQueueItem, outputs: BaseInvocationOutput
|
||||||
@ -135,9 +135,9 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
result=outputs.model_dump(),
|
result=outputs.model_dump(),
|
||||||
)
|
)
|
||||||
# And run lifecycle callbacks
|
|
||||||
if self.on_after_run_node is not None:
|
for callback in self._on_after_run_node_callbacks:
|
||||||
self.on_after_run_node(invocation, queue_item, outputs)
|
callback(invocation, queue_item, outputs)
|
||||||
|
|
||||||
def _on_node_error(
|
def _on_node_error(
|
||||||
self,
|
self,
|
||||||
@ -169,8 +169,8 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
project_id=None,
|
project_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.on_node_error is not None:
|
for callback in self._on_node_error_callbacks:
|
||||||
self.on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback)
|
callback(invocation, queue_item, exc_type, exc_value, exc_traceback)
|
||||||
|
|
||||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||||
"""Run a single node in the graph"""
|
"""Run a single node in the graph"""
|
||||||
@ -213,6 +213,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# Must extract the exception traceback here to not lose its stacktrace when we change scope
|
||||||
exc_type = type(e)
|
exc_type = type(e)
|
||||||
exc_value = e
|
exc_value = e
|
||||||
exc_traceback = e.__traceback__
|
exc_traceback = e.__traceback__
|
||||||
@ -224,14 +225,14 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
session_runner: Optional[SessionRunnerBase] = None,
|
session_runner: Optional[SessionRunnerBase] = None,
|
||||||
on_non_fatal_processor_error: Optional[OnNonFatalProcessorError] = None,
|
on_non_fatal_processor_error_callbacks: Optional[list[OnNonFatalProcessorError]] = None,
|
||||||
thread_limit: int = 1,
|
thread_limit: int = 1,
|
||||||
polling_interval: int = 1,
|
polling_interval: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
|
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
|
||||||
self.on_non_fatal_processor_error = on_non_fatal_processor_error
|
self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or []
|
||||||
self._thread_limit = thread_limit
|
self._thread_limit = thread_limit
|
||||||
self._polling_interval = polling_interval
|
self._polling_interval = polling_interval
|
||||||
|
|
||||||
@ -250,8 +251,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||||
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
|
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
|
||||||
|
|
||||||
if self.on_non_fatal_processor_error:
|
for callback in self._on_non_fatal_processor_error_callbacks:
|
||||||
self.on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item)
|
callback(exc_type, exc_value, exc_traceback, queue_item)
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
def start(self, invoker: Invoker) -> None:
|
||||||
self._invoker: Invoker = invoker
|
self._invoker: Invoker = invoker
|
||||||
@ -377,6 +378,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
self.session_runner.run(queue_item=self._queue_item)
|
self.session_runner.run(queue_item=self._queue_item)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# Must extract the exception traceback here to not lose its stacktrace when we change scope
|
||||||
exc_type = type(e)
|
exc_type = type(e)
|
||||||
exc_value = e
|
exc_value = e
|
||||||
exc_traceback = e.__traceback__
|
exc_traceback = e.__traceback__
|
||||||
|
Loading…
Reference in New Issue
Block a user