feat(app): iterate on processor split 2

- Use protocol to define callbacks, this allows them to have kwargs
- Shuffle the profiler around a bit
- Move `thread_limit` and `polling_interval` to `__init__`; `start` is called programmatically and will never get these args in practice
This commit is contained in:
psychedelicious 2024-05-22 18:52:46 +10:00
parent be41c84305
commit 47722528a3
3 changed files with 80 additions and 37 deletions

View File

@ -112,7 +112,7 @@ class ApiDependencies:
print("BEFORE RUN NODE", invocation.id)
return True
def on_after_run_node(invocation, queue_item, outputs):
def on_after_run_node(invocation, queue_item, output):
print("AFTER RUN NODE", invocation.id)
return True
@ -124,17 +124,17 @@ class ApiDependencies:
print("AFTER RUN SESSION", queue_item.item_id)
return True
def on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback):
def on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item=None):
print("NON FATAL PROCESSOR ERROR", exc_value)
return True
session_processor = DefaultSessionProcessor(
DefaultSessionRunner(
on_before_run_session,
on_before_run_node,
on_after_run_node,
on_node_error,
on_after_run_session,
on_before_run_session=on_before_run_session,
on_before_run_node=on_before_run_node,
on_after_run_node=on_after_run_node,
on_node_error=on_node_error,
on_after_run_session=on_after_run_session,
),
on_non_fatal_processor_error,
)

View File

@ -1,10 +1,13 @@
from abc import ABC, abstractmethod
from threading import Event
from types import TracebackType
from typing import Optional, Protocol
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.util.profiler import Profiler
class SessionRunnerBase(ABC):
@ -13,7 +16,7 @@ class SessionRunnerBase(ABC):
"""
@abstractmethod
def start(self, services: InvocationServices, cancel_event: Event) -> None:
def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None:
"""Starts the session runner"""
pass
@ -51,3 +54,42 @@ class SessionProcessorBase(ABC):
def get_status(self) -> SessionProcessorStatus:
"""Gets the status of the session processor"""
pass
class OnBeforeRunNode(Protocol):
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> bool: ...
class OnAfterRunNode(Protocol):
def __call__(
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
) -> bool: ...
class OnNodeError(Protocol):
def __call__(
self,
invocation: BaseInvocation,
queue_item: SessionQueueItem,
exc_type: type,
exc_value: BaseException,
exc_traceback: TracebackType,
) -> bool: ...
class OnBeforeRunSession(Protocol):
def __call__(self, queue_item: SessionQueueItem) -> bool: ...
class OnAfterRunSession(Protocol):
def __call__(self, queue_item: SessionQueueItem) -> bool: ...
class OnNonFatalProcessorError(Protocol):
def __call__(
self,
exc_type: type,
exc_value: BaseException,
exc_traceback: TracebackType,
queue_item: Optional[SessionQueueItem] = None,
) -> bool: ...

View File

@ -3,7 +3,7 @@ from contextlib import suppress
from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from types import TracebackType
from typing import Callable, Optional, TypeAlias
from typing import Optional
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
@ -11,6 +11,14 @@ from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.session_processor.session_processor_base import (
OnAfterRunNode,
OnAfterRunSession,
OnBeforeRunNode,
OnBeforeRunSession,
OnNodeError,
OnNonFatalProcessorError,
)
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
@ -20,13 +28,6 @@ from ..invoker import Invoker
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
from .session_processor_common import SessionProcessorStatus
OnBeforeRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem], bool]
OnAfterRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, BaseInvocationOutput], bool]
OnNodeError: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, type, BaseException, TracebackType], bool]
OnBeforeRunSession: TypeAlias = Callable[[SessionQueueItem], bool]
OnAfterRunSession: TypeAlias = Callable[[SessionQueueItem], bool]
OnNonFatalProcessorError: TypeAlias = Callable[[Optional[SessionQueueItem], type, BaseException, TracebackType], bool]
def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str:
return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
@ -49,16 +50,18 @@ class DefaultSessionRunner(SessionRunnerBase):
self.on_node_error = on_node_error
self.on_after_run_session = on_after_run_session
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
"""Start the session runner"""
self.services = services
self.cancel_event = cancel_event
self.profiler = profiler
def run(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None):
def run(self, queue_item: SessionQueueItem):
"""Run the graph"""
# Loop over invocations until the session is complete or canceled
self._on_before_run_session(queue_item=queue_item)
while True:
invocation = queue_item.session.next()
if invocation is None or self.cancel_event.is_set():
@ -66,20 +69,21 @@ class DefaultSessionRunner(SessionRunnerBase):
self.run_node(invocation, queue_item)
if queue_item.session.is_complete() or self.cancel_event.is_set():
break
self._on_after_run_session(queue_item=queue_item)
def _on_before_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None:
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
# If profiling is enabled, start the profiler
if profiler is not None:
profiler.start(profile_id=queue_item.session_id)
if self.profiler is not None:
self.profiler.start(profile_id=queue_item.session_id)
if self.on_before_run_session:
self.on_before_run_session(queue_item)
self.on_before_run_session(queue_item=queue_item)
def _on_after_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None:
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
# If we are profiling, stop the profiler and dump the profile & stats
if profiler:
profile_path = profiler.stop()
if self.profiler is not None:
profile_path = self.profiler.stop()
stats_path = profile_path.with_suffix(".json")
self.services.performance_statistics.dump_stats(
graph_execution_state_id=queue_item.session.id, output_path=stats_path
@ -221,11 +225,15 @@ class DefaultSessionProcessor(SessionProcessorBase):
self,
session_runner: Optional[SessionRunnerBase] = None,
on_non_fatal_processor_error: Optional[OnNonFatalProcessorError] = None,
thread_limit: int = 1,
polling_interval: int = 1,
) -> None:
super().__init__()
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
self.on_non_fatal_processor_error = on_non_fatal_processor_error
self._thread_limit = thread_limit
self._polling_interval = polling_interval
def _on_non_fatal_processor_error(
self,
@ -243,14 +251,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
if self.on_non_fatal_processor_error:
self.on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback)
self.on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item)
def start(
self,
invoker: Invoker,
thread_limit: int = 1,
polling_interval: int = 1,
) -> None:
def start(self, invoker: Invoker) -> None:
self._invoker: Invoker = invoker
self._queue_item: Optional[SessionQueueItem] = None
self._invocation: Optional[BaseInvocation] = None
@ -262,9 +265,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
self._thread_limit = thread_limit
self._thread_semaphore = BoundedSemaphore(thread_limit)
self._polling_interval = polling_interval
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
# the profiler will create a new profile for each session.
@ -278,7 +279,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
else None
)
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler)
self._thread = Thread(
name="session_processor",
target=self._process,