import time import traceback from threading import Event, Thread, BoundedSemaphore from typing import Any, TypeGuard from invokeai.app.invocations.image import ImageOutput from invokeai.app.models.image import ImageType from ..invocations.baseinvocation import InvocationContext from .invocation_queue import InvocationQueueItem from .invoker import InvocationProcessorABC, Invoker from ..models.exceptions import CanceledException import invokeai.backend.util.logging as logger class DefaultInvocationProcessor(InvocationProcessorABC): __invoker_thread: Thread __stop_event: Event __invoker: Invoker __threadLimit: BoundedSemaphore def start(self, invoker) -> None: # if we do want multithreading at some point, we could make this configurable self.__threadLimit = BoundedSemaphore(1) self.__invoker = invoker self.__stop_event = Event() self.__invoker_thread = Thread( name="invoker_processor", target=self.__process, kwargs=dict(stop_event=self.__stop_event), ) self.__invoker_thread.daemon = ( True # TODO: make async and do not use threads ) self.__invoker_thread.start() def stop(self, *args, **kwargs) -> None: self.__stop_event.set() def __process(self, stop_event: Event): try: self.__threadLimit.acquire() while not stop_event.is_set(): try: queue_item: InvocationQueueItem = self.__invoker.services.queue.get() except Exception as e: logger.debug("Exception while getting from queue: %s" % e) if not queue_item: # Probably stopping # do not hammer the queue time.sleep(0.5) continue graph_execution_state = ( self.__invoker.services.graph_execution_manager.get( queue_item.graph_execution_state_id ) ) invocation = graph_execution_state.execution_graph.get_node( queue_item.invocation_id ) # get the source node id to provide to clients (the prepared node id is not as useful) source_node_id = graph_execution_state.prepared_source_mapping[invocation.id] # Send starting event self.__invoker.services.events.emit_invocation_started( graph_execution_state_id=graph_execution_state.id, node=invocation.dict(), source_node_id=source_node_id ) # Invoke try: outputs = invocation.invoke( InvocationContext( services=self.__invoker.services, graph_execution_state_id=graph_execution_state.id, ) ) # Check queue to see if this is canceled, and skip if so if self.__invoker.services.queue.is_canceled( graph_execution_state.id ): continue # Save outputs and history graph_execution_state.complete(invocation.id, outputs) # Save the state changes self.__invoker.services.graph_execution_manager.set( graph_execution_state ) def is_image_output(obj: Any) -> TypeGuard[ImageOutput]: return obj.__class__ == ImageOutput outputs_dict = outputs.dict() if is_image_output(outputs): image_url = self.__invoker.services.images_new.get_url( ImageType.RESULT, outputs.image.image_name ) thumbnail_url = self.__invoker.services.images_new.get_url( ImageType.RESULT, outputs.image.image_name, True ) else: image_url = None thumbnail_url = None # Send complete event self.__invoker.services.events.emit_invocation_complete( graph_execution_state_id=graph_execution_state.id, node=invocation.dict(), source_node_id=source_node_id, result=outputs_dict, image_url=image_url, thumbnail_url=thumbnail_url, ) except KeyboardInterrupt: pass except CanceledException: pass except Exception as e: error = traceback.format_exc() # Save error graph_execution_state.set_node_error(invocation.id, error) # Save the state changes self.__invoker.services.graph_execution_manager.set( graph_execution_state ) # Send error event self.__invoker.services.events.emit_invocation_error( graph_execution_state_id=graph_execution_state.id, node=invocation.dict(), source_node_id=source_node_id, error=error, ) pass # Check queue to see if this is canceled, and skip if so if self.__invoker.services.queue.is_canceled( graph_execution_state.id ): continue # Queue any further commands if invoking all is_complete = graph_execution_state.is_complete() if queue_item.invoke_all and not is_complete: try: self.__invoker.invoke(graph_execution_state, invoke_all=True) except Exception as e: logger.error("Error while invoking: %s" % e) self.__invoker.services.events.emit_invocation_error( graph_execution_state_id=graph_execution_state.id, node=invocation.dict(), source_node_id=source_node_id, error=traceback.format_exc() ) elif is_complete: self.__invoker.services.events.emit_graph_execution_complete( graph_execution_state.id ) except KeyboardInterrupt: pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor finally: self.__threadLimit.release()