mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
4b334be7d0
When a queue item is popped for processing, we need to retrieve its session from the DB. Pydantic serializes the graph at this stage. It's possible for a graph to have been made invalid during the graph preparation stage (e.g. an ancestor node executes, and its output is not valid for its successor node's input field). When this occurs, the session in the DB will fail validation, but we don't have a chance to find out until it is retrieved and parsed by pydantic. This logic was previously not wrapped in any exception handling. Just after retrieving a session, we retrieve the specific invocation to execute from the session. It's possible that this could also have some sort of error, though it should be impossible for it to be a pydantic validation error (that would have been caught during session validation). There was also no exception handling here. When either of these processes fail, the processor gets soft-locked because the processor's cleanup logic is never run. (I didn't dig deeper into exactly what cleanup is not happening, because the fix is to just handle the exceptions.) This PR adds exception handling to both the session retrieval and node retrieval and events for each: `session_retrieval_error` and `invocation_retrieval_error`. These events are caught and displayed in the UI as toasts, along with the type of the python exception (e.g. `Validation Error`). The events are also logged to the browser console.
178 lines
7.3 KiB
Python
178 lines
7.3 KiB
Python
import time
|
|
import traceback
|
|
from threading import Event, Thread, BoundedSemaphore
|
|
|
|
from ..invocations.baseinvocation import InvocationContext
|
|
from .invocation_queue import InvocationQueueItem
|
|
from .invoker import InvocationProcessorABC, Invoker
|
|
from ..models.exceptions import CanceledException
|
|
|
|
import invokeai.backend.util.logging as logger
|
|
class DefaultInvocationProcessor(InvocationProcessorABC):
|
|
__invoker_thread: Thread
|
|
__stop_event: Event
|
|
__invoker: Invoker
|
|
__threadLimit: BoundedSemaphore
|
|
|
|
def start(self, invoker) -> None:
|
|
# if we do want multithreading at some point, we could make this configurable
|
|
self.__threadLimit = BoundedSemaphore(1)
|
|
self.__invoker = invoker
|
|
self.__stop_event = Event()
|
|
self.__invoker_thread = Thread(
|
|
name="invoker_processor",
|
|
target=self.__process,
|
|
kwargs=dict(stop_event=self.__stop_event),
|
|
)
|
|
self.__invoker_thread.daemon = (
|
|
True # TODO: make async and do not use threads
|
|
)
|
|
self.__invoker_thread.start()
|
|
|
|
def stop(self, *args, **kwargs) -> None:
|
|
self.__stop_event.set()
|
|
|
|
def __process(self, stop_event: Event):
|
|
try:
|
|
self.__threadLimit.acquire()
|
|
while not stop_event.is_set():
|
|
try:
|
|
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
|
except Exception as e:
|
|
self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
|
|
|
|
if not queue_item: # Probably stopping
|
|
# do not hammer the queue
|
|
time.sleep(0.5)
|
|
continue
|
|
|
|
try:
|
|
graph_execution_state = (
|
|
self.__invoker.services.graph_execution_manager.get(
|
|
queue_item.graph_execution_state_id
|
|
)
|
|
)
|
|
except Exception as e:
|
|
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
|
|
self.__invoker.services.events.emit_session_retrieval_error(
|
|
graph_execution_state_id=queue_item.graph_execution_state_id,
|
|
error_type=e.__class__.__name__,
|
|
error=traceback.format_exc(),
|
|
)
|
|
continue
|
|
|
|
try:
|
|
invocation = graph_execution_state.execution_graph.get_node(
|
|
queue_item.invocation_id
|
|
)
|
|
except Exception as e:
|
|
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
|
|
self.__invoker.services.events.emit_invocation_retrieval_error(
|
|
graph_execution_state_id=queue_item.graph_execution_state_id,
|
|
node_id=queue_item.invocation_id,
|
|
error_type=e.__class__.__name__,
|
|
error=traceback.format_exc(),
|
|
)
|
|
continue
|
|
|
|
# get the source node id to provide to clients (the prepared node id is not as useful)
|
|
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
|
|
|
# Send starting event
|
|
self.__invoker.services.events.emit_invocation_started(
|
|
graph_execution_state_id=graph_execution_state.id,
|
|
node=invocation.dict(),
|
|
source_node_id=source_node_id
|
|
)
|
|
|
|
# Invoke
|
|
try:
|
|
outputs = invocation.invoke(
|
|
InvocationContext(
|
|
services=self.__invoker.services,
|
|
graph_execution_state_id=graph_execution_state.id,
|
|
)
|
|
)
|
|
|
|
# Check queue to see if this is canceled, and skip if so
|
|
if self.__invoker.services.queue.is_canceled(
|
|
graph_execution_state.id
|
|
):
|
|
continue
|
|
|
|
# Save outputs and history
|
|
graph_execution_state.complete(invocation.id, outputs)
|
|
|
|
# Save the state changes
|
|
self.__invoker.services.graph_execution_manager.set(
|
|
graph_execution_state
|
|
)
|
|
|
|
# Send complete event
|
|
self.__invoker.services.events.emit_invocation_complete(
|
|
graph_execution_state_id=graph_execution_state.id,
|
|
node=invocation.dict(),
|
|
source_node_id=source_node_id,
|
|
result=outputs.dict(),
|
|
)
|
|
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
except CanceledException:
|
|
pass
|
|
|
|
except Exception as e:
|
|
error = traceback.format_exc()
|
|
logger.error(error)
|
|
|
|
# Save error
|
|
graph_execution_state.set_node_error(invocation.id, error)
|
|
|
|
# Save the state changes
|
|
self.__invoker.services.graph_execution_manager.set(
|
|
graph_execution_state
|
|
)
|
|
|
|
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
|
# Send error event
|
|
self.__invoker.services.events.emit_invocation_error(
|
|
graph_execution_state_id=graph_execution_state.id,
|
|
node=invocation.dict(),
|
|
source_node_id=source_node_id,
|
|
error_type=e.__class__.__name__,
|
|
error=error,
|
|
)
|
|
|
|
pass
|
|
|
|
# Check queue to see if this is canceled, and skip if so
|
|
if self.__invoker.services.queue.is_canceled(
|
|
graph_execution_state.id
|
|
):
|
|
continue
|
|
|
|
# Queue any further commands if invoking all
|
|
is_complete = graph_execution_state.is_complete()
|
|
if queue_item.invoke_all and not is_complete:
|
|
try:
|
|
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
|
except Exception as e:
|
|
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
|
self.__invoker.services.events.emit_invocation_error(
|
|
graph_execution_state_id=graph_execution_state.id,
|
|
node=invocation.dict(),
|
|
source_node_id=source_node_id,
|
|
error_type=e.__class__.__name__,
|
|
error=traceback.format_exc()
|
|
)
|
|
elif is_complete:
|
|
self.__invoker.services.events.emit_graph_execution_complete(
|
|
graph_execution_state.id
|
|
)
|
|
|
|
except KeyboardInterrupt:
|
|
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
|
finally:
|
|
self.__threadLimit.release()
|