mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(events): simplify event classes
- Remove ABCs, they do not work well with pydantic - Remove the event type classvar - unused - Remove clever logic to require an event name - we already get validation for this during schema registration. - Rename event bases to all end in "Base"
This commit is contained in:
parent
1b9bbaa5a4
commit
2dc752ea83
@ -10,14 +10,14 @@ from invokeai.app.services.events.events_common import (
|
|||||||
BatchEnqueuedEvent,
|
BatchEnqueuedEvent,
|
||||||
BulkDownloadCompleteEvent,
|
BulkDownloadCompleteEvent,
|
||||||
BulkDownloadErrorEvent,
|
BulkDownloadErrorEvent,
|
||||||
BulkDownloadEvent,
|
BulkDownloadEventBase,
|
||||||
BulkDownloadStartedEvent,
|
BulkDownloadStartedEvent,
|
||||||
FastAPIEvent,
|
FastAPIEvent,
|
||||||
InvocationCompleteEvent,
|
InvocationCompleteEvent,
|
||||||
InvocationDenoiseProgressEvent,
|
InvocationDenoiseProgressEvent,
|
||||||
InvocationErrorEvent,
|
InvocationErrorEvent,
|
||||||
InvocationStartedEvent,
|
InvocationStartedEvent,
|
||||||
ModelEvent,
|
ModelEventBase,
|
||||||
ModelInstallCancelledEvent,
|
ModelInstallCancelledEvent,
|
||||||
ModelInstallCompleteEvent,
|
ModelInstallCompleteEvent,
|
||||||
ModelInstallDownloadProgressEvent,
|
ModelInstallDownloadProgressEvent,
|
||||||
@ -26,7 +26,7 @@ from invokeai.app.services.events.events_common import (
|
|||||||
ModelLoadCompleteEvent,
|
ModelLoadCompleteEvent,
|
||||||
ModelLoadStartedEvent,
|
ModelLoadStartedEvent,
|
||||||
QueueClearedEvent,
|
QueueClearedEvent,
|
||||||
QueueEvent,
|
QueueEventBase,
|
||||||
QueueItemStatusChangedEvent,
|
QueueItemStatusChangedEvent,
|
||||||
SessionCanceledEvent,
|
SessionCanceledEvent,
|
||||||
SessionCompleteEvent,
|
SessionCompleteEvent,
|
||||||
@ -106,14 +106,14 @@ class SocketIO:
|
|||||||
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
|
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
|
||||||
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
||||||
|
|
||||||
async def _handle_queue_event(self, event: FastAPIEvent[QueueEvent]):
|
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
|
||||||
event_name, payload = event
|
event_name, payload = event
|
||||||
await self._sio.emit(event=event_name, data=payload.model_dump(), room=payload.queue_id)
|
await self._sio.emit(event=event_name, data=payload.model_dump(), room=payload.queue_id)
|
||||||
|
|
||||||
async def _handle_model_event(self, event: FastAPIEvent[ModelEvent]) -> None:
|
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase]) -> None:
|
||||||
event_name, payload = event
|
event_name, payload = event
|
||||||
await self._sio.emit(event=event_name, data=payload.model_dump())
|
await self._sio.emit(event=event_name, data=payload.model_dump())
|
||||||
|
|
||||||
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEvent]) -> None:
|
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
|
||||||
event_name, payload = event
|
event_name, payload = event
|
||||||
await self._sio.emit(event=event_name, data=payload.model_dump(), room=payload.bulk_download_id)
|
await self._sio.emit(event=event_name, data=payload.model_dump(), room=payload.bulk_download_id)
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from invokeai.app.services.events.events_common import (
|
from invokeai.app.services.events.events_common import (
|
||||||
BaseEvent,
|
|
||||||
BatchEnqueuedEvent,
|
BatchEnqueuedEvent,
|
||||||
BulkDownloadCompleteEvent,
|
BulkDownloadCompleteEvent,
|
||||||
BulkDownloadErrorEvent,
|
BulkDownloadErrorEvent,
|
||||||
@ -14,6 +13,7 @@ from invokeai.app.services.events.events_common import (
|
|||||||
DownloadErrorEvent,
|
DownloadErrorEvent,
|
||||||
DownloadProgressEvent,
|
DownloadProgressEvent,
|
||||||
DownloadStartedEvent,
|
DownloadStartedEvent,
|
||||||
|
EventBase,
|
||||||
InvocationCompleteEvent,
|
InvocationCompleteEvent,
|
||||||
InvocationDenoiseProgressEvent,
|
InvocationDenoiseProgressEvent,
|
||||||
InvocationErrorEvent,
|
InvocationErrorEvent,
|
||||||
@ -35,7 +35,7 @@ from invokeai.app.services.events.events_common import (
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||||
from invokeai.app.services.events.events_common import BaseEvent
|
from invokeai.app.services.events.events_common import EventBase
|
||||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||||
from invokeai.app.services.session_queue.session_queue_common import (
|
from invokeai.app.services.session_queue.session_queue_common import (
|
||||||
@ -50,7 +50,7 @@ if TYPE_CHECKING:
|
|||||||
class EventServiceBase:
|
class EventServiceBase:
|
||||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||||
|
|
||||||
def dispatch(self, event: "BaseEvent") -> None:
|
def dispatch(self, event: "EventBase") -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# region: Invocation
|
# region: Invocation
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
from abc import ABC
|
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Protocol, TypeAlias, TypeVar
|
||||||
from enum import Enum
|
|
||||||
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Optional, Protocol, TypeAlias, TypeVar
|
|
||||||
|
|
||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.registry.payload_schema import registry as payload_schema
|
from fastapi_events.registry.payload_schema import registry as payload_schema
|
||||||
@ -22,39 +20,22 @@ if TYPE_CHECKING:
|
|||||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||||
|
|
||||||
|
|
||||||
class EventType(str, Enum):
|
class EventBase(BaseModel):
|
||||||
QUEUE = "queue"
|
|
||||||
MODEL = "model"
|
|
||||||
DOWNLOAD = "download"
|
|
||||||
BULK_IMAGE_DOWNLOAD = "bulk_image_download"
|
|
||||||
|
|
||||||
|
|
||||||
class BaseEvent(BaseModel, ABC):
|
|
||||||
"""Base class for all events. All events must inherit from this class.
|
"""Base class for all events. All events must inherit from this class.
|
||||||
|
|
||||||
Events must define the following class attributes:
|
Events must define a class attribute `__event_name__` to identify the event.
|
||||||
- `__event_name__: str`: The name of the event
|
|
||||||
- `__event_type__: EventType`: The type of the event
|
|
||||||
|
|
||||||
All other attributes should be defined as normal for a pydantic model.
|
All other attributes should be defined as normal for a pydantic model.
|
||||||
|
|
||||||
A timestamp is automatically added to the event when it is created.
|
A timestamp is automatically added to the event when it is created.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__event_name__: ClassVar[str] = ... # pyright: ignore [reportAssignmentType]
|
|
||||||
__event_type__: ClassVar[EventType] = ... # pyright: ignore [reportAssignmentType]
|
|
||||||
|
|
||||||
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
|
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs: ConfigDict):
|
|
||||||
for required_attr in ("__event_name__", "__event_type__"):
|
|
||||||
if getattr(cls, required_attr) is ...:
|
|
||||||
raise TypeError(f"{cls.__name__} must define {required_attr}")
|
|
||||||
|
|
||||||
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
|
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
|
||||||
|
|
||||||
|
|
||||||
TEvent = TypeVar("TEvent", bound=BaseEvent)
|
TEvent = TypeVar("TEvent", bound=EventBase)
|
||||||
|
|
||||||
FastAPIEvent: TypeAlias = tuple[str, TEvent]
|
FastAPIEvent: TypeAlias = tuple[str, TEvent]
|
||||||
"""
|
"""
|
||||||
@ -77,37 +58,28 @@ def register_events(events: set[type[TEvent]], func: FastAPIEventFunc) -> None:
|
|||||||
local_handler.register(event_name=event.__event_name__, _func=func)
|
local_handler.register(event_name=event.__event_name__, _func=func)
|
||||||
|
|
||||||
|
|
||||||
class QueueEvent(BaseEvent, ABC):
|
class QueueEventBase(EventBase):
|
||||||
"""Base class for queue events"""
|
"""Base class for queue events"""
|
||||||
|
|
||||||
__event_type__ = EventType.QUEUE
|
|
||||||
__event_name__ = "queue_event"
|
|
||||||
|
|
||||||
queue_id: str = Field(description="The ID of the queue")
|
queue_id: str = Field(description="The ID of the queue")
|
||||||
|
|
||||||
|
|
||||||
class QueueItemEvent(QueueEvent, ABC):
|
class QueueItemEventBase(QueueEventBase):
|
||||||
"""Base class for queue item events"""
|
"""Base class for queue item events"""
|
||||||
|
|
||||||
__event_name__ = "queue_item_event"
|
|
||||||
|
|
||||||
item_id: int = Field(description="The ID of the queue item")
|
item_id: int = Field(description="The ID of the queue item")
|
||||||
batch_id: str = Field(description="The ID of the queue batch")
|
batch_id: str = Field(description="The ID of the queue batch")
|
||||||
|
|
||||||
|
|
||||||
class SessionEvent(QueueItemEvent, ABC):
|
class SessionEventBase(QueueItemEventBase):
|
||||||
"""Base class for session (aka graph execution state) events"""
|
"""Base class for session (aka graph execution state) events"""
|
||||||
|
|
||||||
__event_name__ = "session_event"
|
|
||||||
|
|
||||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||||
|
|
||||||
|
|
||||||
class InvocationEvent(SessionEvent, ABC):
|
class InvocationEventBase(SessionEventBase):
|
||||||
"""Base class for invocation events"""
|
"""Base class for invocation events"""
|
||||||
|
|
||||||
__event_name__ = "invocation_event"
|
|
||||||
|
|
||||||
queue_id: str = Field(description="The ID of the queue")
|
queue_id: str = Field(description="The ID of the queue")
|
||||||
item_id: int = Field(description="The ID of the queue item")
|
item_id: int = Field(description="The ID of the queue item")
|
||||||
batch_id: str = Field(description="The ID of the queue batch")
|
batch_id: str = Field(description="The ID of the queue batch")
|
||||||
@ -118,7 +90,7 @@ class InvocationEvent(SessionEvent, ABC):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class InvocationStartedEvent(InvocationEvent):
|
class InvocationStartedEvent(InvocationEventBase):
|
||||||
"""Emitted when an invocation is started"""
|
"""Emitted when an invocation is started"""
|
||||||
|
|
||||||
__event_name__ = "invocation_started"
|
__event_name__ = "invocation_started"
|
||||||
@ -137,7 +109,7 @@ class InvocationStartedEvent(InvocationEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class InvocationDenoiseProgressEvent(InvocationEvent):
|
class InvocationDenoiseProgressEvent(InvocationEventBase):
|
||||||
"""Emitted at each step during denoising of an invocation."""
|
"""Emitted at each step during denoising of an invocation."""
|
||||||
|
|
||||||
__event_name__ = "invocation_denoise_progress"
|
__event_name__ = "invocation_denoise_progress"
|
||||||
@ -170,7 +142,7 @@ class InvocationDenoiseProgressEvent(InvocationEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class InvocationCompleteEvent(InvocationEvent):
|
class InvocationCompleteEvent(InvocationEventBase):
|
||||||
"""Emitted when an invocation is complete"""
|
"""Emitted when an invocation is complete"""
|
||||||
|
|
||||||
__event_name__ = "invocation_complete"
|
__event_name__ = "invocation_complete"
|
||||||
@ -194,7 +166,7 @@ class InvocationCompleteEvent(InvocationEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class InvocationErrorEvent(InvocationEvent):
|
class InvocationErrorEvent(InvocationEventBase):
|
||||||
"""Emitted when an invocation encounters an error"""
|
"""Emitted when an invocation encounters an error"""
|
||||||
|
|
||||||
__event_name__ = "invocation_error"
|
__event_name__ = "invocation_error"
|
||||||
@ -227,7 +199,7 @@ class InvocationErrorEvent(InvocationEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class SessionStartedEvent(SessionEvent):
|
class SessionStartedEvent(SessionEventBase):
|
||||||
"""Emitted when a session has started"""
|
"""Emitted when a session has started"""
|
||||||
|
|
||||||
__event_name__ = "session_started"
|
__event_name__ = "session_started"
|
||||||
@ -243,7 +215,7 @@ class SessionStartedEvent(SessionEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class SessionCompleteEvent(SessionEvent):
|
class SessionCompleteEvent(SessionEventBase):
|
||||||
"""Emitted when a session has completed all invocations"""
|
"""Emitted when a session has completed all invocations"""
|
||||||
|
|
||||||
__event_name__ = "session_complete"
|
__event_name__ = "session_complete"
|
||||||
@ -259,7 +231,7 @@ class SessionCompleteEvent(SessionEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class SessionCanceledEvent(SessionEvent):
|
class SessionCanceledEvent(SessionEventBase):
|
||||||
"""Emitted when a session is canceled"""
|
"""Emitted when a session is canceled"""
|
||||||
|
|
||||||
__event_name__ = "session_canceled"
|
__event_name__ = "session_canceled"
|
||||||
@ -275,7 +247,7 @@ class SessionCanceledEvent(SessionEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class QueueItemStatusChangedEvent(SessionEvent):
|
class QueueItemStatusChangedEvent(SessionEventBase):
|
||||||
"""Emitted when a queue item's status changes"""
|
"""Emitted when a queue item's status changes"""
|
||||||
|
|
||||||
__event_name__ = "queue_item_status_changed"
|
__event_name__ = "queue_item_status_changed"
|
||||||
@ -314,7 +286,7 @@ class QueueItemStatusChangedEvent(SessionEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class BatchEnqueuedEvent(QueueEvent):
|
class BatchEnqueuedEvent(QueueEventBase):
|
||||||
"""Emitted when a batch is enqueued"""
|
"""Emitted when a batch is enqueued"""
|
||||||
|
|
||||||
__event_name__ = "batch_enqueued"
|
__event_name__ = "batch_enqueued"
|
||||||
@ -338,7 +310,7 @@ class BatchEnqueuedEvent(QueueEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class QueueClearedEvent(QueueEvent):
|
class QueueClearedEvent(QueueEventBase):
|
||||||
"""Emitted when a queue is cleared"""
|
"""Emitted when a queue is cleared"""
|
||||||
|
|
||||||
__event_name__ = "queue_cleared"
|
__event_name__ = "queue_cleared"
|
||||||
@ -348,17 +320,14 @@ class QueueClearedEvent(QueueEvent):
|
|||||||
return cls(queue_id=queue_id)
|
return cls(queue_id=queue_id)
|
||||||
|
|
||||||
|
|
||||||
class DownloadEvent(BaseEvent, ABC):
|
class DownloadEventBase(EventBase):
|
||||||
"""Base class for events associated with a download"""
|
"""Base class for events associated with a download"""
|
||||||
|
|
||||||
__event_type__ = EventType.DOWNLOAD
|
|
||||||
__event_name__ = "download_event"
|
|
||||||
|
|
||||||
source: str = Field(description="The source of the download")
|
source: str = Field(description="The source of the download")
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class DownloadStartedEvent(DownloadEvent):
|
class DownloadStartedEvent(DownloadEventBase):
|
||||||
"""Emitted when a download is started"""
|
"""Emitted when a download is started"""
|
||||||
|
|
||||||
__event_name__ = "download_started"
|
__event_name__ = "download_started"
|
||||||
@ -371,7 +340,7 @@ class DownloadStartedEvent(DownloadEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class DownloadProgressEvent(DownloadEvent):
|
class DownloadProgressEvent(DownloadEventBase):
|
||||||
"""Emitted at intervals during a download"""
|
"""Emitted at intervals during a download"""
|
||||||
|
|
||||||
__event_name__ = "download_progress"
|
__event_name__ = "download_progress"
|
||||||
@ -386,7 +355,7 @@ class DownloadProgressEvent(DownloadEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class DownloadCompleteEvent(DownloadEvent):
|
class DownloadCompleteEvent(DownloadEventBase):
|
||||||
"""Emitted when a download is completed"""
|
"""Emitted when a download is completed"""
|
||||||
|
|
||||||
__event_name__ = "download_complete"
|
__event_name__ = "download_complete"
|
||||||
@ -400,7 +369,7 @@ class DownloadCompleteEvent(DownloadEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class DownloadCancelledEvent(DownloadEvent):
|
class DownloadCancelledEvent(DownloadEventBase):
|
||||||
"""Emitted when a download is cancelled"""
|
"""Emitted when a download is cancelled"""
|
||||||
|
|
||||||
__event_name__ = "download_cancelled"
|
__event_name__ = "download_cancelled"
|
||||||
@ -411,7 +380,7 @@ class DownloadCancelledEvent(DownloadEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class DownloadErrorEvent(DownloadEvent):
|
class DownloadErrorEvent(DownloadEventBase):
|
||||||
"""Emitted when a download encounters an error"""
|
"""Emitted when a download encounters an error"""
|
||||||
|
|
||||||
__event_name__ = "download_error"
|
__event_name__ = "download_error"
|
||||||
@ -424,15 +393,12 @@ class DownloadErrorEvent(DownloadEvent):
|
|||||||
return cls(source=source, error_type=error_type, error=error)
|
return cls(source=source, error_type=error_type, error=error)
|
||||||
|
|
||||||
|
|
||||||
class ModelEvent(BaseEvent, ABC):
|
class ModelEventBase(EventBase):
|
||||||
"""Base class for events associated with a model"""
|
"""Base class for events associated with a model"""
|
||||||
|
|
||||||
__event_type__ = EventType.MODEL
|
|
||||||
__event_name__ = "model_event"
|
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class ModelLoadStartedEvent(ModelEvent):
|
class ModelLoadStartedEvent(ModelEventBase):
|
||||||
"""Emitted when a model is requested"""
|
"""Emitted when a model is requested"""
|
||||||
|
|
||||||
__event_name__ = "model_load_started"
|
__event_name__ = "model_load_started"
|
||||||
@ -446,7 +412,7 @@ class ModelLoadStartedEvent(ModelEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class ModelLoadCompleteEvent(ModelEvent):
|
class ModelLoadCompleteEvent(ModelEventBase):
|
||||||
"""Emitted when a model is requested"""
|
"""Emitted when a model is requested"""
|
||||||
|
|
||||||
__event_name__ = "model_load_complete"
|
__event_name__ = "model_load_complete"
|
||||||
@ -460,7 +426,7 @@ class ModelLoadCompleteEvent(ModelEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class ModelInstallDownloadProgressEvent(ModelEvent):
|
class ModelInstallDownloadProgressEvent(ModelEventBase):
|
||||||
"""Emitted at intervals while the install job is in progress (remote models only)."""
|
"""Emitted at intervals while the install job is in progress (remote models only)."""
|
||||||
|
|
||||||
__event_name__ = "model_install_download_progress"
|
__event_name__ = "model_install_download_progress"
|
||||||
@ -496,7 +462,7 @@ class ModelInstallDownloadProgressEvent(ModelEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class ModelInstallDownloadsCompleteEvent(ModelEvent):
|
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
|
||||||
"""Emitted once when an install job becomes active."""
|
"""Emitted once when an install job becomes active."""
|
||||||
|
|
||||||
__event_name__ = "model_install_downloads_complete"
|
__event_name__ = "model_install_downloads_complete"
|
||||||
@ -510,7 +476,7 @@ class ModelInstallDownloadsCompleteEvent(ModelEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class ModelInstallStartedEvent(ModelEvent):
|
class ModelInstallStartedEvent(ModelEventBase):
|
||||||
"""Emitted once when an install job becomes active."""
|
"""Emitted once when an install job becomes active."""
|
||||||
|
|
||||||
__event_name__ = "model_install_started"
|
__event_name__ = "model_install_started"
|
||||||
@ -524,7 +490,7 @@ class ModelInstallStartedEvent(ModelEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class ModelInstallCompleteEvent(ModelEvent):
|
class ModelInstallCompleteEvent(ModelEventBase):
|
||||||
"""Emitted when an install job is completed successfully."""
|
"""Emitted when an install job is completed successfully."""
|
||||||
|
|
||||||
__event_name__ = "model_install_complete"
|
__event_name__ = "model_install_complete"
|
||||||
@ -541,7 +507,7 @@ class ModelInstallCompleteEvent(ModelEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class ModelInstallCancelledEvent(ModelEvent):
|
class ModelInstallCancelledEvent(ModelEventBase):
|
||||||
"""Emitted when an install job is cancelled."""
|
"""Emitted when an install job is cancelled."""
|
||||||
|
|
||||||
__event_name__ = "model_install_cancelled"
|
__event_name__ = "model_install_cancelled"
|
||||||
@ -555,7 +521,7 @@ class ModelInstallCancelledEvent(ModelEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class ModelInstallErrorEvent(ModelEvent):
|
class ModelInstallErrorEvent(ModelEventBase):
|
||||||
"""Emitted when an install job encounters an exception."""
|
"""Emitted when an install job encounters an exception."""
|
||||||
|
|
||||||
__event_name__ = "model_install_error"
|
__event_name__ = "model_install_error"
|
||||||
@ -572,19 +538,16 @@ class ModelInstallErrorEvent(ModelEvent):
|
|||||||
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
|
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
|
||||||
|
|
||||||
|
|
||||||
class BulkDownloadEvent(BaseEvent, ABC):
|
class BulkDownloadEventBase(EventBase):
|
||||||
"""Base class for events associated with a bulk image download"""
|
"""Base class for events associated with a bulk image download"""
|
||||||
|
|
||||||
__event_type__ = EventType.BULK_IMAGE_DOWNLOAD
|
|
||||||
__event_name__ = "bulk_image_download_event"
|
|
||||||
|
|
||||||
bulk_download_id: str = Field(description="The ID of the bulk image download")
|
bulk_download_id: str = Field(description="The ID of the bulk image download")
|
||||||
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
|
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
|
||||||
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
|
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class BulkDownloadStartedEvent(BulkDownloadEvent):
|
class BulkDownloadStartedEvent(BulkDownloadEventBase):
|
||||||
"""Emitted when a bulk image download is started"""
|
"""Emitted when a bulk image download is started"""
|
||||||
|
|
||||||
__event_name__ = "bulk_download_started"
|
__event_name__ = "bulk_download_started"
|
||||||
@ -601,7 +564,7 @@ class BulkDownloadStartedEvent(BulkDownloadEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class BulkDownloadCompleteEvent(BulkDownloadEvent):
|
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
|
||||||
"""Emitted when a bulk image download is started"""
|
"""Emitted when a bulk image download is started"""
|
||||||
|
|
||||||
__event_name__ = "bulk_download_complete"
|
__event_name__ = "bulk_download_complete"
|
||||||
@ -618,7 +581,7 @@ class BulkDownloadCompleteEvent(BulkDownloadEvent):
|
|||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
||||||
class BulkDownloadErrorEvent(BulkDownloadEvent):
|
class BulkDownloadErrorEvent(BulkDownloadEventBase):
|
||||||
"""Emitted when a bulk image download is started"""
|
"""Emitted when a bulk image download is started"""
|
||||||
|
|
||||||
__event_name__ = "bulk_download_error"
|
__event_name__ = "bulk_download_error"
|
||||||
|
@ -7,7 +7,7 @@ from queue import Empty, Queue
|
|||||||
from fastapi_events.dispatcher import dispatch
|
from fastapi_events.dispatcher import dispatch
|
||||||
|
|
||||||
from invokeai.app.services.events.events_common import (
|
from invokeai.app.services.events.events_common import (
|
||||||
BaseEvent,
|
EventBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .events_base import EventServiceBase
|
from .events_base import EventServiceBase
|
||||||
@ -16,7 +16,7 @@ from .events_base import EventServiceBase
|
|||||||
class FastAPIEventService(EventServiceBase):
|
class FastAPIEventService(EventServiceBase):
|
||||||
def __init__(self, event_handler_id: int) -> None:
|
def __init__(self, event_handler_id: int) -> None:
|
||||||
self.event_handler_id = event_handler_id
|
self.event_handler_id = event_handler_id
|
||||||
self._queue = Queue[BaseEvent | None]()
|
self._queue = Queue[EventBase | None]()
|
||||||
self._stop_event = threading.Event()
|
self._stop_event = threading.Event()
|
||||||
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
|
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
|
||||||
|
|
||||||
@ -26,7 +26,7 @@ class FastAPIEventService(EventServiceBase):
|
|||||||
self._stop_event.set()
|
self._stop_event.set()
|
||||||
self._queue.put(None)
|
self._queue.put(None)
|
||||||
|
|
||||||
def dispatch(self, event: BaseEvent) -> None:
|
def dispatch(self, event: EventBase) -> None:
|
||||||
self._queue.put(event)
|
self._queue.put(event)
|
||||||
|
|
||||||
async def _dispatch_from_queue(self, stop_event: threading.Event):
|
async def _dispatch_from_queue(self, stop_event: threading.Event):
|
||||||
|
@ -9,7 +9,7 @@ from invokeai.app.services.events.events_common import (
|
|||||||
BatchEnqueuedEvent,
|
BatchEnqueuedEvent,
|
||||||
FastAPIEvent,
|
FastAPIEvent,
|
||||||
QueueClearedEvent,
|
QueueClearedEvent,
|
||||||
QueueEvent,
|
QueueEventBase,
|
||||||
QueueItemStatusChangedEvent,
|
QueueItemStatusChangedEvent,
|
||||||
SessionCanceledEvent,
|
SessionCanceledEvent,
|
||||||
register_events,
|
register_events,
|
||||||
@ -332,7 +332,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
def _poll_now(self) -> None:
|
def _poll_now(self) -> None:
|
||||||
self._poll_now_event.set()
|
self._poll_now_event.set()
|
||||||
|
|
||||||
async def _on_queue_event(self, event: FastAPIEvent[QueueEvent]) -> None:
|
async def _on_queue_event(self, event: FastAPIEvent[QueueEventBase]) -> None:
|
||||||
_event_name, payload = event
|
_event_name, payload = event
|
||||||
if (
|
if (
|
||||||
isinstance(payload, SessionCanceledEvent)
|
isinstance(payload, SessionCanceledEvent)
|
||||||
|
Loading…
Reference in New Issue
Block a user