diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index 7b419db127..786a683ed9 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -1,5 +1,5 @@ from math import floor -from typing import TYPE_CHECKING, Any, Coroutine, Optional, Protocol, TypeAlias, TypeVar +from typing import TYPE_CHECKING, Any, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar from fastapi_events.handlers.local import local_handler from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny @@ -50,7 +50,7 @@ class EventBase(BaseModel): return event_subclasses -TEvent = TypeVar("TEvent", bound=EventBase) +TEvent = TypeVar("TEvent", bound=EventBase, contravariant=True) FastAPIEvent: TypeAlias = tuple[str, TEvent] """ @@ -59,11 +59,11 @@ Provide a generic type to `TEvent` to specify the payload type. """ -class FastAPIEventFunc(Protocol): - def __call__(self, event: FastAPIEvent[Any]) -> Optional[Coroutine[Any, Any, None]]: ... +class FastAPIEventFunc(Protocol, Generic[TEvent]): + def __call__(self, event: FastAPIEvent[TEvent]) -> Optional[Coroutine[Any, Any, None]]: ... -def register_events(events: set[type[TEvent]] | type[TEvent], func: FastAPIEventFunc) -> None: +def register_events(events: set[type[TEvent]] | type[TEvent], func: FastAPIEventFunc[TEvent]) -> None: """Register a function to handle specific events. :param events: An event or set of events to handle