integrate correctly into app API and add features

- Create abstract base class InvocationStatsServiceBase
- Store InvocationStatsService in the InvocationServices object
- Collect and report stats on simultaneous graph execution
  independently for each graph id
- Track VRAM usage for each node
- Handle cancellations and other exceptions gracefully
This commit is contained in:
Lincoln Stein 2023-08-02 18:10:52 -04:00
parent 8a4e5f73aa
commit 8fc75a71ee
4 changed files with 182 additions and 69 deletions

View File

@ -2,7 +2,6 @@
from typing import Optional
from logging import Logger
import os
from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage,
)
@ -30,6 +29,7 @@ from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.model_manager_service import ModelManagerService
from ..services.invocation_stats import InvocationStatsService
from .events import FastAPIEventService
@ -128,6 +128,7 @@ class ApiDependencies:
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
configuration=config,
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger,
)

View File

@ -32,6 +32,7 @@ class InvocationServices:
logger: "Logger"
model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC"
def __init__(
@ -47,6 +48,7 @@ class InvocationServices:
logger: "Logger",
model_manager: "ModelManagerServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
):
self.board_images = board_images
@ -61,4 +63,5 @@ class InvocationServices:
self.logger = logger
self.model_manager = model_manager
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue

View File

@ -3,99 +3,196 @@
"""
Usage:
statistics = InvocationStats() # keep track of performance metrics
...
with statistics.collect_stats(invocation, graph_execution_state):
outputs = invocation.invoke(
InvocationContext(
services=self.__invoker.services,
graph_execution_state_id=graph_execution_state.id,
)
)
...
statistics = InvocationStatsService(graph_execution_manager)
with statistics.collect_stats(invocation, graph_execution_state.id):
... execute graphs...
statistics.log_stats()
Typical output:
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> Node Calls Seconds
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> main_model_loader 1 0.006s
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> clip_skip 1 0.005s
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> compel 2 0.351s
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> rand_int 1 0.001s
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> range_of_size 1 0.001s
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> iterate 1 0.001s
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> noise 1 0.002s
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> t2l 1 3.117s
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> l2i 1 0.377s
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> TOTAL: 3.865s
[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> Max VRAM used for execution: 3.12G.
[2023-08-01 17:34:44,586]::[InvokeAI]::INFO --> Current VRAM utilization 2.31G.
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
writes to the system log is stored in InvocationServices.performance_statistics.
"""
import time
from typing import Dict, List
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager
from dataclasses import dataclass, field
from typing import Dict
import torch
from .graph import GraphExecutionState
from .invocation_queue import InvocationQueueItem
from ..invocations.baseinvocation import BaseInvocation
import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import BaseInvocation
from .graph import GraphExecutionState
from .item_storage import ItemStorageABC
class InvocationStats:
class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics"
@abstractmethod
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
"""
Initialize the InvocationStatsService and reset counters to zero
:param graph_execution_manager: Graph execution manager for this session
"""
pass
@abstractmethod
def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
) -> AbstractContextManager:
"""
Return a context object that will capture the statistics on the execution
of invocaation. Use with: to place around the part of the code that executes the invocation.
:param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session.
"""
pass
@abstractmethod
def reset_stats(self, graph_execution_state_id: str):
"""
Reset all statistics for the indicated graph
:param graph_execution_state_id
"""
pass
@abstractmethod
def reset_all_stats(self):
"""Zero all statistics"""
pass
@abstractmethod
def update_invocation_stats(
self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
"""
pass
@abstractmethod
def log_stats(self):
"""
Write out the accumulated statistics to the log or somewhere else.
"""
pass
@dataclass
class NodeStats:
"""Class for tracking execution stats of an invocation node"""
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
@dataclass
class NodeLog:
"""Class for tracking node usage"""
# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)
class InvocationStatsService(InvocationStatsServiceBase):
"""Accumulate performance information about a running graph. Collects time spent in each node,
as well as the maximum and current VRAM utilisation for CUDA systems"""
def __init__(self):
self._stats: Dict[str, int] = {}
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
self.graph_execution_manager = graph_execution_manager
# {graph_id => NodeLog}
self._stats: Dict[str, NodeLog] = {}
class StatsContext:
def __init__(self, invocation: BaseInvocation, collector):
def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"):
self.invocation = invocation
self.collector = collector
self.graph_id = graph_id
self.start_time = 0
def __enter__(self):
self.start_time = time.time()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
def __exit__(self, *args):
self.collector.log_time(self.invocation.type, time.time() - self.start_time)
self.collector.update_invocation_stats(
self.graph_id,
self.invocation.type,
time.time() - self.start_time,
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
)
def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state: GraphExecutionState,
graph_execution_state_id: str,
) -> StatsContext:
"""
Return a context object that will capture the statistics.
:param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session.
"""
if len(graph_execution_state.executed) == 0: # new graph is starting
self.reset_stats()
self._current_graph_state = graph_execution_state
sc = self.StatsContext(invocation, self)
return self.StatsContext(invocation, self)
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog()
return self.StatsContext(invocation, graph_execution_state_id, self)
def reset_stats(self):
"""Zero the statistics. Ordinarily called internally."""
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
self._stats: Dict[str, List[int, float]] = {}
def reset_all_stats(self):
"""Zero all statistics"""
self._stats = {}
def log_time(self, invocation_type: str, time_used: float):
def reset_stats(self, graph_execution_id: str):
"""Zero the statistics for the indicated graph."""
try:
self._stats.pop(graph_execution_id)
except KeyError:
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Floating point seconds used by node's exection
"""
if not self._stats.get(invocation_type):
self._stats[invocation_type] = [0, 0.0]
self._stats[invocation_type][0] += 1
self._stats[invocation_type][1] += time_used
if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats()
stats = self._stats[graph_id].nodes[invocation_type]
stats.calls += 1
stats.time_used += time_used
stats.max_vram = max(stats.max_vram, vram_used)
def log_stats(self):
"""
@ -103,13 +200,24 @@ class InvocationStats:
Stats will only be printed if when the execution of the graph
is complete.
"""
if self._current_graph_state.is_complete():
logger.info("Node Calls Seconds")
for node_type, (calls, time_used) in self._stats.items():
logger.info(f"{node_type:<20} {calls:>5} {time_used:4.3f}s")
completed = set()
for graph_id, node_log in self._stats.items():
current_graph_state = self.graph_execution_manager.get(graph_id)
if not current_graph_state.is_complete():
continue
total_time = sum([ticks for _, ticks in self._stats.values()])
logger.info(f"TOTAL: {total_time:4.3f}s")
total_time = 0
logger.info(f"Graph stats: {graph_id}")
logger.info("Node Calls Seconds VRAM Used")
for node_type, stats in self._stats[graph_id].nodes.items():
logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:4.3f}s {stats.max_vram:4.2f}G")
total_time += stats.time_used
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:4.3f}s")
if torch.cuda.is_available():
logger.info("Max VRAM used for execution: " + "%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9))
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
completed.add(graph_id)
for graph_id in completed:
del self._stats[graph_id]

View File

@ -1,15 +1,15 @@
import time
import traceback
from threading import Event, Thread, BoundedSemaphore
from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker
from .invocation_stats import InvocationStats
from ..models.exceptions import CanceledException
from threading import BoundedSemaphore, Event, Thread
import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import InvocationContext
from ..models.exceptions import CanceledException
from .invocation_queue import InvocationQueueItem
from .invocation_stats import InvocationStatsServiceBase
from .invoker import InvocationProcessorABC, Invoker
class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread
@ -36,7 +36,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
def __process(self, stop_event: Event):
try:
self.__threadLimit.acquire()
statistics = InvocationStats() # keep track of performance metrics
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
while not stop_event.is_set():
try:
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
@ -85,7 +86,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Invoke
try:
with statistics.collect_stats(invocation, graph_execution_state):
with statistics.collect_stats(invocation, graph_execution_state.id):
outputs = invocation.invoke(
InvocationContext(
services=self.__invoker.services,
@ -116,7 +117,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
pass
except CanceledException:
statistics.reset_stats()
statistics.reset_stats(graph_execution_state.id)
pass
except Exception as e:
@ -138,7 +139,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error_type=e.__class__.__name__,
error=error,
)
statistics.reset_stats()
statistics.reset_stats(graph_execution_state.id)
pass
# Check queue to see if this is canceled, and skip if so