Add execution stat reporting after each invocation (#4125)

## What type of PR is this? (check all applicable)

- [X] Feature


## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [X] Yes
- [ ] No

## Description

This PR adds execution time and VRAM usage reporting to each graph
invocation. The log output will look like this:

```
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3                                                                                              
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node                 Calls  Seconds  VRAM Used                                                                                                 
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader        1   0.005s     0.01G                                                                                                  
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip                1   0.004s     0.01G                                                                                                  
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel                   2   0.512s     0.26G                                                                                                  
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int                 1   0.001s     0.01G                                                                                                  
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size            1   0.001s     0.01G                                                                                                  
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate                  1   0.001s     0.01G                                                                                                  
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator     1   0.002s     0.01G                                                                                                  
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise                    1   0.002s     0.01G                                                                                                  
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l                      1   3.541s     1.93G                                                                                                  
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i                      1   0.679s     0.58G                                                                                                  
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME:  4.749s                                                                                                            
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G                                                                                                                 
```
On systems without CUDA, the VRAM stats are not printed.

The current implementation keeps track of graph ids separately so will
not be confused when several graphs are executing in parallel. It
handles exceptions, and it is integrated into the app framework by
defining an abstract base class and storing an implementation instance
in `InvocationServices`.
This commit is contained in:
Lincoln Stein 2023-08-03 20:05:21 -04:00 committed by GitHub
commit 9091e19de8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 278 additions and 35 deletions

View File

@ -2,7 +2,6 @@
from typing import Optional from typing import Optional
from logging import Logger from logging import Logger
import os
from invokeai.app.services.board_image_record_storage import ( from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage, SqliteBoardImageRecordStorage,
) )
@ -30,6 +29,7 @@ from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage from ..services.sqlite import SqliteItemStorage
from ..services.model_manager_service import ModelManagerService from ..services.model_manager_service import ModelManagerService
from ..services.invocation_stats import InvocationStatsService
from .events import FastAPIEventService from .events import FastAPIEventService
@ -128,6 +128,7 @@ class ApiDependencies:
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
configuration=config, configuration=config,
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger, logger=logger,
) )

View File

@ -37,6 +37,7 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from invokeai.app.services.invocation_stats import InvocationStatsService
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@ -311,6 +312,7 @@ def invoke_cli():
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"), graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger, logger=logger,
configuration=config, configuration=config,
) )

View File

@ -32,6 +32,7 @@ class InvocationServices:
logger: "Logger" logger: "Logger"
model_manager: "ModelManagerServiceBase" model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC" processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC" queue: "InvocationQueueABC"
def __init__( def __init__(
@ -47,6 +48,7 @@ class InvocationServices:
logger: "Logger", logger: "Logger",
model_manager: "ModelManagerServiceBase", model_manager: "ModelManagerServiceBase",
processor: "InvocationProcessorABC", processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC", queue: "InvocationQueueABC",
): ):
self.board_images = board_images self.board_images = board_images
@ -61,4 +63,5 @@ class InvocationServices:
self.logger = logger self.logger = logger
self.model_manager = model_manager self.model_manager = model_manager
self.processor = processor self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue self.queue = queue

View File

