integrate correctly into app API and add features

- Create abstract base class InvocationStatsServiceBase
- Store InvocationStatsService in the InvocationServices object
- Collect and report stats on simultaneous graph execution
  independently for each graph id
- Track VRAM usage for each node
- Handle cancellations and other exceptions gracefully
This commit is contained in:
Lincoln Stein
2023-08-02 18:10:52 -04:00
parent 8a4e5f73aa
commit 8fc75a71ee
4 changed files with 182 additions and 69 deletions

View File

@ -2,7 +2,6 @@
from typing import Optional from typing import Optional
from logging import Logger from logging import Logger
import os
from invokeai.app.services.board_image_record_storage import ( from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage, SqliteBoardImageRecordStorage,
) )
@ -30,6 +29,7 @@ from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage from ..services.sqlite import SqliteItemStorage
from ..services.model_manager_service import ModelManagerService from ..services.model_manager_service import ModelManagerService
from ..services.invocation_stats import InvocationStatsService
from .events import FastAPIEventService from .events import FastAPIEventService
@ -128,6 +128,7 @@ class ApiDependencies:
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
configuration=config, configuration=config,
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger, logger=logger,
) )

View File

@ -32,6 +32,7 @@ class InvocationServices:
logger: "Logger" logger: "Logger"
model_manager: "ModelManagerServiceBase" model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC" processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC" queue: "InvocationQueueABC"
def __init__( def __init__(
@ -47,6 +48,7 @@ class InvocationServices:
logger: "Logger", logger: "Logger",
model_manager: "ModelManagerServiceBase", model_manager: "ModelManagerServiceBase",
processor: "InvocationProcessorABC", processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC", queue: "InvocationQueueABC",
): ):
self.board_images = board_images self.board_images = board_images
@ -61,4 +63,5 @@ class InvocationServices:
self.logger = logger self.logger = logger
self.model_manager = model_manager self.model_manager = model_manager
self.processor = processor self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue self.queue = queue

View File

