stats: handle exceptions (#4320)

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission

## Description

[fix(stats): fix fail case when previous graph is
invalid](d1d2d5a47d)

When retrieving a graph, it is parsed through pydantic. It is possible
that this graph is invalid, and an error is thrown.

Handle this by deleting the failed graph from the stats if this occurs.

[fix(stats): fix InvocationStatsService
types](1b70bd1380)

- move docstrings to ABC
- `start_time: int` -> `start_time: float`
- remove class attribute assignments in `StatsContext`
- add `update_mem_stats()` to ABC
- add class attributes to ABC, because they are referenced in instances
of the class. if they should not be on the ABC, then maybe there needs
to be some restructuring

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

On `main` (not this PR), create a situation in which an graph is valid
but will be rendered invalid on invoke. Easy way in node editor:
- create an `Integer Primitive` node, set value to 3
- create a `Resize Image` node and add an image to it
- route the output of `Integer Primitive` to the `width` of `Resize
Image`
- Invoke - this will cause first a `Validation Error` (expected), and if
you inspect the error in the JS console, you'll see it is a "session
retrieval error"
- Invoke again - this will also cause a `Validation Error`, but if you
inspect the error you should see it originates in the stats module (this
is the error this PR fixes)
- Fix the graph by setting the `Integer Primitive` to 512
- Invoke again - you get the same `Validation Error` originating from
stats, even tho there are no issues

Switch to this PR, and then you should only ever get the `Validation
Error` that that is classified as a "session retrieval error".
This commit is contained in:
Lincoln Stein 2023-08-21 19:47:21 -04:00 committed by GitHub
commit 572e6b892a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -49,9 +49,36 @@ from invokeai.backend.model_management.model_cache import CacheStats
GIG = 1073741824 GIG = 1073741824
@dataclass
class NodeStats:
"""Class for tracking execution stats of an invocation node"""
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0
@dataclass
class NodeLog:
"""Class for tracking node usage"""
# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)
class InvocationStatsServiceBase(ABC): class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics" "Abstract base class for recording node memory/time performance statistics"
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
# {graph_id => NodeLog}
_stats: Dict[str, NodeLog]
_cache_stats: Dict[str, CacheStats]
ram_used: float
ram_changed: float
@abstractmethod @abstractmethod
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]): def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
""" """
@ -94,8 +121,6 @@ class InvocationStatsServiceBase(ABC):
invocation_type: str, invocation_type: str,
time_used: float, time_used: float,
vram_used: float, vram_used: float,
ram_used: float,
ram_changed: float,
): ):
""" """
Add timing information on execution of a node. Usually Add timing information on execution of a node. Usually
@ -104,8 +129,6 @@ class InvocationStatsServiceBase(ABC):
:param invocation_type: String literal type of the node :param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec) :param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB) :param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
""" """
pass pass
@ -116,25 +139,19 @@ class InvocationStatsServiceBase(ABC):
""" """
pass pass
@abstractmethod
def update_mem_stats(
self,
ram_used: float,
ram_changed: float,
):
"""
Update the collector with RAM memory usage info.
@dataclass :param ram_used: How much RAM is currently in use.
class NodeStats: :param ram_changed: How much RAM changed since last generation.
"""Class for tracking execution stats of an invocation node""" """
pass
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0
@dataclass
class NodeLog:
"""Class for tracking node usage"""
# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)
class InvocationStatsService(InvocationStatsServiceBase): class InvocationStatsService(InvocationStatsServiceBase):
@ -152,12 +169,12 @@ class InvocationStatsService(InvocationStatsServiceBase):
class StatsContext: class StatsContext:
"""Context manager for collecting statistics.""" """Context manager for collecting statistics."""
invocation: BaseInvocation = None invocation: BaseInvocation
collector: "InvocationStatsServiceBase" = None collector: "InvocationStatsServiceBase"
graph_id: str = None graph_id: str
start_time: int = 0 start_time: float
ram_used: int = 0 ram_used: int
model_manager: ModelManagerService = None model_manager: ModelManagerService
def __init__( def __init__(
self, self,
@ -170,7 +187,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
self.invocation = invocation self.invocation = invocation
self.collector = collector self.collector = collector
self.graph_id = graph_id self.graph_id = graph_id
self.start_time = 0 self.start_time = 0.0
self.ram_used = 0 self.ram_used = 0
self.model_manager = model_manager self.model_manager = model_manager
@ -191,7 +208,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
) )
self.collector.update_invocation_stats( self.collector.update_invocation_stats(
graph_id=self.graph_id, graph_id=self.graph_id,
invocation_type=self.invocation.type, invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
time_used=time.time() - self.start_time, time_used=time.time() - self.start_time,
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0, vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
) )
@ -202,11 +219,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
graph_execution_state_id: str, graph_execution_state_id: str,
model_manager: ModelManagerService, model_manager: ModelManagerService,
) -> StatsContext: ) -> StatsContext:
"""
Return a context object that will capture the statistics.
:param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session.
"""
if not self._stats.get(graph_execution_state_id): # first time we're seeing this if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog() self._stats[graph_execution_state_id] = NodeLog()
self._cache_stats[graph_execution_state_id] = CacheStats() self._cache_stats[graph_execution_state_id] = CacheStats()
@ -217,7 +229,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
self._stats = {} self._stats = {}
def reset_stats(self, graph_execution_id: str): def reset_stats(self, graph_execution_id: str):
"""Zero the statistics for the indicated graph."""
try: try:
self._stats.pop(graph_execution_id) self._stats.pop(graph_execution_id)
except KeyError: except KeyError:
@ -228,12 +239,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
ram_used: float, ram_used: float,
ram_changed: float, ram_changed: float,
): ):
"""
Update the collector with RAM memory usage info.
:param ram_used: How much RAM is currently in use.
:param ram_changed: How much RAM changed since last generation.
"""
self.ram_used = ram_used self.ram_used = ram_used
self.ram_changed = ram_changed self.ram_changed = ram_changed
@ -244,16 +249,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
time_used: float, time_used: float,
vram_used: float, vram_used: float,
): ):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
"""
if not self._stats[graph_id].nodes.get(invocation_type): if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats() self._stats[graph_id].nodes[invocation_type] = NodeStats()
stats = self._stats[graph_id].nodes[invocation_type] stats = self._stats[graph_id].nodes[invocation_type]
@ -262,14 +257,15 @@ class InvocationStatsService(InvocationStatsServiceBase):
stats.max_vram = max(stats.max_vram, vram_used) stats.max_vram = max(stats.max_vram, vram_used)
def log_stats(self): def log_stats(self):
"""
Send the statistics to the system logger at the info level.
Stats will only be printed when the execution of the graph
is complete.
"""
completed = set() completed = set()
errored = set()
for graph_id, node_log in self._stats.items(): for graph_id, node_log in self._stats.items():
current_graph_state = self.graph_execution_manager.get(graph_id) try:
current_graph_state = self.graph_execution_manager.get(graph_id)
except Exception:
errored.add(graph_id)
continue
if not current_graph_state.is_complete(): if not current_graph_state.is_complete():
continue continue
@ -302,3 +298,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
for graph_id in completed: for graph_id in completed:
del self._stats[graph_id] del self._stats[graph_id]
del self._cache_stats[graph_id] del self._cache_stats[graph_id]
for graph_id in errored:
del self._stats[graph_id]
del self._cache_stats[graph_id]