mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
d14a7d756e
On hyperthreaded CPUs we get two threads operating on the queue by default on each core. This cases two threads to process queue items. This results in pytorch errors and sometimes generates garbage. Locking this to single thread makes sense because we are bound by the number of GPUs in the system, not by CPU cores. And to parallelize across GPUs we should just start multiple processors (and use async instead of threading) Fixes #3289
137 lines
5.2 KiB
Python
137 lines
5.2 KiB
Python
import traceback
|
|
from threading import Event, Thread, BoundedSemaphore
|
|
|
|
from ..invocations.baseinvocation import InvocationContext
|
|
from .invocation_queue import InvocationQueueItem
|
|
from .invoker import InvocationProcessorABC, Invoker
|
|
from ..models.exceptions import CanceledException
|
|
|
|
class DefaultInvocationProcessor(InvocationProcessorABC):
|
|
__invoker_thread: Thread
|
|
__stop_event: Event
|
|
__invoker: Invoker
|
|
__threadLimit: BoundedSemaphore
|
|
|
|
def start(self, invoker) -> None:
|
|
# if we do want multithreading at some point, we could make this configurable
|
|
self.__threadLimit = BoundedSemaphore(1)
|
|
self.__invoker = invoker
|
|
self.__stop_event = Event()
|
|
self.__invoker_thread = Thread(
|
|
name="invoker_processor",
|
|
target=self.__process,
|
|
kwargs=dict(stop_event=self.__stop_event),
|
|
)
|
|
self.__invoker_thread.daemon = (
|
|
True # TODO: make async and do not use threads
|
|
)
|
|
self.__invoker_thread.start()
|
|
|
|
def stop(self, *args, **kwargs) -> None:
|
|
self.__stop_event.set()
|
|
|
|
def __process(self, stop_event: Event):
|
|
try:
|
|
self.__threadLimit.acquire()
|
|
while not stop_event.is_set():
|
|
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
|
if not queue_item: # Probably stopping
|
|
continue
|
|
|
|
graph_execution_state = (
|
|
self.__invoker.services.graph_execution_manager.get(
|
|
queue_item.graph_execution_state_id
|
|
)
|
|
)
|
|
invocation = graph_execution_state.execution_graph.get_node(
|
|
queue_item.invocation_id
|
|
)
|
|
|
|
# get the source node id to provide to clients (the prepared node id is not as useful)
|
|
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
|
|
|
# Send starting event
|
|
self.__invoker.services.events.emit_invocation_started(
|
|
graph_execution_state_id=graph_execution_state.id,
|
|
node=invocation.dict(),
|
|
source_node_id=source_node_id
|
|
)
|
|
|
|
# Invoke
|
|
try:
|
|
outputs = invocation.invoke(
|
|
InvocationContext(
|
|
services=self.__invoker.services,
|
|
graph_execution_state_id=graph_execution_state.id,
|
|
)
|
|
)
|
|
|
|
# Check queue to see if this is canceled, and skip if so
|
|
if self.__invoker.services.queue.is_canceled(
|
|
graph_execution_state.id
|
|
):
|
|
continue
|
|
|
|
# Save outputs and history
|
|
graph_execution_state.complete(invocation.id, outputs)
|
|
|
|
# Save the state changes
|
|
self.__invoker.services.graph_execution_manager.set(
|
|
graph_execution_state
|
|
)
|
|
|
|
# Send complete event
|
|
self.__invoker.services.events.emit_invocation_complete(
|
|
graph_execution_state_id=graph_execution_state.id,
|
|
node=invocation.dict(),
|
|
source_node_id=source_node_id,
|
|
result=outputs.dict(),
|
|
)
|
|
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
except CanceledException:
|
|
pass
|
|
|
|
except Exception as e:
|
|
error = traceback.format_exc()
|
|
|
|
# Save error
|
|
graph_execution_state.set_node_error(invocation.id, error)
|
|
|
|
# Save the state changes
|
|
self.__invoker.services.graph_execution_manager.set(
|
|
graph_execution_state
|
|
)
|
|
|
|
# Send error event
|
|
self.__invoker.services.events.emit_invocation_error(
|
|
graph_execution_state_id=graph_execution_state.id,
|
|
node=invocation.dict(),
|
|
source_node_id=source_node_id,
|
|
error=error,
|
|
)
|
|
|
|
pass
|
|
|
|
# Check queue to see if this is canceled, and skip if so
|
|
if self.__invoker.services.queue.is_canceled(
|
|
graph_execution_state.id
|
|
):
|
|
continue
|
|
|
|
# Queue any further commands if invoking all
|
|
is_complete = graph_execution_state.is_complete()
|
|
if queue_item.invoke_all and not is_complete:
|
|
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
|
elif is_complete:
|
|
self.__invoker.services.events.emit_graph_execution_complete(
|
|
graph_execution_state.id
|
|
)
|
|
|
|
except KeyboardInterrupt:
|
|
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
|
finally:
|
|
self.__threadLimit.release()
|