Compare commits

...

1 Commits

Author SHA1 Message Date
e745e6d080 feat(nodes): free gpu mem after invocation 2023-05-03 19:28:11 +10:00

View File

@ -1,11 +1,15 @@
import gc
import traceback import traceback
from threading import Event, Thread, BoundedSemaphore from threading import Event, Thread, BoundedSemaphore
import torch
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker from .invoker import InvocationProcessorABC, Invoker
from ..models.exceptions import CanceledException from ..models.exceptions import CanceledException
class DefaultInvocationProcessor(InvocationProcessorABC): class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread __invoker_thread: Thread
__stop_event: Event __stop_event: Event
@ -22,9 +26,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
target=self.__process, target=self.__process,
kwargs=dict(stop_event=self.__stop_event), kwargs=dict(stop_event=self.__stop_event),
) )
self.__invoker_thread.daemon = ( self.__invoker_thread.daemon = True # TODO: make async and do not use threads
True # TODO: make async and do not use threads
)
self.__invoker_thread.start() self.__invoker_thread.start()
def stop(self, *args, **kwargs) -> None: def stop(self, *args, **kwargs) -> None:
@ -48,13 +50,15 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
) )
# get the source node id to provide to clients (the prepared node id is not as useful) # get the source node id to provide to clients (the prepared node id is not as useful)
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id] source_node_id = graph_execution_state.prepared_source_mapping[
invocation.id
]
# Send starting event # Send starting event
self.__invoker.services.events.emit_invocation_started( self.__invoker.services.events.emit_invocation_started(
graph_execution_state_id=graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(), node=invocation.dict(),
source_node_id=source_node_id source_node_id=source_node_id,
) )
# Invoke # Invoke
@ -114,11 +118,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
) )
pass pass
finally:
gc.collect()
torch.cuda.empty_cache()
# Check queue to see if this is canceled, and skip if so # Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled( if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
graph_execution_state.id
):
continue continue
# Queue any further commands if invoking all # Queue any further commands if invoking all