mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
stats: handle exceptions (#4320)
## What type of PR is this? (check all applicable) - [ ] Refactor - [ ] Feature - [x] Bug Fix - [ ] Optimization - [ ] Documentation Update - [ ] Community Node Submission ## Description [fix(stats): fix fail case when previous graph is invalid](d1d2d5a47d
) When retrieving a graph, it is parsed through pydantic. It is possible that this graph is invalid, and an error is thrown. Handle this by deleting the failed graph from the stats if this occurs. [fix(stats): fix InvocationStatsService types](1b70bd1380
) - move docstrings to ABC - `start_time: int` -> `start_time: float` - remove class attribute assignments in `StatsContext` - add `update_mem_stats()` to ABC - add class attributes to ABC, because they are referenced in instances of the class. if they should not be on the ABC, then maybe there needs to be some restructuring ## QA Instructions, Screenshots, Recordings <!-- Please provide steps on how to test changes, any hardware or software specifications as well as any other pertinent information. --> On `main` (not this PR), create a situation in which an graph is valid but will be rendered invalid on invoke. Easy way in node editor: - create an `Integer Primitive` node, set value to 3 - create a `Resize Image` node and add an image to it - route the output of `Integer Primitive` to the `width` of `Resize Image` - Invoke - this will cause first a `Validation Error` (expected), and if you inspect the error in the JS console, you'll see it is a "session retrieval error" - Invoke again - this will also cause a `Validation Error`, but if you inspect the error you should see it originates in the stats module (this is the error this PR fixes) - Fix the graph by setting the `Integer Primitive` to 512 - Invoke again - you get the same `Validation Error` originating from stats, even tho there are no issues Switch to this PR, and then you should only ever get the `Validation Error` that that is classified as a "session retrieval error".
This commit is contained in:
commit
572e6b892a
@ -49,9 +49,36 @@ from invokeai.backend.model_management.model_cache import CacheStats
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeStats:
|
||||
"""Class for tracking execution stats of an invocation node"""
|
||||
|
||||
calls: int = 0
|
||||
time_used: float = 0.0 # seconds
|
||||
max_vram: float = 0.0 # GB
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
cache_high_watermark: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeLog:
|
||||
"""Class for tracking node usage"""
|
||||
|
||||
# {node_type => NodeStats}
|
||||
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
||||
|
||||
|
||||
class InvocationStatsServiceBase(ABC):
|
||||
"Abstract base class for recording node memory/time performance statistics"
|
||||
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
# {graph_id => NodeLog}
|
||||
_stats: Dict[str, NodeLog]
|
||||
_cache_stats: Dict[str, CacheStats]
|
||||
ram_used: float
|
||||
ram_changed: float
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||
"""
|
||||
@ -94,8 +121,6 @@ class InvocationStatsServiceBase(ABC):
|
||||
invocation_type: str,
|
||||
time_used: float,
|
||||
vram_used: float,
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
"""
|
||||
Add timing information on execution of a node. Usually
|
||||
@ -104,8 +129,6 @@ class InvocationStatsServiceBase(ABC):
|
||||
:param invocation_type: String literal type of the node
|
||||
:param time_used: Time used by node's exection (sec)
|
||||
:param vram_used: Maximum VRAM used during exection (GB)
|
||||
:param ram_used: Current RAM available (GB)
|
||||
:param ram_changed: Change in RAM usage over course of the run (GB)
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -116,25 +139,19 @@ class InvocationStatsServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_mem_stats(
|
||||
self,
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
"""
|
||||
Update the collector with RAM memory usage info.
|
||||
|
||||
@dataclass
|
||||
class NodeStats:
|
||||
"""Class for tracking execution stats of an invocation node"""
|
||||
|
||||
calls: int = 0
|
||||
time_used: float = 0.0 # seconds
|
||||
max_vram: float = 0.0 # GB
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
cache_high_watermark: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeLog:
|
||||
"""Class for tracking node usage"""
|
||||
|
||||
# {node_type => NodeStats}
|
||||
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
||||
:param ram_used: How much RAM is currently in use.
|
||||
:param ram_changed: How much RAM changed since last generation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class InvocationStatsService(InvocationStatsServiceBase):
|
||||
@ -152,12 +169,12 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
class StatsContext:
|
||||
"""Context manager for collecting statistics."""
|
||||
|
||||
invocation: BaseInvocation = None
|
||||
collector: "InvocationStatsServiceBase" = None
|
||||
graph_id: str = None
|
||||
start_time: int = 0
|
||||
ram_used: int = 0
|
||||
model_manager: ModelManagerService = None
|
||||
invocation: BaseInvocation
|
||||
collector: "InvocationStatsServiceBase"
|
||||
graph_id: str
|
||||
start_time: float
|
||||
ram_used: int
|
||||
model_manager: ModelManagerService
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -170,7 +187,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self.invocation = invocation
|
||||
self.collector = collector
|
||||
self.graph_id = graph_id
|
||||
self.start_time = 0
|
||||
self.start_time = 0.0
|
||||
self.ram_used = 0
|
||||
self.model_manager = model_manager
|
||||
|
||||
@ -191,7 +208,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
)
|
||||
self.collector.update_invocation_stats(
|
||||
graph_id=self.graph_id,
|
||||
invocation_type=self.invocation.type,
|
||||
invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
|
||||
time_used=time.time() - self.start_time,
|
||||
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
||||
)
|
||||
@ -202,11 +219,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
graph_execution_state_id: str,
|
||||
model_manager: ModelManagerService,
|
||||
) -> StatsContext:
|
||||
"""
|
||||
Return a context object that will capture the statistics.
|
||||
:param invocation: BaseInvocation object from the current graph.
|
||||
:param graph_execution_state: GraphExecutionState object from the current session.
|
||||
"""
|
||||
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||
self._stats[graph_execution_state_id] = NodeLog()
|
||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||
@ -217,7 +229,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self._stats = {}
|
||||
|
||||
def reset_stats(self, graph_execution_id: str):
|
||||
"""Zero the statistics for the indicated graph."""
|
||||
try:
|
||||
self._stats.pop(graph_execution_id)
|
||||
except KeyError:
|
||||
@ -228,12 +239,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
"""
|
||||
Update the collector with RAM memory usage info.
|
||||
|
||||
:param ram_used: How much RAM is currently in use.
|
||||
:param ram_changed: How much RAM changed since last generation.
|
||||
"""
|
||||
self.ram_used = ram_used
|
||||
self.ram_changed = ram_changed
|
||||
|
||||
@ -244,16 +249,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
time_used: float,
|
||||
vram_used: float,
|
||||
):
|
||||
"""
|
||||
Add timing information on execution of a node. Usually
|
||||
used internally.
|
||||
:param graph_id: ID of the graph that is currently executing
|
||||
:param invocation_type: String literal type of the node
|
||||
:param time_used: Time used by node's exection (sec)
|
||||
:param vram_used: Maximum VRAM used during exection (GB)
|
||||
:param ram_used: Current RAM available (GB)
|
||||
:param ram_changed: Change in RAM usage over course of the run (GB)
|
||||
"""
|
||||
if not self._stats[graph_id].nodes.get(invocation_type):
|
||||
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
||||
stats = self._stats[graph_id].nodes[invocation_type]
|
||||
@ -262,14 +257,15 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
stats.max_vram = max(stats.max_vram, vram_used)
|
||||
|
||||
def log_stats(self):
|
||||
"""
|
||||
Send the statistics to the system logger at the info level.
|
||||
Stats will only be printed when the execution of the graph
|
||||
is complete.
|
||||
"""
|
||||
completed = set()
|
||||
errored = set()
|
||||
for graph_id, node_log in self._stats.items():
|
||||
try:
|
||||
current_graph_state = self.graph_execution_manager.get(graph_id)
|
||||
except Exception:
|
||||
errored.add(graph_id)
|
||||
continue
|
||||
|
||||
if not current_graph_state.is_complete():
|
||||
continue
|
||||
|
||||
@ -302,3 +298,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
for graph_id in completed:
|
||||
del self._stats[graph_id]
|
||||
del self._cache_stats[graph_id]
|
||||
|
||||
for graph_id in errored:
|
||||
del self._stats[graph_id]
|
||||
del self._cache_stats[graph_id]
|
||||
|
Loading…
Reference in New Issue
Block a user