@ -0,0 +1,223 @@
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
"""Utility to collect execution time and GPU usage stats on invocations in flight"""
"""
Usage:
statistics = InvocationStatsService(graph_execution_manager)
with statistics.collect_stats(invocation, graph_execution_state.id):
... execute graphs...
statistics.log_stats()
Typical output:
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
writes to the system log is stored in InvocationServices.performance_statistics.
"""
import time
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager
from dataclasses import dataclass, field
from typing import Dict
import torch
import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import BaseInvocation
from .graph import GraphExecutionState
from .item_storage import ItemStorageABC
class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics"
@abstractmethod
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
"""
Initialize the InvocationStatsService and reset counters to zero
:param graph_execution_manager: Graph execution manager for this session
"""
pass
@abstractmethod
def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
) -> AbstractContextManager:
"""
Return a context object that will capture the statistics on the execution
of invocaation. Use with: to place around the part of the code that executes the invocation.
:param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session.
"""
pass
@abstractmethod
def reset_stats(self, graph_execution_state_id: str):
"""
Reset all statistics for the indicated graph
:param graph_execution_state_id
"""
pass
@abstractmethod
def reset_all_stats(self):
"""Zero all statistics"""
pass
@abstractmethod
def update_invocation_stats(
self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
"""
pass
@abstractmethod
def log_stats(self):
"""
Write out the accumulated statistics to the log or somewhere else.
"""
pass
@dataclass
class NodeStats:
"""Class for tracking execution stats of an invocation node"""
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
@dataclass
class NodeLog:
"""Class for tracking node usage"""
# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)
class InvocationStatsService(InvocationStatsServiceBase):
"""Accumulate performance information about a running graph. Collects time spent in each node,
as well as the maximum and current VRAM utilisation for CUDA systems"""
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
self.graph_execution_manager = graph_execution_manager
# {graph_id => NodeLog}
self._stats: Dict[str, NodeLog] = {}
class StatsContext:
def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"):
self.invocation = invocation
self.collector = collector
self.graph_id = graph_id
self.start_time = 0
def __enter__(self):
self.start_time = time.time()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
def __exit__(self, *args):
self.collector.update_invocation_stats(
self.graph_id,
self.invocation.type,
time.time() - self.start_time,
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
)
def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
) -> StatsContext:
"""
Return a context object that will capture the statistics.
:param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session.
"""
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog()
return self.StatsContext(invocation, graph_execution_state_id, self)
def reset_all_stats(self):
"""Zero all statistics"""
self._stats = {}
def reset_stats(self, graph_execution_id: str):
"""Zero the statistics for the indicated graph."""
try:
self._stats.pop(graph_execution_id)
except KeyError:
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Floating point seconds used by node's exection
"""
if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats()
stats = self._stats[graph_id].nodes[invocation_type]
stats.calls += 1
stats.time_used += time_used
stats.max_vram = max(stats.max_vram, vram_used)
def log_stats(self):
"""
Send the statistics to the system logger at the info level.
Stats will only be printed if when the execution of the graph
is complete.
"""
completed = set()
for graph_id, node_log in self._stats.items():
current_graph_state = self.graph_execution_manager.get(graph_id)
if not current_graph_state.is_complete():
continue
total_time = 0
logger.info(f"Graph stats: {graph_id}")
logger.info("Node Calls Seconds VRAM Used")
for node_type, stats in self._stats[graph_id].nodes.items():
logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G")
total_time += stats.time_used
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
if torch.cuda.is_available():
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
completed.add(graph_id)
for graph_id in completed:
del self._stats[graph_id]

View File

@ -1,14 +1,15 @@
import time import time
import traceback import traceback
from threading import Event, Thread, BoundedSemaphore from threading import BoundedSemaphore, Event, Thread
from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker
from ..models.exceptions import CanceledException
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import InvocationContext
from ..models.exceptions import CanceledException
from .invocation_queue import InvocationQueueItem
from .invocation_stats import InvocationStatsServiceBase
from .invoker import InvocationProcessorABC, Invoker
class DefaultInvocationProcessor(InvocationProcessorABC): class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread __invoker_thread: Thread
@ -35,6 +36,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
def __process(self, stop_event: Event): def __process(self, stop_event: Event):
try: try:
self.__threadLimit.acquire() self.__threadLimit.acquire()
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
while not stop_event.is_set(): while not stop_event.is_set():
try: try:
queue_item: InvocationQueueItem = self.__invoker.services.queue.get() queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
@ -83,35 +86,38 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Invoke # Invoke
try: try:
outputs = invocation.invoke( with statistics.collect_stats(invocation, graph_execution_state.id):
InvocationContext( outputs = invocation.invoke(
services=self.__invoker.services, InvocationContext(
graph_execution_state_id=graph_execution_state.id, services=self.__invoker.services,
graph_execution_state_id=graph_execution_state.id,
)
) )
)
# Check queue to see if this is canceled, and skip if so # Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(graph_execution_state.id): if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
continue continue
# Save outputs and history # Save outputs and history
graph_execution_state.complete(invocation.id, outputs) graph_execution_state.complete(invocation.id, outputs)
# Save the state changes # Save the state changes
self.__invoker.services.graph_execution_manager.set(graph_execution_state) self.__invoker.services.graph_execution_manager.set(graph_execution_state)
# Send complete event # Send complete event
self.__invoker.services.events.emit_invocation_complete( self.__invoker.services.events.emit_invocation_complete(
graph_execution_state_id=graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(), node=invocation.dict(),
source_node_id=source_node_id, source_node_id=source_node_id,
result=outputs.dict(), result=outputs.dict(),
) )
statistics.log_stats()
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
except CanceledException: except CanceledException:
statistics.reset_stats(graph_execution_state.id)
pass pass
except Exception as e: except Exception as e:
@ -133,7 +139,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error_type=e.__class__.__name__, error_type=e.__class__.__name__,
error=error, error=error,
) )
statistics.reset_stats(graph_execution_state.id)
pass pass
# Check queue to see if this is canceled, and skip if so # Check queue to see if this is canceled, and skip if so

