mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(events): remove payload registry, add method to get event classes
We don't need to use the payload schema registry. All our events are dispatched as pydantic models, which are already validated on instantiation. We do want to add all events to the OpenAPI schema, and we referred to the payload schema registry for this. To get all events, add a simple helper to EventBase. This is functionally identical to using the schema registry.
This commit is contained in:
parent
9aeabf10df
commit
a48ef9f7a7
@ -5,7 +5,7 @@ import socket
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, cast
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@ -17,8 +17,6 @@ from fastapi.openapi.utils import get_openapi
|
|||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||||
from fastapi_events.registry.payload_schema import registry as fastapi_events_registry
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from pydantic.json_schema import models_json_schema
|
from pydantic.json_schema import models_json_schema
|
||||||
from torch.backends.mps import is_available as is_mps_available
|
from torch.backends.mps import is_available as is_mps_available
|
||||||
|
|
||||||
@ -29,6 +27,7 @@ import invokeai.frontend.web as web_dir
|
|||||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
|
from invokeai.app.services.events.events_common import EventBase
|
||||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
@ -185,15 +184,13 @@ def custom_openapi() -> dict[str, Any]:
|
|||||||
invoker_schema["class"] = "invocation"
|
invoker_schema["class"] = "invocation"
|
||||||
|
|
||||||
# Add all pydantic event schemas registered with fastapi-events
|
# Add all pydantic event schemas registered with fastapi-events
|
||||||
for payload in fastapi_events_registry.data.values():
|
for event in EventBase.get_events():
|
||||||
json_schema = cast(BaseModel, payload).model_json_schema(
|
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
||||||
mode="serialization", ref_template="#/components/schemas/{model}"
|
|
||||||
)
|
|
||||||
if "$defs" in json_schema:
|
if "$defs" in json_schema:
|
||||||
for schema_key, schema in json_schema["$defs"].items():
|
for schema_key, schema in json_schema["$defs"].items():
|
||||||
openapi_schema["components"]["schemas"][schema_key] = schema
|
openapi_schema["components"]["schemas"][schema_key] = schema
|
||||||
del json_schema["$defs"]
|
del json_schema["$defs"]
|
||||||
openapi_schema["components"]["schemas"][payload.__name__] = json_schema
|
openapi_schema["components"]["schemas"][event.__name__] = json_schema
|
||||||
|
|
||||||
app.openapi_schema = openapi_schema
|
app.openapi_schema = openapi_schema
|
||||||
return app.openapi_schema
|
return app.openapi_schema
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Protocol, TypeAlias, TypeVar
|
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Protocol, TypeAlias, TypeVar
|
||||||
|
|
||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.registry.payload_schema import registry as payload_schema
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
|
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||||
@ -31,12 +30,23 @@ class EventBase(BaseModel):
|
|||||||
A timestamp is automatically added to the event when it is created.
|
A timestamp is automatically added to the event when it is created.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__event_name__ = "event_base"
|
|
||||||
|
|
||||||
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
|
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
|
||||||
|
|
||||||
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
|
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_events(cls) -> set[type["EventBase"]]:
|
||||||
|
"""Get a set of all event models."""
|
||||||
|
|
||||||
|
event_subclasses: set[type["EventBase"]] = set()
|
||||||
|
for subclass in cls.__subclasses__():
|
||||||
|
# We only want to include subclasses that are event models, not intermediary classes
|
||||||
|
if hasattr(subclass, "__event_name__"):
|
||||||
|
event_subclasses.add(subclass)
|
||||||
|
event_subclasses.update(subclass.get_events())
|
||||||
|
|
||||||
|
return event_subclasses
|
||||||
|
|
||||||
|
|
||||||
TEvent = TypeVar("TEvent", bound=EventBase)
|
TEvent = TypeVar("TEvent", bound=EventBase)
|
||||||
|
|
||||||
@ -92,7 +102,6 @@ class InvocationEventBase(SessionEventBase):
|
|||||||
invocation_type: str = Field(description="The type of invocation")
|
invocation_type: str = Field(description="The type of invocation")
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class InvocationStartedEvent(InvocationEventBase):
|
class InvocationStartedEvent(InvocationEventBase):
|
||||||
"""Event model for invocation_started"""
|
"""Event model for invocation_started"""
|
||||||
|
|
||||||
@ -111,7 +120,6 @@ class InvocationStartedEvent(InvocationEventBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class InvocationDenoiseProgressEvent(InvocationEventBase):
|
class InvocationDenoiseProgressEvent(InvocationEventBase):
|
||||||
"""Event model for invocation_denoise_progress"""
|
"""Event model for invocation_denoise_progress"""
|
||||||
|
|
||||||
@ -144,7 +152,6 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class InvocationCompleteEvent(InvocationEventBase):
|
class InvocationCompleteEvent(InvocationEventBase):
|
||||||
"""Event model for invocation_complete"""
|
"""Event model for invocation_complete"""
|
||||||
|
|
||||||
@ -168,7 +175,6 @@ class InvocationCompleteEvent(InvocationEventBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class InvocationErrorEvent(InvocationEventBase):
|
class InvocationErrorEvent(InvocationEventBase):
|
||||||
"""Event model for invocation_error"""
|
"""Event model for invocation_error"""
|
||||||
|
|
||||||
@ -194,7 +200,6 @@ class InvocationErrorEvent(InvocationEventBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class SessionStartedEvent(SessionEventBase):
|
class SessionStartedEvent(SessionEventBase):
|
||||||
"""Event model for session_started"""
|
"""Event model for session_started"""
|
||||||
|
|
||||||
@ -210,7 +215,6 @@ class SessionStartedEvent(SessionEventBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class SessionCompleteEvent(SessionEventBase):
|
class SessionCompleteEvent(SessionEventBase):
|
||||||
"""Event model for session_complete"""
|
"""Event model for session_complete"""
|
||||||
|
|
||||||
@ -226,7 +230,6 @@ class SessionCompleteEvent(SessionEventBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class SessionCanceledEvent(SessionEventBase):
|
class SessionCanceledEvent(SessionEventBase):
|
||||||
"""Event model for session_canceled"""
|
"""Event model for session_canceled"""
|
||||||
|
|
||||||
@ -242,7 +245,6 @@ class SessionCanceledEvent(SessionEventBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class QueueItemStatusChangedEvent(QueueItemEventBase):
|
class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||||
"""Event model for queue_item_status_changed"""
|
"""Event model for queue_item_status_changed"""
|
||||||
|
|
||||||
@ -276,7 +278,6 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class BatchEnqueuedEvent(QueueEventBase):
|
class BatchEnqueuedEvent(QueueEventBase):
|
||||||
"""Event model for batch_enqueued"""
|
"""Event model for batch_enqueued"""
|
||||||
|
|
||||||
@ -300,7 +301,6 @@ class BatchEnqueuedEvent(QueueEventBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class QueueClearedEvent(QueueEventBase):
|
class QueueClearedEvent(QueueEventBase):
|
||||||
"""Event model for queue_cleared"""
|
"""Event model for queue_cleared"""
|
||||||
|
|
||||||
@ -317,7 +317,6 @@ class DownloadEventBase(EventBase):
|
|||||||
source: str = Field(description="The source of the download")
|
source: str = Field(description="The source of the download")
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class DownloadStartedEvent(DownloadEventBase):
|
class DownloadStartedEvent(DownloadEventBase):
|
||||||
"""Event model for download_started"""
|
"""Event model for download_started"""
|
||||||
|
|
||||||
@ -331,7 +330,6 @@ class DownloadStartedEvent(DownloadEventBase):
|
|||||||
return cls(source=str(job.source), download_path=job.download_path.as_posix())
|
return cls(source=str(job.source), download_path=job.download_path.as_posix())
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class DownloadProgressEvent(DownloadEventBase):
|
class DownloadProgressEvent(DownloadEventBase):
|
||||||
"""Event model for download_progress"""
|
"""Event model for download_progress"""
|
||||||
|
|
||||||
@ -352,7 +350,6 @@ class DownloadProgressEvent(DownloadEventBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class DownloadCompleteEvent(DownloadEventBase):
|
class DownloadCompleteEvent(DownloadEventBase):
|
||||||
"""Event model for download_complete"""
|
"""Event model for download_complete"""
|
||||||
|
|
||||||
@ -367,7 +364,6 @@ class DownloadCompleteEvent(DownloadEventBase):
|
|||||||
return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes)
|
return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class DownloadCancelledEvent(DownloadEventBase):
|
class DownloadCancelledEvent(DownloadEventBase):
|
||||||
"""Event model for download_cancelled"""
|
"""Event model for download_cancelled"""
|
||||||
|
|
||||||
@ -378,7 +374,6 @@ class DownloadCancelledEvent(DownloadEventBase):
|
|||||||
return cls(source=str(job.source))
|
return cls(source=str(job.source))
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class DownloadErrorEvent(DownloadEventBase):
|
class DownloadErrorEvent(DownloadEventBase):
|
||||||
"""Event model for download_error"""
|
"""Event model for download_error"""
|
||||||
|
|
||||||
@ -398,7 +393,6 @@ class ModelEventBase(EventBase):
|
|||||||
"""Base class for events associated with a model"""
|
"""Base class for events associated with a model"""
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class ModelLoadStartedEvent(ModelEventBase):
|
class ModelLoadStartedEvent(ModelEventBase):
|
||||||
"""Event model for model_load_started"""
|
"""Event model for model_load_started"""
|
||||||
|
|
||||||
@ -412,7 +406,6 @@ class ModelLoadStartedEvent(ModelEventBase):
|
|||||||
return cls(config=config, submodel_type=submodel_type)
|
return cls(config=config, submodel_type=submodel_type)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class ModelLoadCompleteEvent(ModelEventBase):
|
class ModelLoadCompleteEvent(ModelEventBase):
|
||||||
"""Event model for model_load_complete"""
|
"""Event model for model_load_complete"""
|
||||||
|
|
||||||
@ -426,7 +419,6 @@ class ModelLoadCompleteEvent(ModelEventBase):
|
|||||||
return cls(config=config, submodel_type=submodel_type)
|
return cls(config=config, submodel_type=submodel_type)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class ModelInstallDownloadProgressEvent(ModelEventBase):
|
class ModelInstallDownloadProgressEvent(ModelEventBase):
|
||||||
"""Event model for model_install_download_progress"""
|
"""Event model for model_install_download_progress"""
|
||||||
|
|
||||||
@ -462,7 +454,6 @@ class ModelInstallDownloadProgressEvent(ModelEventBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
|
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
|
||||||
"""Emitted once when an install job becomes active."""
|
"""Emitted once when an install job becomes active."""
|
||||||
|
|
||||||
@ -476,7 +467,6 @@ class ModelInstallDownloadsCompleteEvent(ModelEventBase):
|
|||||||
return cls(id=job.id, source=str(job.source))
|
return cls(id=job.id, source=str(job.source))
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class ModelInstallStartedEvent(ModelEventBase):
|
class ModelInstallStartedEvent(ModelEventBase):
|
||||||
"""Event model for model_install_started"""
|
"""Event model for model_install_started"""
|
||||||
|
|
||||||
@ -490,7 +480,6 @@ class ModelInstallStartedEvent(ModelEventBase):
|
|||||||
return cls(id=job.id, source=str(job.source))
|
return cls(id=job.id, source=str(job.source))
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class ModelInstallCompleteEvent(ModelEventBase):
|
class ModelInstallCompleteEvent(ModelEventBase):
|
||||||
"""Event model for model_install_complete"""
|
"""Event model for model_install_complete"""
|
||||||
|
|
||||||
@ -507,7 +496,6 @@ class ModelInstallCompleteEvent(ModelEventBase):
|
|||||||
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
|
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class ModelInstallCancelledEvent(ModelEventBase):
|
class ModelInstallCancelledEvent(ModelEventBase):
|
||||||
"""Event model for model_install_cancelled"""
|
"""Event model for model_install_cancelled"""
|
||||||
|
|
||||||
@ -521,7 +509,6 @@ class ModelInstallCancelledEvent(ModelEventBase):
|
|||||||
return cls(id=job.id, source=str(job.source))
|
return cls(id=job.id, source=str(job.source))
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class ModelInstallErrorEvent(ModelEventBase):
|
class ModelInstallErrorEvent(ModelEventBase):
|
||||||
"""Event model for model_install_error"""
|
"""Event model for model_install_error"""
|
||||||
|
|
||||||
@ -547,7 +534,6 @@ class BulkDownloadEventBase(EventBase):
|
|||||||
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
|
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class BulkDownloadStartedEvent(BulkDownloadEventBase):
|
class BulkDownloadStartedEvent(BulkDownloadEventBase):
|
||||||
"""Event model for bulk_download_started"""
|
"""Event model for bulk_download_started"""
|
||||||
|
|
||||||
@ -564,7 +550,6 @@ class BulkDownloadStartedEvent(BulkDownloadEventBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
|
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
|
||||||
"""Event model for bulk_download_complete"""
|
"""Event model for bulk_download_complete"""
|
||||||
|
|
||||||
@ -581,7 +566,6 @@ class BulkDownloadCompleteEvent(BulkDownloadEventBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
|
|
||||||
class BulkDownloadErrorEvent(BulkDownloadEventBase):
|
class BulkDownloadErrorEvent(BulkDownloadEventBase):
|
||||||
"""Event model for bulk_download_error"""
|
"""Event model for bulk_download_error"""
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user