fix(stats): fix InvocationStatsService types

- move docstrings to ABC
- `start_time: int` -> `start_time: float`
- remove class attribute assignments in `StatsContext`
- add `update_mem_stats()` to ABC
- add class attributes to ABC, because they are referenced in instances of the class. if they should not be on the ABC, then maybe there needs to be some restructuring
This commit is contained in:
psychedelicious 2023-08-18 20:16:45 +10:00
parent 81385d7d35
commit cd9baf8092

View File

@ -50,9 +50,36 @@ from invokeai.backend.model_management.model_cache import CacheStats
GIG = 1073741824
@dataclass
class NodeStats:
"""Class for tracking execution stats of an invocation node"""
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0
@dataclass
class NodeLog:
"""Class for tracking node usage"""
# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)
class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics"
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
# {graph_id => NodeLog}
_stats: Dict[str, NodeLog]
_cache_stats: Dict[str, CacheStats]
ram_used: float
ram_changed: float
@abstractmethod
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
"""
@ -95,8 +122,6 @@ class InvocationStatsServiceBase(ABC):
invocation_type: str,
time_used: float,
vram_used: float,
ram_used: float,
ram_changed: float,
):
"""
Add timing information on execution of a node. Usually
@ -105,8 +130,6 @@ class InvocationStatsServiceBase(ABC):
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
"""
pass
@ -117,25 +140,19 @@ class InvocationStatsServiceBase(ABC):
"""
pass
@abstractmethod
def update_mem_stats(
self,
ram_used: float,
ram_changed: float,
):
"""
Update the collector with RAM memory usage info.
@dataclass
class NodeStats:
"""Class for tracking execution stats of an invocation node"""
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0
@dataclass
class NodeLog:
"""Class for tracking node usage"""
# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)
:param ram_used: How much RAM is currently in use.
:param ram_changed: How much RAM changed since last generation.
"""
pass
class InvocationStatsService(InvocationStatsServiceBase):
@ -153,12 +170,12 @@ class InvocationStatsService(InvocationStatsServiceBase):
class StatsContext:
"""Context manager for collecting statistics."""
invocation: BaseInvocation = None
collector: "InvocationStatsServiceBase" = None
graph_id: str = None
start_time: int = 0
ram_used: int = 0
model_manager: ModelManagerService = None
invocation: BaseInvocation
collector: "InvocationStatsServiceBase"
graph_id: str
start_time: float
ram_used: int
model_manager: ModelManagerService
def __init__(
self,
@ -171,7 +188,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
self.invocation = invocation
self.collector = collector
self.graph_id = graph_id
self.start_time = 0
self.start_time = 0.0
self.ram_used = 0
self.model_manager = model_manager
@ -192,7 +209,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
)
self.collector.update_invocation_stats(
graph_id=self.graph_id,
invocation_type=self.invocation.type,
invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
time_used=time.time() - self.start_time,
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
)
@ -203,11 +220,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
graph_execution_state_id: str,
model_manager: ModelManagerService,
) -> StatsContext:
"""
Return a context object that will capture the statistics.
:param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session.
"""
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog()
self._cache_stats[graph_execution_state_id] = CacheStats()
@ -218,7 +230,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
self._stats = {}
def reset_stats(self, graph_execution_id: str):
"""Zero the statistics for the indicated graph."""
try:
self._stats.pop(graph_execution_id)
except KeyError:
@ -229,12 +240,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
ram_used: float,
ram_changed: float,
):
"""
Update the collector with RAM memory usage info.
:param ram_used: How much RAM is currently in use.
:param ram_changed: How much RAM changed since last generation.
"""
self.ram_used = ram_used
self.ram_changed = ram_changed
@ -245,16 +250,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
time_used: float,
vram_used: float,
):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
"""
if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats()
stats = self._stats[graph_id].nodes[invocation_type]
@ -263,11 +258,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
stats.max_vram = max(stats.max_vram, vram_used)
def log_stats(self):
"""
Send the statistics to the system logger at the info level.
Stats will only be printed when the execution of the graph
is complete.
"""
completed = set()
for graph_id, node_log in self._stats.items():
try: