mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rename graph processor to session runner to better describe what it's doing, add before/after callbacks for sessions
This commit is contained in:
parent
7d5a88b69d
commit
46c904d08a
@ -21,8 +21,8 @@ from .session_processor_base import SessionProcessorBase
|
|||||||
from .session_processor_common import SessionProcessorStatus
|
from .session_processor_common import SessionProcessorStatus
|
||||||
|
|
||||||
|
|
||||||
class GraphProcessor:
|
class SessionRunner:
|
||||||
"""Process a graph of invocations"""
|
"""Processes a single session's invocations"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -167,6 +167,8 @@ class GraphProcessor:
|
|||||||
|
|
||||||
|
|
||||||
class DefaultSessionProcessor(SessionProcessorBase):
|
class DefaultSessionProcessor(SessionProcessorBase):
|
||||||
|
"""Processes sessions from the session queue"""
|
||||||
|
|
||||||
def start(
|
def start(
|
||||||
self,
|
self,
|
||||||
invoker: Invoker,
|
invoker: Invoker,
|
||||||
@ -174,10 +176,14 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
polling_interval: int = 1,
|
polling_interval: int = 1,
|
||||||
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||||
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||||
|
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
||||||
|
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._invoker: Invoker = invoker
|
self._invoker: Invoker = invoker
|
||||||
self._queue_item: Optional[SessionQueueItem] = None
|
self._queue_item: Optional[SessionQueueItem] = None
|
||||||
self._invocation: Optional[BaseInvocation] = None
|
self._invocation: Optional[BaseInvocation] = None
|
||||||
|
self.on_before_run_session = on_before_run_session
|
||||||
|
self.on_after_run_session = on_after_run_session
|
||||||
|
|
||||||
self._resume_event = ThreadEvent()
|
self._resume_event = ThreadEvent()
|
||||||
self._stop_event = ThreadEvent()
|
self._stop_event = ThreadEvent()
|
||||||
@ -202,7 +208,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
self.graph_processor = GraphProcessor(
|
self.session_runner = SessionRunner(
|
||||||
services=self._invoker.services,
|
services=self._invoker.services,
|
||||||
cancel_event=self._cancel_event,
|
cancel_event=self._cancel_event,
|
||||||
profiler=self._profiler,
|
profiler=self._profiler,
|
||||||
@ -275,11 +281,19 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
# Get the next session to process
|
# Get the next session to process
|
||||||
self._queue_item = self._invoker.services.session_queue.dequeue()
|
self._queue_item = self._invoker.services.session_queue.dequeue()
|
||||||
if self._queue_item is not None and resume_event.is_set():
|
if self._queue_item is not None and resume_event.is_set():
|
||||||
|
# If we have a on_before_run_session callback, call it
|
||||||
|
if self.on_before_run_session is not None:
|
||||||
|
self.on_before_run_session(self._queue_item)
|
||||||
|
|
||||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||||
cancel_event.clear()
|
cancel_event.clear()
|
||||||
|
|
||||||
# Run the graph
|
# Run the graph
|
||||||
self.graph_processor.run(queue_item=self._queue_item)
|
self.session_runner.run(queue_item=self._queue_item)
|
||||||
|
|
||||||
|
# If we have a on_after_run_session callback, call it
|
||||||
|
if self.on_after_run_session is not None:
|
||||||
|
self.on_after_run_session(self._queue_item)
|
||||||
|
|
||||||
# The session is complete, immediately poll for next session
|
# The session is complete, immediately poll for next session
|
||||||
self._queue_item = None
|
self._queue_item = None
|
||||||
|
Loading…
Reference in New Issue
Block a user