@ -3,99 +3,196 @@
""" """
Usage: Usage:
statistics = InvocationStats() # keep track of performance metrics
... statistics = InvocationStatsService(graph_execution_manager)
with statistics.collect_stats(invocation, graph_execution_state): with statistics.collect_stats(invocation, graph_execution_state.id):
outputs = invocation.invoke( ... execute graphs...
InvocationContext(
services=self.__invoker.services,
graph_execution_state_id=graph_execution_state.id,
)
)
...
statistics.log_stats() statistics.log_stats()
Typical output: Typical output:
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> Node Calls Seconds [2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> main_model_loader 1 0.006s [2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> clip_skip 1 0.005s [2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> compel 2 0.351s [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> rand_int 1 0.001s [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> range_of_size 1 0.001s [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> iterate 1 0.001s [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> noise 1 0.002s [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> t2l 1 3.117s [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> l2i 1 0.377s [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> TOTAL: 3.865s [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> Max VRAM used for execution: 3.12G. [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
[2023-08-01 17:34:44,586]::[InvokeAI]::INFO --> Current VRAM utilization 2.31G. [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
writes to the system log is stored in InvocationServices.performance_statistics.
""" """
import time import time
from typing import Dict, List from abc import ABC, abstractmethod
from contextlib import AbstractContextManager
from dataclasses import dataclass, field
from typing import Dict
import torch import torch
from .graph import GraphExecutionState
from .invocation_queue import InvocationQueueItem
from ..invocations.baseinvocation import BaseInvocation
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import BaseInvocation
from .graph import GraphExecutionState
from .item_storage import ItemStorageABC
class InvocationStats:
class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics"
@abstractmethod
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
"""
Initialize the InvocationStatsService and reset counters to zero
:param graph_execution_manager: Graph execution manager for this session
"""
pass
@abstractmethod
def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
) -> AbstractContextManager:
"""
Return a context object that will capture the statistics on the execution
of invocaation. Use with: to place around the part of the code that executes the invocation.
:param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session.
"""
pass
@abstractmethod
def reset_stats(self, graph_execution_state_id: str):
"""
Reset all statistics for the indicated graph
:param graph_execution_state_id
"""
pass
@abstractmethod
def reset_all_stats(self):
"""Zero all statistics"""
pass
@abstractmethod
def update_invocation_stats(
self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
"""
pass
@abstractmethod
def log_stats(self):
"""
Write out the accumulated statistics to the log or somewhere else.
"""
pass
@dataclass
class NodeStats:
"""Class for tracking execution stats of an invocation node"""
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
@dataclass
class NodeLog:
"""Class for tracking node usage"""
# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)
class InvocationStatsService(InvocationStatsServiceBase):
"""Accumulate performance information about a running graph. Collects time spent in each node, """Accumulate performance information about a running graph. Collects time spent in each node,
as well as the maximum and current VRAM utilisation for CUDA systems""" as well as the maximum and current VRAM utilisation for CUDA systems"""
def __init__(self): def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
self._stats: Dict[str, int] = {} self.graph_execution_manager = graph_execution_manager
# {graph_id => NodeLog}
self._stats: Dict[str, NodeLog] = {}
class StatsContext: class StatsContext:
def __init__(self, invocation: BaseInvocation, collector): def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"):
self.invocation = invocation self.invocation = invocation
self.collector = collector self.collector = collector
self.graph_id = graph_id
self.start_time = 0 self.start_time = 0
def __enter__(self): def __enter__(self):
self.start_time = time.time() self.start_time = time.time()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
def __exit__(self, *args): def __exit__(self, *args):
self.collector.log_time(self.invocation.type, time.time() - self.start_time) self.collector.update_invocation_stats(
self.graph_id,
self.invocation.type,
time.time() - self.start_time,
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
)
def collect_stats( def collect_stats(
self, self,
invocation: BaseInvocation, invocation: BaseInvocation,
graph_execution_state: GraphExecutionState, graph_execution_state_id: str,
) -> StatsContext: ) -> StatsContext:
""" """
Return a context object that will capture the statistics. Return a context object that will capture the statistics.
:param invocation: BaseInvocation object from the current graph. :param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session. :param graph_execution_state: GraphExecutionState object from the current session.
""" """
if len(graph_execution_state.executed) == 0: # new graph is starting if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self.reset_stats() self._stats[graph_execution_state_id] = NodeLog()
self._current_graph_state = graph_execution_state return self.StatsContext(invocation, graph_execution_state_id, self)
sc = self.StatsContext(invocation, self)
return self.StatsContext(invocation, self)
def reset_stats(self): def reset_all_stats(self):
"""Zero the statistics. Ordinarily called internally.""" """Zero all statistics"""
if torch.cuda.is_available(): self._stats = {}
torch.cuda.reset_peak_memory_stats()
self._stats: Dict[str, List[int, float]] = {}
def log_time(self, invocation_type: str, time_used: float): def reset_stats(self, graph_execution_id: str):
"""Zero the statistics for the indicated graph."""
try:
self._stats.pop(graph_execution_id)
except KeyError:
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float):
""" """
Add timing information on execution of a node. Usually Add timing information on execution of a node. Usually
used internally. used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node :param invocation_type: String literal type of the node
:param time_used: Floating point seconds used by node's exection :param time_used: Floating point seconds used by node's exection
""" """
if not self._stats.get(invocation_type): if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[invocation_type] = [0, 0.0] self._stats[graph_id].nodes[invocation_type] = NodeStats()
self._stats[invocation_type][0] += 1 stats = self._stats[graph_id].nodes[invocation_type]
self._stats[invocation_type][1] += time_used stats.calls += 1
stats.time_used += time_used
stats.max_vram = max(stats.max_vram, vram_used)
def log_stats(self): def log_stats(self):
""" """
@ -103,13 +200,24 @@ class InvocationStats:
Stats will only be printed if when the execution of the graph Stats will only be printed if when the execution of the graph
is complete. is complete.
""" """
if self._current_graph_state.is_complete(): completed = set()
logger.info("Node Calls Seconds") for graph_id, node_log in self._stats.items():
for node_type, (calls, time_used) in self._stats.items(): current_graph_state = self.graph_execution_manager.get(graph_id)
logger.info(f"{node_type:<20} {calls:>5} {time_used:4.3f}s") if not current_graph_state.is_complete():
continue
total_time = sum([ticks for _, ticks in self._stats.values()]) total_time = 0
logger.info(f"TOTAL: {total_time:4.3f}s") logger.info(f"Graph stats: {graph_id}")
logger.info("Node Calls Seconds VRAM Used")
for node_type, stats in self._stats[graph_id].nodes.items():
logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:4.3f}s {stats.max_vram:4.2f}G")
total_time += stats.time_used
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:4.3f}s")
if torch.cuda.is_available(): if torch.cuda.is_available():
logger.info("Max VRAM used for execution: " + "%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9))
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)) logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
completed.add(graph_id)
for graph_id in completed:
del self._stats[graph_id]

View File

@ -1,15 +1,15 @@
import time import time
import traceback import traceback
from threading import Event, Thread, BoundedSemaphore from threading import BoundedSemaphore, Event, Thread
from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker
from .invocation_stats import InvocationStats
from ..models.exceptions import CanceledException
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import InvocationContext
from ..models.exceptions import CanceledException
from .invocation_queue import InvocationQueueItem
from .invocation_stats import InvocationStatsServiceBase
from .invoker import InvocationProcessorABC, Invoker
class DefaultInvocationProcessor(InvocationProcessorABC): class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread __invoker_thread: Thread
@ -36,7 +36,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
def __process(self, stop_event: Event): def __process(self, stop_event: Event):
try: try:
self.__threadLimit.acquire() self.__threadLimit.acquire()
statistics = InvocationStats() # keep track of performance metrics statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
while not stop_event.is_set(): while not stop_event.is_set():
try: try:
queue_item: InvocationQueueItem = self.__invoker.services.queue.get() queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
@ -85,7 +86,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Invoke # Invoke
try: try:
with statistics.collect_stats(invocation, graph_execution_state): with statistics.collect_stats(invocation, graph_execution_state.id):
outputs = invocation.invoke( outputs = invocation.invoke(
InvocationContext( InvocationContext(
services=self.__invoker.services, services=self.__invoker.services,
@ -116,7 +117,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
pass pass
except CanceledException: except CanceledException:
statistics.reset_stats() statistics.reset_stats(graph_execution_state.id)
pass pass
except Exception as e: except Exception as e:
@ -138,7 +139,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error_type=e.__class__.__name__, error_type=e.__class__.__name__,
error=error, error=error,
) )
statistics.reset_stats() statistics.reset_stats(graph_execution_state.id)
pass pass
# Check queue to see if this is canceled, and skip if so # Check queue to see if this is canceled, and skip if so