feat(nodes): free gpu mem after invocation

This commit is contained in:
psychedelicious 2023-05-03 19:27:06 +10:00
parent f7bbc4004a
commit a75148cb16

View File

@ -1,11 +1,15 @@
import gc
import traceback
from threading import Event, Thread, BoundedSemaphore
import torch
from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker
from ..models.exceptions import CanceledException
class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread
__stop_event: Event
@ -22,9 +26,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
target=self.__process,
kwargs=dict(stop_event=self.__stop_event),
)
self.__invoker_thread.daemon = (
True # TODO: make async and do not use threads
)
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
self.__invoker_thread.start()
def stop(self, *args, **kwargs) -> None:
@ -48,13 +50,15 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
)
# get the source node id to provide to clients (the prepared node id is not as useful)
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
source_node_id = graph_execution_state.prepared_source_mapping[
invocation.id
]
# Send starting event
self.__invoker.services.events.emit_invocation_started(
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id
source_node_id=source_node_id,
)
# Invoke
@ -114,11 +118,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
)
pass
finally:
gc.collect()
torch.cuda.empty_cache()
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(
graph_execution_state.id
):
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
continue
# Queue any further commands if invoking all