View File

@ -2,10 +2,12 @@
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein) # Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
import warnings import warnings
from invokeai.frontend.CLI import invokeai_command_line_interface as main
warnings.warn( warnings.warn(
"dream.py is being deprecated, please run invoke.py for the " "new UI/API or legacy_api.py for the old API", "dream.py is being deprecated, please run invoke.py for the " "new UI/API or legacy_api.py for the old API",
DeprecationWarning, DeprecationWarning,
) )
main()
from invokeai.app.cli_app import invoke_cli
invoke_cli()

View File

@ -16,6 +16,7 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats import InvocationStatsService
from invokeai.app.services.graph import ( from invokeai.app.services.graph import (
Graph, Graph,
CollectInvocation, CollectInvocation,
@ -41,6 +42,9 @@ def simple_graph():
@pytest.fixture @pytest.fixture
def mock_services() -> InvocationServices: def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=sqlite_memory, table_name="graph_executions"
)
return InvocationServices( return InvocationServices(
model_manager=None, # type: ignore model_manager=None, # type: ignore
events=TestEventService(), events=TestEventService(),
@ -51,9 +55,8 @@ def mock_services() -> InvocationServices:
board_images=None, # type: ignore board_images=None, # type: ignore
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"), graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
graph_execution_manager=SqliteItemStorage[GraphExecutionState]( graph_execution_manager=graph_execution_manager,
filename=sqlite_memory, table_name="graph_executions" performance_statistics=InvocationStatsService(graph_execution_manager),
),
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
configuration=None, # type: ignore configuration=None, # type: ignore
) )

View File

@ -11,6 +11,7 @@ from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats import InvocationStatsService
from invokeai.app.services.graph import ( from invokeai.app.services.graph import (
Graph, Graph,
GraphExecutionState, GraphExecutionState,
@ -34,6 +35,9 @@ def simple_graph():
@pytest.fixture @pytest.fixture
def mock_services() -> InvocationServices: def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=sqlite_memory, table_name="graph_executions"
)
return InvocationServices( return InvocationServices(
model_manager=None, # type: ignore model_manager=None, # type: ignore
events=TestEventService(), events=TestEventService(),
@ -44,10 +48,9 @@ def mock_services() -> InvocationServices:
board_images=None, # type: ignore board_images=None, # type: ignore
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"), graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
graph_execution_manager=SqliteItemStorage[GraphExecutionState]( graph_execution_manager=graph_execution_manager,
filename=sqlite_memory, table_name="graph_executions"
),
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager),
configuration=None, # type: ignore configuration=None, # type: ignore
) )