mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
stats: handle exceptions (#4320)
## What type of PR is this? (check all applicable) - [ ] Refactor - [ ] Feature - [x] Bug Fix - [ ] Optimization - [ ] Documentation Update - [ ] Community Node Submission ## Description [fix(stats): fix fail case when previous graph is invalid](d1d2d5a47d
) When retrieving a graph, it is parsed through pydantic. It is possible that this graph is invalid, and an error is thrown. Handle this by deleting the failed graph from the stats if this occurs. [fix(stats): fix InvocationStatsService types](1b70bd1380
) - move docstrings to ABC - `start_time: int` -> `start_time: float` - remove class attribute assignments in `StatsContext` - add `update_mem_stats()` to ABC - add class attributes to ABC, because they are referenced in instances of the class. if they should not be on the ABC, then maybe there needs to be some restructuring ## QA Instructions, Screenshots, Recordings <!-- Please provide steps on how to test changes, any hardware or software specifications as well as any other pertinent information. --> On `main` (not this PR), create a situation in which an graph is valid but will be rendered invalid on invoke. Easy way in node editor: - create an `Integer Primitive` node, set value to 3 - create a `Resize Image` node and add an image to it - route the output of `Integer Primitive` to the `width` of `Resize Image` - Invoke - this will cause first a `Validation Error` (expected), and if you inspect the error in the JS console, you'll see it is a "session retrieval error" - Invoke again - this will also cause a `Validation Error`, but if you inspect the error you should see it originates in the stats module (this is the error this PR fixes) - Fix the graph by setting the `Integer Primitive` to 512 - Invoke again - you get the same `Validation Error` originating from stats, even tho there are no issues Switch to this PR, and then you should only ever get the `Validation Error` that that is classified as a "session retrieval error".
This commit is contained in:
commit
572e6b892a
@ -49,9 +49,36 @@ from invokeai.backend.model_management.model_cache import CacheStats
|
|||||||
GIG = 1073741824
|
GIG = 1073741824
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeStats:
|
||||||
|
"""Class for tracking execution stats of an invocation node"""
|
||||||
|
|
||||||
|
calls: int = 0
|
||||||
|
time_used: float = 0.0 # seconds
|
||||||
|
max_vram: float = 0.0 # GB
|
||||||
|
cache_hits: int = 0
|
||||||
|
cache_misses: int = 0
|
||||||
|
cache_high_watermark: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeLog:
|
||||||
|
"""Class for tracking node usage"""
|
||||||
|
|
||||||
|
# {node_type => NodeStats}
|
||||||
|
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class InvocationStatsServiceBase(ABC):
|
class InvocationStatsServiceBase(ABC):
|
||||||
"Abstract base class for recording node memory/time performance statistics"
|
"Abstract base class for recording node memory/time performance statistics"
|
||||||
|
|
||||||
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||||
|
# {graph_id => NodeLog}
|
||||||
|
_stats: Dict[str, NodeLog]
|
||||||
|
_cache_stats: Dict[str, CacheStats]
|
||||||
|
ram_used: float
|
||||||
|
ram_changed: float
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||||
"""
|
"""
|
||||||
@ -94,8 +121,6 @@ class InvocationStatsServiceBase(ABC):
|
|||||||
invocation_type: str,
|
invocation_type: str,
|
||||||
time_used: float,
|
time_used: float,
|
||||||
vram_used: float,
|
vram_used: float,
|
||||||
ram_used: float,
|
|
||||||
ram_changed: float,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Add timing information on execution of a node. Usually
|
Add timing information on execution of a node. Usually
|
||||||
@ -104,8 +129,6 @@ class InvocationStatsServiceBase(ABC):
|
|||||||
:param invocation_type: String literal type of the node
|
:param invocation_type: String literal type of the node
|
||||||
:param time_used: Time used by node's exection (sec)
|
:param time_used: Time used by node's exection (sec)
|
||||||
:param vram_used: Maximum VRAM used during exection (GB)
|
:param vram_used: Maximum VRAM used during exection (GB)
|
||||||
:param ram_used: Current RAM available (GB)
|
|
||||||
:param ram_changed: Change in RAM usage over course of the run (GB)
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -116,25 +139,19 @@ class InvocationStatsServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_mem_stats(
|
||||||
|
self,
|
||||||
|
ram_used: float,
|
||||||
|
ram_changed: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the collector with RAM memory usage info.
|
||||||
|
|
||||||
@dataclass
|
:param ram_used: How much RAM is currently in use.
|
||||||
class NodeStats:
|
:param ram_changed: How much RAM changed since last generation.
|
||||||
"""Class for tracking execution stats of an invocation node"""
|
"""
|
||||||
|
pass
|
||||||
calls: int = 0
|
|
||||||
time_used: float = 0.0 # seconds
|
|
||||||
max_vram: float = 0.0 # GB
|
|
||||||
cache_hits: int = 0
|
|
||||||
cache_misses: int = 0
|
|
||||||
cache_high_watermark: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NodeLog:
|
|
||||||
"""Class for tracking node usage"""
|
|
||||||
|
|
||||||
# {node_type => NodeStats}
|
|
||||||
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class InvocationStatsService(InvocationStatsServiceBase):
|
class InvocationStatsService(InvocationStatsServiceBase):
|
||||||
@ -152,12 +169,12 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
class StatsContext:
|
class StatsContext:
|
||||||
"""Context manager for collecting statistics."""
|
"""Context manager for collecting statistics."""
|
||||||
|
|
||||||
invocation: BaseInvocation = None
|
invocation: BaseInvocation
|
||||||
collector: "InvocationStatsServiceBase" = None
|
collector: "InvocationStatsServiceBase"
|
||||||
graph_id: str = None
|
graph_id: str
|
||||||
start_time: int = 0
|
start_time: float
|
||||||
ram_used: int = 0
|
ram_used: int
|
||||||
model_manager: ModelManagerService = None
|
model_manager: ModelManagerService
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -170,7 +187,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
self.invocation = invocation
|
self.invocation = invocation
|
||||||
self.collector = collector
|
self.collector = collector
|
||||||
self.graph_id = graph_id
|
self.graph_id = graph_id
|
||||||
self.start_time = 0
|
self.start_time = 0.0
|
||||||
self.ram_used = 0
|
self.ram_used = 0
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
|
|
||||||
@ -191,7 +208,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
)
|
)
|
||||||
self.collector.update_invocation_stats(
|
self.collector.update_invocation_stats(
|
||||||
graph_id=self.graph_id,
|
graph_id=self.graph_id,
|
||||||
invocation_type=self.invocation.type,
|
invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
|
||||||
time_used=time.time() - self.start_time,
|
time_used=time.time() - self.start_time,
|
||||||
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
||||||
)
|
)
|
||||||
@ -202,11 +219,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
model_manager: ModelManagerService,
|
model_manager: ModelManagerService,
|
||||||
) -> StatsContext:
|
) -> StatsContext:
|
||||||
"""
|
|
||||||
Return a context object that will capture the statistics.
|
|
||||||
:param invocation: BaseInvocation object from the current graph.
|
|
||||||
:param graph_execution_state: GraphExecutionState object from the current session.
|
|
||||||
"""
|
|
||||||
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||||
self._stats[graph_execution_state_id] = NodeLog()
|
self._stats[graph_execution_state_id] = NodeLog()
|
||||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||||
@ -217,7 +229,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
self._stats = {}
|
self._stats = {}
|
||||||
|
|
||||||
def reset_stats(self, graph_execution_id: str):
|
def reset_stats(self, graph_execution_id: str):
|
||||||
"""Zero the statistics for the indicated graph."""
|
|
||||||
try:
|
try:
|
||||||
self._stats.pop(graph_execution_id)
|
self._stats.pop(graph_execution_id)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -228,12 +239,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
ram_used: float,
|
ram_used: float,
|
||||||
ram_changed: float,
|
ram_changed: float,
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Update the collector with RAM memory usage info.
|
|
||||||
|
|
||||||
:param ram_used: How much RAM is currently in use.
|
|
||||||
:param ram_changed: How much RAM changed since last generation.
|
|
||||||
"""
|
|
||||||
self.ram_used = ram_used
|
self.ram_used = ram_used
|
||||||
self.ram_changed = ram_changed
|
self.ram_changed = ram_changed
|
||||||
|
|
||||||
@ -244,16 +249,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
time_used: float,
|
time_used: float,
|
||||||
vram_used: float,
|
vram_used: float,
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Add timing information on execution of a node. Usually
|
|
||||||
used internally.
|
|
||||||
:param graph_id: ID of the graph that is currently executing
|
|
||||||
:param invocation_type: String literal type of the node
|
|
||||||
:param time_used: Time used by node's exection (sec)
|
|
||||||
:param vram_used: Maximum VRAM used during exection (GB)
|
|
||||||
:param ram_used: Current RAM available (GB)
|
|
||||||
:param ram_changed: Change in RAM usage over course of the run (GB)
|
|
||||||
"""
|
|
||||||
if not self._stats[graph_id].nodes.get(invocation_type):
|
if not self._stats[graph_id].nodes.get(invocation_type):
|
||||||
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
||||||
stats = self._stats[graph_id].nodes[invocation_type]
|
stats = self._stats[graph_id].nodes[invocation_type]
|
||||||
@ -262,14 +257,15 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
stats.max_vram = max(stats.max_vram, vram_used)
|
stats.max_vram = max(stats.max_vram, vram_used)
|
||||||
|
|
||||||
def log_stats(self):
|
def log_stats(self):
|
||||||
"""
|
|
||||||
Send the statistics to the system logger at the info level.
|
|
||||||
Stats will only be printed when the execution of the graph
|
|
||||||
is complete.
|
|
||||||
"""
|
|
||||||
completed = set()
|
completed = set()
|
||||||
|
errored = set()
|
||||||
for graph_id, node_log in self._stats.items():
|
for graph_id, node_log in self._stats.items():
|
||||||
current_graph_state = self.graph_execution_manager.get(graph_id)
|
try:
|
||||||
|
current_graph_state = self.graph_execution_manager.get(graph_id)
|
||||||
|
except Exception:
|
||||||
|
errored.add(graph_id)
|
||||||
|
continue
|
||||||
|
|
||||||
if not current_graph_state.is_complete():
|
if not current_graph_state.is_complete():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -302,3 +298,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
for graph_id in completed:
|
for graph_id in completed:
|
||||||
del self._stats[graph_id]
|
del self._stats[graph_id]
|
||||||
del self._cache_stats[graph_id]
|
del self._cache_stats[graph_id]
|
||||||
|
|
||||||
|
for graph_id in errored:
|
||||||
|
del self._stats[graph_id]
|
||||||
|
del self._cache_stats[graph_id]
|
||||||
|
Loading…
Reference in New Issue
Block a user