mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(app): iterate on processor split 2
- Use protocol to define callbacks, this allows them to have kwargs - Shuffle the profiler around a bit - Move `thread_limit` and `polling_interval` to `__init__`; `start` is called programmatically and will never get these args in practice
This commit is contained in:
parent
be41c84305
commit
47722528a3
@ -112,7 +112,7 @@ class ApiDependencies:
|
||||
print("BEFORE RUN NODE", invocation.id)
|
||||
return True
|
||||
|
||||
def on_after_run_node(invocation, queue_item, outputs):
|
||||
def on_after_run_node(invocation, queue_item, output):
|
||||
print("AFTER RUN NODE", invocation.id)
|
||||
return True
|
||||
|
||||
@ -124,17 +124,17 @@ class ApiDependencies:
|
||||
print("AFTER RUN SESSION", queue_item.item_id)
|
||||
return True
|
||||
|
||||
def on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback):
|
||||
def on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item=None):
|
||||
print("NON FATAL PROCESSOR ERROR", exc_value)
|
||||
return True
|
||||
|
||||
session_processor = DefaultSessionProcessor(
|
||||
DefaultSessionRunner(
|
||||
on_before_run_session,
|
||||
on_before_run_node,
|
||||
on_after_run_node,
|
||||
on_node_error,
|
||||
on_after_run_session,
|
||||
on_before_run_session=on_before_run_session,
|
||||
on_before_run_node=on_before_run_node,
|
||||
on_after_run_node=on_after_run_node,
|
||||
on_node_error=on_node_error,
|
||||
on_after_run_session=on_after_run_session,
|
||||
),
|
||||
on_non_fatal_processor_error,
|
||||
)
|
||||
|
@ -1,10 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from threading import Event
|
||||
from types import TracebackType
|
||||
from typing import Optional, Protocol
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
|
||||
|
||||
class SessionRunnerBase(ABC):
|
||||
@ -13,7 +16,7 @@ class SessionRunnerBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def start(self, services: InvocationServices, cancel_event: Event) -> None:
|
||||
def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None:
|
||||
"""Starts the session runner"""
|
||||
pass
|
||||
|
||||
@ -51,3 +54,42 @@ class SessionProcessorBase(ABC):
|
||||
def get_status(self) -> SessionProcessorStatus:
|
||||
"""Gets the status of the session processor"""
|
||||
pass
|
||||
|
||||
|
||||
class OnBeforeRunNode(Protocol):
|
||||
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> bool: ...
|
||||
|
||||
|
||||
class OnAfterRunNode(Protocol):
|
||||
def __call__(
|
||||
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
class OnNodeError(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
queue_item: SessionQueueItem,
|
||||
exc_type: type,
|
||||
exc_value: BaseException,
|
||||
exc_traceback: TracebackType,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
class OnBeforeRunSession(Protocol):
|
||||
def __call__(self, queue_item: SessionQueueItem) -> bool: ...
|
||||
|
||||
|
||||
class OnAfterRunSession(Protocol):
|
||||
def __call__(self, queue_item: SessionQueueItem) -> bool: ...
|
||||
|
||||
|
||||
class OnNonFatalProcessorError(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
exc_type: type,
|
||||
exc_value: BaseException,
|
||||
exc_traceback: TracebackType,
|
||||
queue_item: Optional[SessionQueueItem] = None,
|
||||
) -> bool: ...
|
||||
|
@ -3,7 +3,7 @@ from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from types import TracebackType
|
||||
from typing import Callable, Optional, TypeAlias
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
@ -11,6 +11,14 @@ from fastapi_events.typing import Event as FastAPIEvent
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||
from invokeai.app.services.session_processor.session_processor_base import (
|
||||
OnAfterRunNode,
|
||||
OnAfterRunSession,
|
||||
OnBeforeRunNode,
|
||||
OnBeforeRunSession,
|
||||
OnNodeError,
|
||||
OnNonFatalProcessorError,
|
||||
)
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||
@ -20,13 +28,6 @@ from ..invoker import Invoker
|
||||
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
|
||||
from .session_processor_common import SessionProcessorStatus
|
||||
|
||||
OnBeforeRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem], bool]
|
||||
OnAfterRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, BaseInvocationOutput], bool]
|
||||
OnNodeError: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, type, BaseException, TracebackType], bool]
|
||||
OnBeforeRunSession: TypeAlias = Callable[[SessionQueueItem], bool]
|
||||
OnAfterRunSession: TypeAlias = Callable[[SessionQueueItem], bool]
|
||||
OnNonFatalProcessorError: TypeAlias = Callable[[Optional[SessionQueueItem], type, BaseException, TracebackType], bool]
|
||||
|
||||
|
||||
def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str:
|
||||
return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
|
||||
@ -49,16 +50,18 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
self.on_node_error = on_node_error
|
||||
self.on_after_run_session = on_after_run_session
|
||||
|
||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
|
||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
|
||||
"""Start the session runner"""
|
||||
self.services = services
|
||||
self.cancel_event = cancel_event
|
||||
self.profiler = profiler
|
||||
|
||||
def run(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None):
|
||||
def run(self, queue_item: SessionQueueItem):
|
||||
"""Run the graph"""
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
|
||||
self._on_before_run_session(queue_item=queue_item)
|
||||
|
||||
while True:
|
||||
invocation = queue_item.session.next()
|
||||
if invocation is None or self.cancel_event.is_set():
|
||||
@ -66,20 +69,21 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
self.run_node(invocation, queue_item)
|
||||
if queue_item.session.is_complete() or self.cancel_event.is_set():
|
||||
break
|
||||
|
||||
self._on_after_run_session(queue_item=queue_item)
|
||||
|
||||
def _on_before_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None:
|
||||
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||
# If profiling is enabled, start the profiler
|
||||
if profiler is not None:
|
||||
profiler.start(profile_id=queue_item.session_id)
|
||||
if self.profiler is not None:
|
||||
self.profiler.start(profile_id=queue_item.session_id)
|
||||
|
||||
if self.on_before_run_session:
|
||||
self.on_before_run_session(queue_item)
|
||||
self.on_before_run_session(queue_item=queue_item)
|
||||
|
||||
def _on_after_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None:
|
||||
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if profiler:
|
||||
profile_path = profiler.stop()
|
||||
if self.profiler is not None:
|
||||
profile_path = self.profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=queue_item.session.id, output_path=stats_path
|
||||
@ -221,11 +225,15 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self,
|
||||
session_runner: Optional[SessionRunnerBase] = None,
|
||||
on_non_fatal_processor_error: Optional[OnNonFatalProcessorError] = None,
|
||||
thread_limit: int = 1,
|
||||
polling_interval: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
|
||||
self.on_non_fatal_processor_error = on_non_fatal_processor_error
|
||||
self._thread_limit = thread_limit
|
||||
self._polling_interval = polling_interval
|
||||
|
||||
def _on_non_fatal_processor_error(
|
||||
self,
|
||||
@ -243,14 +251,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
|
||||
|
||||
if self.on_non_fatal_processor_error:
|
||||
self.on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback)
|
||||
self.on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item)
|
||||
|
||||
def start(
|
||||
self,
|
||||
invoker: Invoker,
|
||||
thread_limit: int = 1,
|
||||
polling_interval: int = 1,
|
||||
) -> None:
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker: Invoker = invoker
|
||||
self._queue_item: Optional[SessionQueueItem] = None
|
||||
self._invocation: Optional[BaseInvocation] = None
|
||||
@ -262,9 +265,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
||||
|
||||
self._thread_limit = thread_limit
|
||||
self._thread_semaphore = BoundedSemaphore(thread_limit)
|
||||
self._polling_interval = polling_interval
|
||||
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
|
||||
|
||||
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
|
||||
# the profiler will create a new profile for each session.
|
||||
@ -278,7 +279,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
else None
|
||||
)
|
||||
|
||||
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
|
||||
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler)
|
||||
self._thread = Thread(
|
||||
name="session_processor",
|
||||
target=self._process,
|
||||
|
Loading…
Reference in New Issue
Block a user