mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(app): iterate on processor split 2
- Use protocol to define callbacks, this allows them to have kwargs - Shuffle the profiler around a bit - Move `thread_limit` and `polling_interval` to `__init__`; `start` is called programmatically and will never get these args in practice
This commit is contained in:
parent
be41c84305
commit
47722528a3
@ -112,7 +112,7 @@ class ApiDependencies:
|
|||||||
print("BEFORE RUN NODE", invocation.id)
|
print("BEFORE RUN NODE", invocation.id)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def on_after_run_node(invocation, queue_item, outputs):
|
def on_after_run_node(invocation, queue_item, output):
|
||||||
print("AFTER RUN NODE", invocation.id)
|
print("AFTER RUN NODE", invocation.id)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -124,17 +124,17 @@ class ApiDependencies:
|
|||||||
print("AFTER RUN SESSION", queue_item.item_id)
|
print("AFTER RUN SESSION", queue_item.item_id)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback):
|
def on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item=None):
|
||||||
print("NON FATAL PROCESSOR ERROR", exc_value)
|
print("NON FATAL PROCESSOR ERROR", exc_value)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
session_processor = DefaultSessionProcessor(
|
session_processor = DefaultSessionProcessor(
|
||||||
DefaultSessionRunner(
|
DefaultSessionRunner(
|
||||||
on_before_run_session,
|
on_before_run_session=on_before_run_session,
|
||||||
on_before_run_node,
|
on_before_run_node=on_before_run_node,
|
||||||
on_after_run_node,
|
on_after_run_node=on_after_run_node,
|
||||||
on_node_error,
|
on_node_error=on_node_error,
|
||||||
on_after_run_session,
|
on_after_run_session=on_after_run_session,
|
||||||
),
|
),
|
||||||
on_non_fatal_processor_error,
|
on_non_fatal_processor_error,
|
||||||
)
|
)
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from threading import Event
|
from threading import Event
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Optional, Protocol
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||||
|
from invokeai.app.util.profiler import Profiler
|
||||||
|
|
||||||
|
|
||||||
class SessionRunnerBase(ABC):
|
class SessionRunnerBase(ABC):
|
||||||
@ -13,7 +16,7 @@ class SessionRunnerBase(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def start(self, services: InvocationServices, cancel_event: Event) -> None:
|
def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None:
|
||||||
"""Starts the session runner"""
|
"""Starts the session runner"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -51,3 +54,42 @@ class SessionProcessorBase(ABC):
|
|||||||
def get_status(self) -> SessionProcessorStatus:
|
def get_status(self) -> SessionProcessorStatus:
|
||||||
"""Gets the status of the session processor"""
|
"""Gets the status of the session processor"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OnBeforeRunNode(Protocol):
|
||||||
|
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
|
class OnAfterRunNode(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
|
||||||
|
) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
|
class OnNodeError(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
invocation: BaseInvocation,
|
||||||
|
queue_item: SessionQueueItem,
|
||||||
|
exc_type: type,
|
||||||
|
exc_value: BaseException,
|
||||||
|
exc_traceback: TracebackType,
|
||||||
|
) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
|
class OnBeforeRunSession(Protocol):
|
||||||
|
def __call__(self, queue_item: SessionQueueItem) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
|
class OnAfterRunSession(Protocol):
|
||||||
|
def __call__(self, queue_item: SessionQueueItem) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
|
class OnNonFatalProcessorError(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
exc_type: type,
|
||||||
|
exc_value: BaseException,
|
||||||
|
exc_traceback: TracebackType,
|
||||||
|
queue_item: Optional[SessionQueueItem] = None,
|
||||||
|
) -> bool: ...
|
||||||
|
@ -3,7 +3,7 @@ from contextlib import suppress
|
|||||||
from threading import BoundedSemaphore, Thread
|
from threading import BoundedSemaphore, Thread
|
||||||
from threading import Event as ThreadEvent
|
from threading import Event as ThreadEvent
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Callable, Optional, TypeAlias
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.typing import Event as FastAPIEvent
|
from fastapi_events.typing import Event as FastAPIEvent
|
||||||
@ -11,6 +11,14 @@ from fastapi_events.typing import Event as FastAPIEvent
|
|||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||||
|
from invokeai.app.services.session_processor.session_processor_base import (
|
||||||
|
OnAfterRunNode,
|
||||||
|
OnAfterRunSession,
|
||||||
|
OnBeforeRunNode,
|
||||||
|
OnBeforeRunSession,
|
||||||
|
OnNodeError,
|
||||||
|
OnNonFatalProcessorError,
|
||||||
|
)
|
||||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||||
@ -20,13 +28,6 @@ from ..invoker import Invoker
|
|||||||
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
|
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
|
||||||
from .session_processor_common import SessionProcessorStatus
|
from .session_processor_common import SessionProcessorStatus
|
||||||
|
|
||||||
OnBeforeRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem], bool]
|
|
||||||
OnAfterRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, BaseInvocationOutput], bool]
|
|
||||||
OnNodeError: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, type, BaseException, TracebackType], bool]
|
|
||||||
OnBeforeRunSession: TypeAlias = Callable[[SessionQueueItem], bool]
|
|
||||||
OnAfterRunSession: TypeAlias = Callable[[SessionQueueItem], bool]
|
|
||||||
OnNonFatalProcessorError: TypeAlias = Callable[[Optional[SessionQueueItem], type, BaseException, TracebackType], bool]
|
|
||||||
|
|
||||||
|
|
||||||
def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str:
|
def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str:
|
||||||
return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
|
return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
|
||||||
@ -49,16 +50,18 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
self.on_node_error = on_node_error
|
self.on_node_error = on_node_error
|
||||||
self.on_after_run_session = on_after_run_session
|
self.on_after_run_session = on_after_run_session
|
||||||
|
|
||||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
|
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
|
||||||
"""Start the session runner"""
|
"""Start the session runner"""
|
||||||
self.services = services
|
self.services = services
|
||||||
self.cancel_event = cancel_event
|
self.cancel_event = cancel_event
|
||||||
|
self.profiler = profiler
|
||||||
|
|
||||||
def run(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None):
|
def run(self, queue_item: SessionQueueItem):
|
||||||
"""Run the graph"""
|
"""Run the graph"""
|
||||||
# Loop over invocations until the session is complete or canceled
|
# Loop over invocations until the session is complete or canceled
|
||||||
|
|
||||||
self._on_before_run_session(queue_item=queue_item)
|
self._on_before_run_session(queue_item=queue_item)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
invocation = queue_item.session.next()
|
invocation = queue_item.session.next()
|
||||||
if invocation is None or self.cancel_event.is_set():
|
if invocation is None or self.cancel_event.is_set():
|
||||||
@ -66,20 +69,21 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
self.run_node(invocation, queue_item)
|
self.run_node(invocation, queue_item)
|
||||||
if queue_item.session.is_complete() or self.cancel_event.is_set():
|
if queue_item.session.is_complete() or self.cancel_event.is_set():
|
||||||
break
|
break
|
||||||
|
|
||||||
self._on_after_run_session(queue_item=queue_item)
|
self._on_after_run_session(queue_item=queue_item)
|
||||||
|
|
||||||
def _on_before_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None:
|
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||||
# If profiling is enabled, start the profiler
|
# If profiling is enabled, start the profiler
|
||||||
if profiler is not None:
|
if self.profiler is not None:
|
||||||
profiler.start(profile_id=queue_item.session_id)
|
self.profiler.start(profile_id=queue_item.session_id)
|
||||||
|
|
||||||
if self.on_before_run_session:
|
if self.on_before_run_session:
|
||||||
self.on_before_run_session(queue_item)
|
self.on_before_run_session(queue_item=queue_item)
|
||||||
|
|
||||||
def _on_after_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None:
|
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||||
# If we are profiling, stop the profiler and dump the profile & stats
|
# If we are profiling, stop the profiler and dump the profile & stats
|
||||||
if profiler:
|
if self.profiler is not None:
|
||||||
profile_path = profiler.stop()
|
profile_path = self.profiler.stop()
|
||||||
stats_path = profile_path.with_suffix(".json")
|
stats_path = profile_path.with_suffix(".json")
|
||||||
self.services.performance_statistics.dump_stats(
|
self.services.performance_statistics.dump_stats(
|
||||||
graph_execution_state_id=queue_item.session.id, output_path=stats_path
|
graph_execution_state_id=queue_item.session.id, output_path=stats_path
|
||||||
@ -221,11 +225,15 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
self,
|
self,
|
||||||
session_runner: Optional[SessionRunnerBase] = None,
|
session_runner: Optional[SessionRunnerBase] = None,
|
||||||
on_non_fatal_processor_error: Optional[OnNonFatalProcessorError] = None,
|
on_non_fatal_processor_error: Optional[OnNonFatalProcessorError] = None,
|
||||||
|
thread_limit: int = 1,
|
||||||
|
polling_interval: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
|
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
|
||||||
self.on_non_fatal_processor_error = on_non_fatal_processor_error
|
self.on_non_fatal_processor_error = on_non_fatal_processor_error
|
||||||
|
self._thread_limit = thread_limit
|
||||||
|
self._polling_interval = polling_interval
|
||||||
|
|
||||||
def _on_non_fatal_processor_error(
|
def _on_non_fatal_processor_error(
|
||||||
self,
|
self,
|
||||||
@ -243,14 +251,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
|
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
|
||||||
|
|
||||||
if self.on_non_fatal_processor_error:
|
if self.on_non_fatal_processor_error:
|
||||||
self.on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback)
|
self.on_non_fatal_processor_error(exc_type, exc_value, exc_traceback, queue_item)
|
||||||
|
|
||||||
def start(
|
def start(self, invoker: Invoker) -> None:
|
||||||
self,
|
|
||||||
invoker: Invoker,
|
|
||||||
thread_limit: int = 1,
|
|
||||||
polling_interval: int = 1,
|
|
||||||
) -> None:
|
|
||||||
self._invoker: Invoker = invoker
|
self._invoker: Invoker = invoker
|
||||||
self._queue_item: Optional[SessionQueueItem] = None
|
self._queue_item: Optional[SessionQueueItem] = None
|
||||||
self._invocation: Optional[BaseInvocation] = None
|
self._invocation: Optional[BaseInvocation] = None
|
||||||
@ -262,9 +265,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
|
|
||||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
||||||
|
|
||||||
self._thread_limit = thread_limit
|
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
|
||||||
self._thread_semaphore = BoundedSemaphore(thread_limit)
|
|
||||||
self._polling_interval = polling_interval
|
|
||||||
|
|
||||||
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
|
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
|
||||||
# the profiler will create a new profile for each session.
|
# the profiler will create a new profile for each session.
|
||||||
@ -278,7 +279,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
|
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler)
|
||||||
self._thread = Thread(
|
self._thread = Thread(
|
||||||
name="session_processor",
|
name="session_processor",
|
||||||
target=self._process,
|
target=self._process,
|
||||||
|
Loading…
Reference in New Issue
Block a user