mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Reduce the number of graph_execution_manager.get(...) calls from the InvocationStatsService.
This commit is contained in:
parent
ac42513da9
commit
aa45d21fd2
@ -132,7 +132,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
result=outputs.model_dump(),
|
result=outputs.model_dump(),
|
||||||
)
|
)
|
||||||
self.__invoker.services.performance_statistics.log_stats()
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
@ -195,6 +194,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
error=traceback.format_exc(),
|
error=traceback.format_exc(),
|
||||||
)
|
)
|
||||||
elif is_complete:
|
elif is_complete:
|
||||||
|
self.__invoker.services.performance_statistics.log_stats(graph_execution_state.id)
|
||||||
self.__invoker.services.events.emit_graph_execution_complete(
|
self.__invoker.services.events.emit_graph_execution_complete(
|
||||||
queue_batch_id=queue_item.session_queue_batch_id,
|
queue_batch_id=queue_item.session_queue_batch_id,
|
||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
|
@ -67,7 +67,7 @@ class InvocationStatsServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def log_stats(self):
|
def log_stats(self, graph_execution_state_id: str):
|
||||||
"""
|
"""
|
||||||
Write out the accumulated statistics to the log or somewhere else.
|
Write out the accumulated statistics to the log or somewhere else.
|
||||||
"""
|
"""
|
||||||
|
@ -36,6 +36,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
self._stats[graph_execution_state_id] = GraphExecutionStats()
|
self._stats[graph_execution_state_id] = GraphExecutionStats()
|
||||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||||
|
|
||||||
|
# Prune stale stats. There should be none since we're starting a new graph, but just in case.
|
||||||
|
self._prune_stale_stats()
|
||||||
|
|
||||||
# Record state before the invocation.
|
# Record state before the invocation.
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
start_ram = psutil.Process().memory_info().rss
|
start_ram = psutil.Process().memory_info().rss
|
||||||
@ -59,49 +62,60 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
)
|
)
|
||||||
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
|
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
|
||||||
|
|
||||||
def reset_stats(self, graph_execution_id: str):
|
def _prune_stale_stats(self):
|
||||||
try:
|
"""Check all graphs being tracked and prune any that have completed/errored.
|
||||||
self._stats.pop(graph_execution_id)
|
|
||||||
except KeyError:
|
|
||||||
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
|
|
||||||
|
|
||||||
def log_stats(self):
|
This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so
|
||||||
completed = set()
|
for now we call this function periodically to prevent them from accumulating.
|
||||||
errored = set()
|
"""
|
||||||
for graph_id, _node_log in self._stats.items():
|
to_prune = []
|
||||||
|
for graph_execution_state_id in self._stats:
|
||||||
try:
|
try:
|
||||||
current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
|
graph_execution_state = self._invoker.services.graph_execution_manager.get(graph_execution_state_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
errored.add(graph_id)
|
# TODO(ryand): What would cause this? Should this exception just be allowed to propagate?
|
||||||
|
logger.warning(f"Failed to get graph state for {graph_execution_state_id}.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not current_graph_state.is_complete():
|
if not graph_execution_state.is_complete():
|
||||||
|
# The graph is still running, don't prune it.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
graph_stats = self._stats[graph_id]
|
to_prune.append(graph_execution_state_id)
|
||||||
log = graph_stats.get_pretty_log(graph_id)
|
|
||||||
|
|
||||||
cache_stats = self._cache_stats[graph_id]
|
for graph_execution_state_id in to_prune:
|
||||||
hwm = cache_stats.high_watermark / GB
|
del self._stats[graph_execution_state_id]
|
||||||
tot = cache_stats.cache_size / GB
|
del self._cache_stats[graph_execution_state_id]
|
||||||
loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GB
|
|
||||||
log += f"RAM used to load models: {loaded:4.2f}G\n"
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
log += f"VRAM in use: {(torch.cuda.memory_allocated() / GB):4.3f}G\n"
|
|
||||||
log += "RAM cache statistics:\n"
|
|
||||||
log += f" Model cache hits: {cache_stats.hits}\n"
|
|
||||||
log += f" Model cache misses: {cache_stats.misses}\n"
|
|
||||||
log += f" Models cached: {cache_stats.in_cache}\n"
|
|
||||||
log += f" Models cleared from cache: {cache_stats.cleared}\n"
|
|
||||||
log += f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G\n"
|
|
||||||
logger.info(log)
|
|
||||||
|
|
||||||
completed.add(graph_id)
|
if len(to_prune) > 0:
|
||||||
|
logger.info(f"Pruned stale graph stats for {to_prune}.")
|
||||||
|
|
||||||
for graph_id in completed:
|
def reset_stats(self, graph_execution_state_id: str):
|
||||||
del self._stats[graph_id]
|
try:
|
||||||
del self._cache_stats[graph_id]
|
del self._stats[graph_execution_state_id]
|
||||||
|
del self._cache_stats[graph_execution_state_id]
|
||||||
|
except KeyError as e:
|
||||||
|
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}.")
|
||||||
|
|
||||||
for graph_id in errored:
|
def log_stats(self, graph_execution_state_id: str):
|
||||||
del self._stats[graph_id]
|
graph_stats = self._stats[graph_execution_state_id]
|
||||||
del self._cache_stats[graph_id]
|
cache_stats = self._cache_stats[graph_execution_state_id]
|
||||||
|
|
||||||
|
log = graph_stats.get_pretty_log(graph_execution_state_id)
|
||||||
|
|
||||||
|
hwm = cache_stats.high_watermark / GB
|
||||||
|
tot = cache_stats.cache_size / GB
|
||||||
|
loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GB
|
||||||
|
log += f"RAM used to load models: {loaded:4.2f}G\n"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
log += f"VRAM in use: {(torch.cuda.memory_allocated() / GB):4.3f}G\n"
|
||||||
|
log += "RAM cache statistics:\n"
|
||||||
|
log += f" Model cache hits: {cache_stats.hits}\n"
|
||||||
|
log += f" Model cache misses: {cache_stats.misses}\n"
|
||||||
|
log += f" Models cached: {cache_stats.in_cache}\n"
|
||||||
|
log += f" Models cleared from cache: {cache_stats.cleared}\n"
|
||||||
|
log += f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G\n"
|
||||||
|
logger.info(log)
|
||||||
|
|
||||||
|
del self._stats[graph_execution_state_id]
|
||||||
|
del self._cache_stats[graph_execution_state_id]
|
||||||
|
Loading…
Reference in New Issue
Block a user