mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): making invocation class var in processor
This commit is contained in:
parent
62199b0fb1
commit
198ed222c4
@ -7,6 +7,7 @@ from typing import Optional
|
|||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.typing import Event as FastAPIEvent
|
from fastapi_events.typing import Event as FastAPIEvent
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||||
@ -23,6 +24,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
|
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
|
||||||
self._invoker: Invoker = invoker
|
self._invoker: Invoker = invoker
|
||||||
self._queue_item: Optional[SessionQueueItem] = None
|
self._queue_item: Optional[SessionQueueItem] = None
|
||||||
|
self._invocation: Optional[BaseInvocation] = None
|
||||||
|
|
||||||
self._resume_event = ThreadEvent()
|
self._resume_event = ThreadEvent()
|
||||||
self._stop_event = ThreadEvent()
|
self._stop_event = ThreadEvent()
|
||||||
@ -134,12 +136,12 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
profiler.start(profile_id=self._queue_item.session_id)
|
profiler.start(profile_id=self._queue_item.session_id)
|
||||||
|
|
||||||
# Prepare invocations and take the first
|
# Prepare invocations and take the first
|
||||||
invocation = self._queue_item.session.next()
|
self._invocation = self._queue_item.session.next()
|
||||||
|
|
||||||
# Loop over invocations until the session is complete or canceled
|
# Loop over invocations until the session is complete or canceled
|
||||||
while invocation is not None and not cancel_event.is_set():
|
while self._invocation is not None and not cancel_event.is_set():
|
||||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[invocation.id]
|
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
||||||
|
|
||||||
# Send starting event
|
# Send starting event
|
||||||
self._invoker.services.events.emit_invocation_started(
|
self._invoker.services.events.emit_invocation_started(
|
||||||
@ -147,18 +149,18 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
queue_item_id=self._queue_item.item_id,
|
queue_item_id=self._queue_item.item_id,
|
||||||
queue_id=self._queue_item.queue_id,
|
queue_id=self._queue_item.queue_id,
|
||||||
graph_execution_state_id=self._queue_item.session_id,
|
graph_execution_state_id=self._queue_item.session_id,
|
||||||
node=invocation.model_dump(),
|
node=self._invocation.model_dump(),
|
||||||
source_node_id=source_invocation_id,
|
source_node_id=source_invocation_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||||
try:
|
try:
|
||||||
with self._invoker.services.performance_statistics.collect_stats(
|
with self._invoker.services.performance_statistics.collect_stats(
|
||||||
invocation, self._queue_item.session.id
|
self._invocation, self._queue_item.session.id
|
||||||
):
|
):
|
||||||
# Build invocation context (the node-facing API)
|
# Build invocation context (the node-facing API)
|
||||||
data = InvocationContextData(
|
data = InvocationContextData(
|
||||||
invocation=invocation,
|
invocation=self._invocation,
|
||||||
source_invocation_id=source_invocation_id,
|
source_invocation_id=source_invocation_id,
|
||||||
queue_item=self._queue_item,
|
queue_item=self._queue_item,
|
||||||
)
|
)
|
||||||
@ -169,12 +171,12 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Invoke the node
|
# Invoke the node
|
||||||
outputs = invocation.invoke_internal(
|
outputs = self._invocation.invoke_internal(
|
||||||
context=context, services=self._invoker.services
|
context=context, services=self._invoker.services
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save outputs and history
|
# Save outputs and history
|
||||||
self._queue_item.session.complete(invocation.id, outputs)
|
self._queue_item.session.complete(self._invocation.id, outputs)
|
||||||
|
|
||||||
# Send complete event
|
# Send complete event
|
||||||
self._invoker.services.events.emit_invocation_complete(
|
self._invoker.services.events.emit_invocation_complete(
|
||||||
@ -182,7 +184,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
queue_item_id=self._queue_item.item_id,
|
queue_item_id=self._queue_item.item_id,
|
||||||
queue_id=self._queue_item.queue_id,
|
queue_id=self._queue_item.queue_id,
|
||||||
graph_execution_state_id=self._queue_item.session.id,
|
graph_execution_state_id=self._queue_item.session.id,
|
||||||
node=invocation.model_dump(),
|
node=self._invocation.model_dump(),
|
||||||
source_node_id=source_invocation_id,
|
source_node_id=source_invocation_id,
|
||||||
result=outputs.model_dump(),
|
result=outputs.model_dump(),
|
||||||
)
|
)
|
||||||
@ -208,9 +210,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
error = traceback.format_exc()
|
error = traceback.format_exc()
|
||||||
|
|
||||||
# Save error
|
# Save error
|
||||||
self._queue_item.session.set_node_error(invocation.id, error)
|
self._queue_item.session.set_node_error(self._invocation.id, error)
|
||||||
self._invoker.services.logger.error(
|
self._invoker.services.logger.error(
|
||||||
f"Error while invoking session {self._queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
|
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send error event
|
# Send error event
|
||||||
@ -219,7 +221,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
queue_item_id=self._queue_item.item_id,
|
queue_item_id=self._queue_item.item_id,
|
||||||
queue_id=self._queue_item.queue_id,
|
queue_id=self._queue_item.queue_id,
|
||||||
graph_execution_state_id=self._queue_item.session.id,
|
graph_execution_state_id=self._queue_item.session.id,
|
||||||
node=invocation.model_dump(),
|
node=self._invocation.model_dump(),
|
||||||
source_node_id=source_invocation_id,
|
source_node_id=source_invocation_id,
|
||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
error=error,
|
error=error,
|
||||||
@ -236,10 +238,10 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
)
|
)
|
||||||
# Save the stats and stop the profiler if it's running
|
# Save the stats and stop the profiler if it's running
|
||||||
stats_cleanup(self._queue_item.session.id)
|
stats_cleanup(self._queue_item.session.id)
|
||||||
invocation = None
|
self._invocation = None
|
||||||
else:
|
else:
|
||||||
# Prepare the next invocation
|
# Prepare the next invocation
|
||||||
invocation = self._queue_item.session.next()
|
self._invocation = self._queue_item.session.next()
|
||||||
|
|
||||||
# The session is complete, immediately poll for next session
|
# The session is complete, immediately poll for next session
|
||||||
self._queue_item = None
|
self._queue_item = None
|
||||||
|
Loading…
Reference in New Issue
Block a user