From 78dd4603484a6baf4fc25b4e35ebc8c970df6a08 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 19 Feb 2024 13:09:00 +1100 Subject: [PATCH] tidy(nodes): clean up profiler/stats in processor, better comments --- .../session_processor_default.py | 65 ++++++++++--------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 9ba726bff3..cff7bb6c6c 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -37,6 +37,18 @@ class DefaultSessionProcessor(SessionProcessorBase): self._thread_semaphore = BoundedSemaphore(thread_limit) self._polling_interval = polling_interval + # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, + # the profiler will create a new profile for each session. + self._profiler = ( + Profiler( + logger=self._invoker.services.logger, + output_dir=self._invoker.services.configuration.profiles_path, + prefix=self._invoker.services.configuration.profile_prefix, + ) + if self._invoker.services.configuration.profile_graphs + else None + ) + self._thread = Thread( name="session_processor", target=self._process, @@ -95,32 +107,6 @@ class DefaultSessionProcessor(SessionProcessorBase): resume_event.set() cancel_event.clear() - # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, - # the profiler will create a new profile for each session. - profiler = ( - Profiler( - logger=self._invoker.services.logger, - output_dir=self._invoker.services.configuration.profiles_path, - prefix=self._invoker.services.configuration.profile_prefix, - ) - if self._invoker.services.configuration.profile_graphs - else None - ) - - # Helper function to stop the profiler and save the stats - def stats_cleanup(graph_execution_state_id: str) -> None: - if profiler: - profile_path = profiler.stop() - stats_path = profile_path.with_suffix(".json") - self._invoker.services.performance_statistics.dump_stats( - graph_execution_state_id=graph_execution_state_id, output_path=stats_path - ) - # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor - # we don't care about that - suppress the error. - with suppress(GESStatsNotFoundError): - self._invoker.services.performance_statistics.log_stats(graph_execution_state_id) - self._invoker.services.performance_statistics.reset_stats() - while not stop_event.is_set(): poll_now_event.clear() # Middle processor try block; any unhandled exception is a non-fatal processor error @@ -132,8 +118,8 @@ class DefaultSessionProcessor(SessionProcessorBase): cancel_event.clear() # If profiling is enabled, start the profiler - if profiler is not None: - profiler.start(profile_id=self._queue_item.session_id) + if self._profiler is not None: + self._profiler.start(profile_id=self._queue_item.session_id) # Prepare invocations and take the first self._invocation = self._queue_item.session.next() @@ -228,6 +214,7 @@ class DefaultSessionProcessor(SessionProcessorBase): ) pass + # The session is complete if the all invocations are complete or there was an error if self._queue_item.session.is_complete() or cancel_event.is_set(): # Send complete event self._invoker.services.events.emit_graph_execution_complete( @@ -236,8 +223,20 @@ class DefaultSessionProcessor(SessionProcessorBase): queue_id=self._queue_item.queue_id, graph_execution_state_id=self._queue_item.session.id, ) - # Save the stats and stop the profiler if it's running - stats_cleanup(self._queue_item.session.id) + # If we are profiling, stop the profiler and dump the profile & stats + if self._profiler: + profile_path = self._profiler.stop() + stats_path = profile_path.with_suffix(".json") + self._invoker.services.performance_statistics.dump_stats( + graph_execution_state_id=self._queue_item.session.id, output_path=stats_path + ) + # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor + # we don't care about that - suppress the error. + with suppress(GESStatsNotFoundError): + self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id) + self._invoker.services.performance_statistics.reset_stats() + + # Set the invocation to None to prepare for the next session self._invocation = None else: # Prepare the next invocation @@ -252,14 +251,18 @@ class DefaultSessionProcessor(SessionProcessorBase): poll_now_event.wait(self._polling_interval) continue except Exception: - # Non-fatal error in processor, cancel the queue item and wait for next polling interval or event + # Non-fatal error in processor self._invoker.services.logger.error( f"Non-fatal error in session processor:\n{traceback.format_exc()}" ) + # Cancel the queue item if self._queue_item is not None: self._invoker.services.session_queue.cancel_queue_item( self._queue_item.item_id, error=traceback.format_exc() ) + # Reset the invocation to None to prepare for the next session + self._invocation = None + # Immediately poll for next queue item poll_now_event.wait(self._polling_interval) continue except Exception: