diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 710c878c9b..e01742a4e9 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -5,7 +5,7 @@ import socket from contextlib import asynccontextmanager from inspect import signature from pathlib import Path -from typing import Any, cast +from typing import Any import torch import uvicorn @@ -17,8 +17,6 @@ from fastapi.openapi.utils import get_openapi from fastapi.responses import HTMLResponse from fastapi_events.handlers.local import local_handler from fastapi_events.middleware import EventHandlerASGIMiddleware -from fastapi_events.registry.payload_schema import registry as fastapi_events_registry -from pydantic import BaseModel from pydantic.json_schema import models_json_schema from torch.backends.mps import is_available as is_mps_available @@ -29,6 +27,7 @@ import invokeai.frontend.web as web_dir from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.events.events_common import EventBase from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.backend.util.devices import TorchDevice @@ -185,15 +184,13 @@ def custom_openapi() -> dict[str, Any]: invoker_schema["class"] = "invocation" # Add all pydantic event schemas registered with fastapi-events - for payload in fastapi_events_registry.data.values(): - json_schema = cast(BaseModel, payload).model_json_schema( - mode="serialization", ref_template="#/components/schemas/{model}" - ) + for event in EventBase.get_events(): + json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}") if "$defs" in json_schema: for schema_key, schema in json_schema["$defs"].items(): openapi_schema["components"]["schemas"][schema_key] = schema del json_schema["$defs"] - openapi_schema["components"]["schemas"][payload.__name__] = json_schema + openapi_schema["components"]["schemas"][event.__name__] = json_schema app.openapi_schema = openapi_schema return app.openapi_schema diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index e007b5a0d3..e8d835faba 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING, Any, Coroutine, Optional, Protocol, TypeAlias, TypeVar from fastapi_events.handlers.local import local_handler -from fastapi_events.registry.payload_schema import registry as payload_schema from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput @@ -31,12 +30,23 @@ class EventBase(BaseModel): A timestamp is automatically added to the event when it is created. """ - __event_name__ = "event_base" - timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp) model_config = ConfigDict(json_schema_serialization_defaults_required=True) + @classmethod + def get_events(cls) -> set[type["EventBase"]]: + """Get a set of all event models.""" + + event_subclasses: set[type["EventBase"]] = set() + for subclass in cls.__subclasses__(): + # We only want to include subclasses that are event models, not intermediary classes + if hasattr(subclass, "__event_name__"): + event_subclasses.add(subclass) + event_subclasses.update(subclass.get_events()) + + return event_subclasses + TEvent = TypeVar("TEvent", bound=EventBase) @@ -92,7 +102,6 @@ class InvocationEventBase(SessionEventBase): invocation_type: str = Field(description="The type of invocation") -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class InvocationStartedEvent(InvocationEventBase): """Event model for invocation_started""" @@ -111,7 +120,6 @@ class InvocationStartedEvent(InvocationEventBase): ) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class InvocationDenoiseProgressEvent(InvocationEventBase): """Event model for invocation_denoise_progress""" @@ -144,7 +152,6 @@ class InvocationDenoiseProgressEvent(InvocationEventBase): ) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class InvocationCompleteEvent(InvocationEventBase): """Event model for invocation_complete""" @@ -168,7 +175,6 @@ class InvocationCompleteEvent(InvocationEventBase): ) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class InvocationErrorEvent(InvocationEventBase): """Event model for invocation_error""" @@ -194,7 +200,6 @@ class InvocationErrorEvent(InvocationEventBase): ) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class SessionStartedEvent(SessionEventBase): """Event model for session_started""" @@ -210,7 +215,6 @@ class SessionStartedEvent(SessionEventBase): ) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class SessionCompleteEvent(SessionEventBase): """Event model for session_complete""" @@ -226,7 +230,6 @@ class SessionCompleteEvent(SessionEventBase): ) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class SessionCanceledEvent(SessionEventBase): """Event model for session_canceled""" @@ -242,7 +245,6 @@ class SessionCanceledEvent(SessionEventBase): ) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class QueueItemStatusChangedEvent(QueueItemEventBase): """Event model for queue_item_status_changed""" @@ -276,7 +278,6 @@ class QueueItemStatusChangedEvent(QueueItemEventBase): ) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class BatchEnqueuedEvent(QueueEventBase): """Event model for batch_enqueued""" @@ -300,7 +301,6 @@ class BatchEnqueuedEvent(QueueEventBase): ) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class QueueClearedEvent(QueueEventBase): """Event model for queue_cleared""" @@ -317,7 +317,6 @@ class DownloadEventBase(EventBase): source: str = Field(description="The source of the download") -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class DownloadStartedEvent(DownloadEventBase): """Event model for download_started""" @@ -331,7 +330,6 @@ class DownloadStartedEvent(DownloadEventBase): return cls(source=str(job.source), download_path=job.download_path.as_posix()) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class DownloadProgressEvent(DownloadEventBase): """Event model for download_progress""" @@ -352,7 +350,6 @@ class DownloadProgressEvent(DownloadEventBase): ) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class DownloadCompleteEvent(DownloadEventBase): """Event model for download_complete""" @@ -367,7 +364,6 @@ class DownloadCompleteEvent(DownloadEventBase): return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class DownloadCancelledEvent(DownloadEventBase): """Event model for download_cancelled""" @@ -378,7 +374,6 @@ class DownloadCancelledEvent(DownloadEventBase): return cls(source=str(job.source)) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class DownloadErrorEvent(DownloadEventBase): """Event model for download_error""" @@ -398,7 +393,6 @@ class ModelEventBase(EventBase): """Base class for events associated with a model""" -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class ModelLoadStartedEvent(ModelEventBase): """Event model for model_load_started""" @@ -412,7 +406,6 @@ class ModelLoadStartedEvent(ModelEventBase): return cls(config=config, submodel_type=submodel_type) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class ModelLoadCompleteEvent(ModelEventBase): """Event model for model_load_complete""" @@ -426,7 +419,6 @@ class ModelLoadCompleteEvent(ModelEventBase): return cls(config=config, submodel_type=submodel_type) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class ModelInstallDownloadProgressEvent(ModelEventBase): """Event model for model_install_download_progress""" @@ -462,7 +454,6 @@ class ModelInstallDownloadProgressEvent(ModelEventBase): ) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class ModelInstallDownloadsCompleteEvent(ModelEventBase): """Emitted once when an install job becomes active.""" @@ -476,7 +467,6 @@ class ModelInstallDownloadsCompleteEvent(ModelEventBase): return cls(id=job.id, source=str(job.source)) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class ModelInstallStartedEvent(ModelEventBase): """Event model for model_install_started""" @@ -490,7 +480,6 @@ class ModelInstallStartedEvent(ModelEventBase): return cls(id=job.id, source=str(job.source)) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class ModelInstallCompleteEvent(ModelEventBase): """Event model for model_install_complete""" @@ -507,7 +496,6 @@ class ModelInstallCompleteEvent(ModelEventBase): return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class ModelInstallCancelledEvent(ModelEventBase): """Event model for model_install_cancelled""" @@ -521,7 +509,6 @@ class ModelInstallCancelledEvent(ModelEventBase): return cls(id=job.id, source=str(job.source)) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class ModelInstallErrorEvent(ModelEventBase): """Event model for model_install_error""" @@ -547,7 +534,6 @@ class BulkDownloadEventBase(EventBase): bulk_download_item_name: str = Field(description="The name of the bulk image download item") -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class BulkDownloadStartedEvent(BulkDownloadEventBase): """Event model for bulk_download_started""" @@ -564,7 +550,6 @@ class BulkDownloadStartedEvent(BulkDownloadEventBase): ) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class BulkDownloadCompleteEvent(BulkDownloadEventBase): """Event model for bulk_download_complete""" @@ -581,7 +566,6 @@ class BulkDownloadCompleteEvent(BulkDownloadEventBase): ) -@payload_schema.register # pyright: ignore [reportUnknownMemberType] class BulkDownloadErrorEvent(BulkDownloadEventBase): """Event model for bulk_download_error"""