feat(events): register event schemas

This allows for events to be dispatched using dicts as payloads, and have the dicts validated as pydantic schemas.
This commit is contained in:
psychedelicious 2024-05-27 10:59:36 +10:00
parent 5388f5a817
commit b50133d5e1

View File

@ -2,6 +2,7 @@ from math import floor
from typing import TYPE_CHECKING, Any, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
@ -100,6 +101,7 @@ class InvocationEventBase(QueueItemEventBase):
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
@payload_schema.register
class InvocationStartedEvent(InvocationEventBase):
"""Event model for invocation_started"""
@ -117,6 +119,7 @@ class InvocationStartedEvent(InvocationEventBase):
)
@payload_schema.register
class InvocationDenoiseProgressEvent(InvocationEventBase):
"""Event model for invocation_denoise_progress"""
@ -164,6 +167,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
return (step + 1 + 1) / (total_steps + 1)
@payload_schema.register
class InvocationCompleteEvent(InvocationEventBase):
"""Event model for invocation_complete"""
@ -186,6 +190,7 @@ class InvocationCompleteEvent(InvocationEventBase):
)
@payload_schema.register
class InvocationErrorEvent(InvocationEventBase):
"""Event model for invocation_error"""
@ -221,6 +226,7 @@ class InvocationErrorEvent(InvocationEventBase):
)
@payload_schema.register
class QueueItemStatusChangedEvent(QueueItemEventBase):
"""Event model for queue_item_status_changed"""
@ -260,6 +266,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
)
@payload_schema.register
class BatchEnqueuedEvent(QueueEventBase):
"""Event model for batch_enqueued"""
@ -283,6 +290,7 @@ class BatchEnqueuedEvent(QueueEventBase):
)
@payload_schema.register
class QueueClearedEvent(QueueEventBase):
"""Event model for queue_cleared"""
@ -299,6 +307,7 @@ class DownloadEventBase(EventBase):
source: str = Field(description="The source of the download")
@payload_schema.register
class DownloadStartedEvent(DownloadEventBase):
"""Event model for download_started"""
@ -312,6 +321,7 @@ class DownloadStartedEvent(DownloadEventBase):
return cls(source=str(job.source), download_path=job.download_path.as_posix())
@payload_schema.register
class DownloadProgressEvent(DownloadEventBase):
"""Event model for download_progress"""
@ -332,6 +342,7 @@ class DownloadProgressEvent(DownloadEventBase):
)
@payload_schema.register
class DownloadCompleteEvent(DownloadEventBase):
"""Event model for download_complete"""
@ -346,6 +357,7 @@ class DownloadCompleteEvent(DownloadEventBase):
return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes)
@payload_schema.register
class DownloadCancelledEvent(DownloadEventBase):
"""Event model for download_cancelled"""
@ -356,6 +368,7 @@ class DownloadCancelledEvent(DownloadEventBase):
return cls(source=str(job.source))
@payload_schema.register
class DownloadErrorEvent(DownloadEventBase):
"""Event model for download_error"""
@ -375,6 +388,7 @@ class ModelEventBase(EventBase):
"""Base class for events associated with a model"""
@payload_schema.register
class ModelLoadStartedEvent(ModelEventBase):
"""Event model for model_load_started"""
@ -388,6 +402,7 @@ class ModelLoadStartedEvent(ModelEventBase):
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register
class ModelLoadCompleteEvent(ModelEventBase):
"""Event model for model_load_complete"""
@ -401,6 +416,7 @@ class ModelLoadCompleteEvent(ModelEventBase):
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register
class ModelInstallDownloadProgressEvent(ModelEventBase):
"""Event model for model_install_download_progress"""
@ -436,6 +452,7 @@ class ModelInstallDownloadProgressEvent(ModelEventBase):
)
@payload_schema.register
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
"""Emitted once when an install job becomes active."""
@ -449,6 +466,7 @@ class ModelInstallDownloadsCompleteEvent(ModelEventBase):
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallStartedEvent(ModelEventBase):
"""Event model for model_install_started"""
@ -462,6 +480,7 @@ class ModelInstallStartedEvent(ModelEventBase):
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallCompleteEvent(ModelEventBase):
"""Event model for model_install_complete"""
@ -478,6 +497,7 @@ class ModelInstallCompleteEvent(ModelEventBase):
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
@payload_schema.register
class ModelInstallCancelledEvent(ModelEventBase):
"""Event model for model_install_cancelled"""
@ -491,6 +511,7 @@ class ModelInstallCancelledEvent(ModelEventBase):
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallErrorEvent(ModelEventBase):
"""Event model for model_install_error"""
@ -516,6 +537,7 @@ class BulkDownloadEventBase(EventBase):
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
@payload_schema.register
class BulkDownloadStartedEvent(BulkDownloadEventBase):
"""Event model for bulk_download_started"""
@ -532,6 +554,7 @@ class BulkDownloadStartedEvent(BulkDownloadEventBase):
)
@payload_schema.register
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
"""Event model for bulk_download_complete"""
@ -548,6 +571,7 @@ class BulkDownloadCompleteEvent(BulkDownloadEventBase):
)
@payload_schema.register
class BulkDownloadErrorEvent(BulkDownloadEventBase):
"""Event model for bulk_download_error"""