fix(nodes): fix typing on stats service context manager

This commit is contained in:
psychedelicious 2024-02-19 12:57:05 +11:00
parent e8725a1099
commit 9d27d354cf
2 changed files with 4 additions and 4 deletions

View File

@ -30,7 +30,7 @@ writes to the system log is stored in InvocationServices.performance_statistics.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Iterator from typing import ContextManager
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary
@ -50,7 +50,7 @@ class InvocationStatsServiceBase(ABC):
self, self,
invocation: BaseInvocation, invocation: BaseInvocation,
graph_execution_state_id: str, graph_execution_state_id: str,
) -> Iterator[None]: ) -> ContextManager[None]:
""" """
Return a context object that will capture the statistics on the execution Return a context object that will capture the statistics on the execution
of invocaation. Use with: to place around the part of the code that executes the invocation. of invocaation. Use with: to place around the part of the code that executes the invocation.

View File

@ -2,7 +2,7 @@ import json
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Iterator from typing import Generator
import psutil import psutil
import torch import torch
@ -41,7 +41,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
self._invoker = invoker self._invoker = invoker
@contextmanager @contextmanager
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]: def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Generator[None, None, None]:
# This is to handle case of the model manager not being initialized, which happens # This is to handle case of the model manager not being initialized, which happens
# during some tests. # during some tests.
services = self._invoker.services services = self._invoker.services