feat(events): remove payload registry, add method to get event classes

We don't need to use the payload schema registry. All our events are dispatched as pydantic models, which are already validated on instantiation.

We do want to add all events to the OpenAPI schema, and we referred to the payload schema registry for this. To get all events, add a simple helper to EventBase. This is functionally identical to using the schema registry.
This commit is contained in:
psychedelicious 2024-04-01 17:29:02 +11:00
parent 18b4f1b72a
commit d97186dfc8
2 changed files with 18 additions and 37 deletions

View File

@ -5,7 +5,7 @@ import socket
from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path
from typing import Any, cast
from typing import Any
import torch
import uvicorn
@ -17,8 +17,6 @@ from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from fastapi_events.registry.payload_schema import registry as fastapi_events_registry
from pydantic import BaseModel
from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available
@ -29,6 +27,7 @@ import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import TorchDevice
@ -185,15 +184,13 @@ def custom_openapi() -> dict[str, Any]:
invoker_schema["class"] = "invocation"
# Add all pydantic event schemas registered with fastapi-events
for payload in fastapi_events_registry.data.values():
json_schema = cast(BaseModel, payload).model_json_schema(
mode="serialization", ref_template="#/components/schemas/{model}"
)
for event in EventBase.get_events():
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
if "$defs" in json_schema:
for schema_key, schema in json_schema["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema
del json_schema["$defs"]
openapi_schema["components"]["schemas"][payload.__name__] = json_schema
openapi_schema["components"]["schemas"][event.__name__] = json_schema
app.openapi_schema = openapi_schema
return app.openapi_schema

View File

@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Protocol, TypeAlias, TypeVar
from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
@ -31,12 +30,23 @@ class EventBase(BaseModel):
A timestamp is automatically added to the event when it is created.
"""
__event_name__ = "event_base"
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
@classmethod
def get_events(cls) -> set[type["EventBase"]]:
"""Get a set of all event models."""
event_subclasses: set[type["EventBase"]] = set()
for subclass in cls.__subclasses__():
# We only want to include subclasses that are event models, not intermediary classes
if hasattr(subclass, "__event_name__"):
event_subclasses.add(subclass)
event_subclasses.update(subclass.get_events())
return event_subclasses
TEvent = TypeVar("TEvent", bound=EventBase)
@ -92,7 +102,6 @@ class InvocationEventBase(SessionEventBase):
invocation_type: str = Field(description="The type of invocation")
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class InvocationStartedEvent(InvocationEventBase):
"""Event model for invocation_started"""
@ -111,7 +120,6 @@ class InvocationStartedEvent(InvocationEventBase):
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class InvocationDenoiseProgressEvent(InvocationEventBase):
"""Event model for invocation_denoise_progress"""
@ -144,7 +152,6 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class InvocationCompleteEvent(InvocationEventBase):
"""Event model for invocation_complete"""
@ -168,7 +175,6 @@ class InvocationCompleteEvent(InvocationEventBase):
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class InvocationErrorEvent(InvocationEventBase):
"""Event model for invocation_error"""
@ -201,7 +207,6 @@ class InvocationErrorEvent(InvocationEventBase):
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class SessionStartedEvent(SessionEventBase):
"""Event model for session_started"""
@ -217,7 +222,6 @@ class SessionStartedEvent(SessionEventBase):
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class SessionCompleteEvent(SessionEventBase):
"""Event model for session_complete"""
@ -233,7 +237,6 @@ class SessionCompleteEvent(SessionEventBase):
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class SessionCanceledEvent(SessionEventBase):
"""Event model for session_canceled"""
@ -249,7 +252,6 @@ class SessionCanceledEvent(SessionEventBase):
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class QueueItemStatusChangedEvent(QueueItemEventBase):
"""Event model for queue_item_status_changed"""
@ -289,7 +291,6 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class BatchEnqueuedEvent(QueueEventBase):
"""Event model for batch_enqueued"""
@ -313,7 +314,6 @@ class BatchEnqueuedEvent(QueueEventBase):
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class QueueClearedEvent(QueueEventBase):
"""Event model for queue_cleared"""
@ -330,7 +330,6 @@ class DownloadEventBase(EventBase):
source: str = Field(description="The source of the download")
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadStartedEvent(DownloadEventBase):
"""Event model for download_started"""
@ -344,7 +343,6 @@ class DownloadStartedEvent(DownloadEventBase):
return cls(source=str(job.source), download_path=job.download_path.as_posix())
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadProgressEvent(DownloadEventBase):
"""Event model for download_progress"""
@ -365,7 +363,6 @@ class DownloadProgressEvent(DownloadEventBase):
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadCompleteEvent(DownloadEventBase):
"""Event model for download_complete"""
@ -380,7 +377,6 @@ class DownloadCompleteEvent(DownloadEventBase):
return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadCancelledEvent(DownloadEventBase):
"""Event model for download_cancelled"""
@ -391,7 +387,6 @@ class DownloadCancelledEvent(DownloadEventBase):
return cls(source=str(job.source))
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadErrorEvent(DownloadEventBase):
"""Event model for download_error"""
@ -411,7 +406,6 @@ class ModelEventBase(EventBase):
"""Base class for events associated with a model"""
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelLoadStartedEvent(ModelEventBase):
"""Event model for model_load_started"""
@ -425,7 +419,6 @@ class ModelLoadStartedEvent(ModelEventBase):
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelLoadCompleteEvent(ModelEventBase):
"""Event model for model_load_complete"""
@ -439,7 +432,6 @@ class ModelLoadCompleteEvent(ModelEventBase):
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallDownloadProgressEvent(ModelEventBase):
"""Event model for model_install_download_progress"""
@ -475,7 +467,6 @@ class ModelInstallDownloadProgressEvent(ModelEventBase):
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
"""Emitted once when an install job becomes active."""
@ -489,7 +480,6 @@ class ModelInstallDownloadsCompleteEvent(ModelEventBase):
return cls(id=job.id, source=str(job.source))
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallStartedEvent(ModelEventBase):
"""Event model for model_install_started"""
@ -503,7 +493,6 @@ class ModelInstallStartedEvent(ModelEventBase):
return cls(id=job.id, source=str(job.source))
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallCompleteEvent(ModelEventBase):
"""Event model for model_install_complete"""
@ -520,7 +509,6 @@ class ModelInstallCompleteEvent(ModelEventBase):
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallCancelledEvent(ModelEventBase):
"""Event model for model_install_cancelled"""
@ -534,7 +522,6 @@ class ModelInstallCancelledEvent(ModelEventBase):
return cls(id=job.id, source=str(job.source))
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallErrorEvent(ModelEventBase):
"""Event model for model_install_error"""
@ -560,7 +547,6 @@ class BulkDownloadEventBase(EventBase):
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class BulkDownloadStartedEvent(BulkDownloadEventBase):
"""Event model for bulk_download_started"""
@ -577,7 +563,6 @@ class BulkDownloadStartedEvent(BulkDownloadEventBase):
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
"""Event model for bulk_download_complete"""
@ -594,7 +579,6 @@ class BulkDownloadCompleteEvent(BulkDownloadEventBase):
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class BulkDownloadErrorEvent(BulkDownloadEventBase):
"""Event model for bulk_download_error"""