diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 7bf1e910c1..40547dcfed 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -14,6 +14,7 @@ from invokeai.app.services.events.events_common import ( DownloadProgressEvent, DownloadStartedEvent, EventBase, + ExtraData, InvocationCompleteEvent, InvocationDenoiseProgressEvent, InvocationErrorEvent, @@ -57,9 +58,11 @@ class EventServiceBase: # region: Invocation - def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None: + def emit_invocation_started( + self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", extra: Optional[ExtraData] = None + ) -> None: """Emitted when an invocation is started""" - self.dispatch(InvocationStartedEvent.build(queue_item, invocation)) + self.dispatch(InvocationStartedEvent.build(queue_item, invocation, extra)) def emit_invocation_denoise_progress( self, @@ -67,143 +70,184 @@ class EventServiceBase: invocation: "BaseInvocation", intermediate_state: PipelineIntermediateState, progress_image: "ProgressImage", + extra: Optional[ExtraData] = None, ) -> None: """Emitted at each step during denoising of an invocation.""" - self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image)) + self.dispatch( + InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image, extra) + ) def emit_invocation_complete( - self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput" + self, + queue_item: "SessionQueueItem", + invocation: "BaseInvocation", + output: "BaseInvocationOutput", + extra: Optional[ExtraData] = None, ) -> None: """Emitted when an invocation is complete""" - self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output)) + self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output, extra)) def emit_invocation_error( - self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", error_type: str, error: str + self, + queue_item: "SessionQueueItem", + invocation: "BaseInvocation", + error_type: str, + error: str, + extra: Optional[ExtraData] = None, ) -> None: """Emitted when an invocation encounters an error""" - self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error)) + self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error, extra)) # endregion # region Session - def emit_session_started(self, queue_item: "SessionQueueItem") -> None: + def emit_session_started(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None: """Emitted when a session has started""" - self.dispatch(SessionStartedEvent.build(queue_item)) + self.dispatch(SessionStartedEvent.build(queue_item, extra)) - def emit_session_complete(self, queue_item: "SessionQueueItem") -> None: + def emit_session_complete(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None: """Emitted when a session has completed all invocations""" - self.dispatch(SessionCompleteEvent.build(queue_item)) + self.dispatch(SessionCompleteEvent.build(queue_item, extra)) - def emit_session_canceled(self, queue_item: "SessionQueueItem") -> None: + def emit_session_canceled(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None: """Emitted when a session is canceled""" - self.dispatch(SessionCanceledEvent.build(queue_item)) + self.dispatch(SessionCanceledEvent.build(queue_item, extra)) # endregion # region Queue def emit_queue_item_status_changed( - self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus" + self, + queue_item: "SessionQueueItem", + batch_status: "BatchStatus", + queue_status: "SessionQueueStatus", + extra: Optional[ExtraData] = None, ) -> None: """Emitted when a queue item's status changes""" - self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status)) + self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status, extra)) - def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None: + def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult", extra: Optional[ExtraData] = None) -> None: """Emitted when a batch is enqueued""" - self.dispatch(BatchEnqueuedEvent.build(enqueue_result)) + self.dispatch(BatchEnqueuedEvent.build(enqueue_result, extra)) - def emit_queue_cleared(self, queue_id: str) -> None: + def emit_queue_cleared(self, queue_id: str, extra: Optional[ExtraData] = None) -> None: """Emitted when a queue is cleared""" - self.dispatch(QueueClearedEvent.build(queue_id)) + self.dispatch(QueueClearedEvent.build(queue_id, extra)) # endregion # region Download - def emit_download_started(self, job: "DownloadJob") -> None: + def emit_download_started(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None: """Emitted when a download is started""" - self.dispatch(DownloadStartedEvent.build(job)) + self.dispatch(DownloadStartedEvent.build(job, extra)) - def emit_download_progress(self, job: "DownloadJob") -> None: + def emit_download_progress(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None: """Emitted at intervals during a download""" - self.dispatch(DownloadProgressEvent.build(job)) + self.dispatch(DownloadProgressEvent.build(job, extra)) - def emit_download_complete(self, job: "DownloadJob") -> None: + def emit_download_complete(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None: """Emitted when a download is completed""" - self.dispatch(DownloadCompleteEvent.build(job)) + self.dispatch(DownloadCompleteEvent.build(job, extra)) - def emit_download_cancelled(self, job: "DownloadJob") -> None: + def emit_download_cancelled(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None: """Emitted when a download is cancelled""" - self.dispatch(DownloadCancelledEvent.build(job)) + self.dispatch(DownloadCancelledEvent.build(job, extra)) - def emit_download_error(self, job: "DownloadJob") -> None: + def emit_download_error(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None: """Emitted when a download encounters an error""" - self.dispatch(DownloadErrorEvent.build(job)) + self.dispatch(DownloadErrorEvent.build(job, extra)) # endregion # region Model loading - def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None: + def emit_model_load_started( + self, + config: "AnyModelConfig", + submodel_type: Optional["SubModelType"] = None, + extra: Optional[ExtraData] = None, + ) -> None: """Emitted when a model load is started.""" - self.dispatch(ModelLoadStartedEvent.build(config, submodel_type)) + self.dispatch(ModelLoadStartedEvent.build(config, submodel_type, extra)) def emit_model_load_complete( - self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None + self, + config: "AnyModelConfig", + submodel_type: Optional["SubModelType"] = None, + extra: Optional[ExtraData] = None, ) -> None: """Emitted when a model load is complete.""" - self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type)) + self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type, extra)) # endregion # region Model install - def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None: + def emit_model_install_download_progress(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None: """Emitted at intervals while the install job is in progress (remote models only).""" - self.dispatch(ModelInstallDownloadProgressEvent.build(job)) + self.dispatch(ModelInstallDownloadProgressEvent.build(job, extra)) - def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None: - self.dispatch(ModelInstallDownloadsCompleteEvent.build(job)) + def emit_model_install_downloads_complete(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None: + self.dispatch(ModelInstallDownloadsCompleteEvent.build(job, extra)) - def emit_model_install_started(self, job: "ModelInstallJob") -> None: + def emit_model_install_started(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None: """Emitted once when an install job is started (after any download).""" - self.dispatch(ModelInstallStartedEvent.build(job)) + self.dispatch(ModelInstallStartedEvent.build(job, extra)) - def emit_model_install_complete(self, job: "ModelInstallJob") -> None: + def emit_model_install_complete(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None: """Emitted when an install job is completed successfully.""" - self.dispatch(ModelInstallCompleteEvent.build(job)) + self.dispatch(ModelInstallCompleteEvent.build(job, extra)) - def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None: + def emit_model_install_cancelled(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None: """Emitted when an install job is cancelled.""" - self.dispatch(ModelInstallCancelledEvent.build(job)) + self.dispatch(ModelInstallCancelledEvent.build(job, extra)) - def emit_model_install_error(self, job: "ModelInstallJob") -> None: + def emit_model_install_error(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None: """Emitted when an install job encounters an exception.""" - self.dispatch(ModelInstallErrorEvent.build(job)) + self.dispatch(ModelInstallErrorEvent.build(job, extra)) # endregion # region Bulk image download def emit_bulk_download_started( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + extra: Optional[ExtraData] = None, ) -> None: """Emitted when a bulk image download is started""" - self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name)) + self.dispatch( + BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, extra) + ) def emit_bulk_download_complete( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + extra: Optional[ExtraData] = None, ) -> None: """Emitted when a bulk image download is complete""" - self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name)) + self.dispatch( + BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, extra) + ) def emit_bulk_download_error( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + error: str, + extra: Optional[ExtraData] = None, ) -> None: """Emitted when a bulk image download has an error""" self.dispatch( - BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error) + BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error, extra) ) # endregion diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index 75e92569cb..0411b10356 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -22,6 +22,9 @@ if TYPE_CHECKING: from invokeai.app.services.model_install.model_install_common import ModelInstallJob +ExtraData: TypeAlias = dict[str, Any] + + class EventBase(BaseModel): """Base class for all events. All events must inherit from this class. @@ -33,6 +36,7 @@ class EventBase(BaseModel): """ timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp) + extra: Optional[ExtraData] = Field(default=None, description="Extra data to include with the event") model_config = ConfigDict(json_schema_serialization_defaults_required=True) @@ -111,7 +115,9 @@ class InvocationStartedEvent(InvocationEventBase): __event_name__ = "invocation_started" @classmethod - def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent": + def build( + cls, queue_item: SessionQueueItem, invocation: BaseInvocation, extra: Optional[ExtraData] = None + ) -> "InvocationStartedEvent": return cls( queue_id=queue_item.queue_id, item_id=queue_item.item_id, @@ -120,6 +126,7 @@ class InvocationStartedEvent(InvocationEventBase): invocation_id=invocation.id, invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], invocation_type=invocation.get_type(), + extra=extra, ) @@ -141,6 +148,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase): invocation: BaseInvocation, intermediate_state: PipelineIntermediateState, progress_image: ProgressImage, + extra: Optional[ExtraData] = None, ) -> "InvocationDenoiseProgressEvent": step = intermediate_state.step total_steps = intermediate_state.total_steps @@ -158,6 +166,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase): total_steps=total_steps, order=order, percentage=cls.calc_percentage(step, total_steps, order), + extra=extra, ) @staticmethod @@ -180,7 +189,11 @@ class InvocationCompleteEvent(InvocationEventBase): @classmethod def build( - cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput + cls, + queue_item: SessionQueueItem, + invocation: BaseInvocation, + result: BaseInvocationOutput, + extra: Optional[ExtraData] = None, ) -> "InvocationCompleteEvent": return cls( queue_id=queue_item.queue_id, @@ -191,6 +204,7 @@ class InvocationCompleteEvent(InvocationEventBase): invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], invocation_type=invocation.get_type(), result=result, + extra=extra, ) @@ -204,7 +218,12 @@ class InvocationErrorEvent(InvocationEventBase): @classmethod def build( - cls, queue_item: SessionQueueItem, invocation: BaseInvocation, error_type: str, error: str + cls, + queue_item: SessionQueueItem, + invocation: BaseInvocation, + error_type: str, + error: str, + extra: Optional[ExtraData] = None, ) -> "InvocationErrorEvent": return cls( queue_id=queue_item.queue_id, @@ -216,6 +235,7 @@ class InvocationErrorEvent(InvocationEventBase): invocation_type=invocation.get_type(), error_type=error_type, error=error, + extra=extra, ) @@ -225,12 +245,13 @@ class SessionStartedEvent(SessionEventBase): __event_name__ = "session_started" @classmethod - def build(cls, queue_item: SessionQueueItem) -> "SessionStartedEvent": + def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionStartedEvent": return cls( queue_id=queue_item.queue_id, item_id=queue_item.item_id, batch_id=queue_item.batch_id, session_id=queue_item.session_id, + extra=extra, ) @@ -240,12 +261,13 @@ class SessionCompleteEvent(SessionEventBase): __event_name__ = "session_complete" @classmethod - def build(cls, queue_item: SessionQueueItem) -> "SessionCompleteEvent": + def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionCompleteEvent": return cls( queue_id=queue_item.queue_id, item_id=queue_item.item_id, batch_id=queue_item.batch_id, session_id=queue_item.session_id, + extra=extra, ) @@ -255,12 +277,13 @@ class SessionCanceledEvent(SessionEventBase): __event_name__ = "session_canceled" @classmethod - def build(cls, queue_item: SessionQueueItem) -> "SessionCanceledEvent": + def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionCanceledEvent": return cls( queue_id=queue_item.queue_id, item_id=queue_item.item_id, batch_id=queue_item.batch_id, session_id=queue_item.session_id, + extra=extra, ) @@ -280,7 +303,11 @@ class QueueItemStatusChangedEvent(QueueItemEventBase): @classmethod def build( - cls, queue_item: SessionQueueItem, batch_status: BatchStatus, queue_status: SessionQueueStatus + cls, + queue_item: SessionQueueItem, + batch_status: BatchStatus, + queue_status: SessionQueueStatus, + extra: Optional[ExtraData] = None, ) -> "QueueItemStatusChangedEvent": return cls( queue_id=queue_item.queue_id, @@ -294,6 +321,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase): completed_at=str(queue_item.completed_at) if queue_item.completed_at else None, batch_status=batch_status, queue_status=queue_status, + extra=extra, ) @@ -310,13 +338,14 @@ class BatchEnqueuedEvent(QueueEventBase): priority: int = Field(description="The priority of the batch") @classmethod - def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent": + def build(cls, enqueue_result: EnqueueBatchResult, extra: Optional[ExtraData] = None) -> "BatchEnqueuedEvent": return cls( queue_id=enqueue_result.queue_id, batch_id=enqueue_result.batch.batch_id, enqueued=enqueue_result.enqueued, requested=enqueue_result.requested, priority=enqueue_result.priority, + extra=extra, ) @@ -326,8 +355,11 @@ class QueueClearedEvent(QueueEventBase): __event_name__ = "queue_cleared" @classmethod - def build(cls, queue_id: str) -> "QueueClearedEvent": - return cls(queue_id=queue_id) + def build(cls, queue_id: str, extra: Optional[ExtraData] = None) -> "QueueClearedEvent": + return cls( + queue_id=queue_id, + extra=extra, + ) class DownloadEventBase(EventBase): @@ -344,9 +376,13 @@ class DownloadStartedEvent(DownloadEventBase): download_path: str = Field(description="The local path where the download is saved") @classmethod - def build(cls, job: "DownloadJob") -> "DownloadStartedEvent": + def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadStartedEvent": assert job.download_path - return cls(source=str(job.source), download_path=job.download_path.as_posix()) + return cls( + source=str(job.source), + download_path=job.download_path.as_posix(), + extra=extra, + ) class DownloadProgressEvent(DownloadEventBase): @@ -359,13 +395,14 @@ class DownloadProgressEvent(DownloadEventBase): total_bytes: int = Field(description="The total number of bytes to be downloaded") @classmethod - def build(cls, job: "DownloadJob") -> "DownloadProgressEvent": + def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadProgressEvent": assert job.download_path return cls( source=str(job.source), download_path=job.download_path.as_posix(), current_bytes=job.bytes, total_bytes=job.total_bytes, + extra=extra, ) @@ -378,9 +415,14 @@ class DownloadCompleteEvent(DownloadEventBase): total_bytes: int = Field(description="The total number of bytes downloaded") @classmethod - def build(cls, job: "DownloadJob") -> "DownloadCompleteEvent": + def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadCompleteEvent": assert job.download_path - return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes) + return cls( + source=str(job.source), + download_path=job.download_path.as_posix(), + total_bytes=job.total_bytes, + extra=extra, + ) class DownloadCancelledEvent(DownloadEventBase): @@ -389,8 +431,11 @@ class DownloadCancelledEvent(DownloadEventBase): __event_name__ = "download_cancelled" @classmethod - def build(cls, job: "DownloadJob") -> "DownloadCancelledEvent": - return cls(source=str(job.source)) + def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadCancelledEvent": + return cls( + source=str(job.source), + extra=extra, + ) class DownloadErrorEvent(DownloadEventBase): @@ -402,10 +447,15 @@ class DownloadErrorEvent(DownloadEventBase): error: str = Field(description="The error message") @classmethod - def build(cls, job: "DownloadJob") -> "DownloadErrorEvent": + def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadErrorEvent": assert job.error_type assert job.error - return cls(source=str(job.source), error_type=job.error_type, error=job.error) + return cls( + source=str(job.source), + error_type=job.error_type, + error=job.error, + extra=extra, + ) class ModelEventBase(EventBase): @@ -421,8 +471,14 @@ class ModelLoadStartedEvent(ModelEventBase): submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any") @classmethod - def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent": - return cls(config=config, submodel_type=submodel_type) + def build( + cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, extra: Optional[ExtraData] = None + ) -> "ModelLoadStartedEvent": + return cls( + config=config, + submodel_type=submodel_type, + extra=extra, + ) class ModelLoadCompleteEvent(ModelEventBase): @@ -434,8 +490,14 @@ class ModelLoadCompleteEvent(ModelEventBase): submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any") @classmethod - def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent": - return cls(config=config, submodel_type=submodel_type) + def build( + cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, extra: Optional[ExtraData] = None + ) -> "ModelLoadCompleteEvent": + return cls( + config=config, + submodel_type=submodel_type, + extra=extra, + ) class ModelInstallDownloadProgressEvent(ModelEventBase): @@ -453,7 +515,7 @@ class ModelInstallDownloadProgressEvent(ModelEventBase): ) @classmethod - def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent": + def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallDownloadProgressEvent": parts: list[dict[str, str | int]] = [ { "url": str(x.source), @@ -470,6 +532,7 @@ class ModelInstallDownloadProgressEvent(ModelEventBase): parts=parts, bytes=job.bytes, total_bytes=job.total_bytes, + extra=extra, ) @@ -482,8 +545,12 @@ class ModelInstallDownloadsCompleteEvent(ModelEventBase): source: str = Field(description="Source of the model; local path, repo_id or url") @classmethod - def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent": - return cls(id=job.id, source=str(job.source)) + def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallDownloadsCompleteEvent": + return cls( + id=job.id, + source=str(job.source), + extra=extra, + ) class ModelInstallStartedEvent(ModelEventBase): @@ -495,8 +562,12 @@ class ModelInstallStartedEvent(ModelEventBase): source: str = Field(description="Source of the model; local path, repo_id or url") @classmethod - def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent": - return cls(id=job.id, source=str(job.source)) + def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallStartedEvent": + return cls( + id=job.id, + source=str(job.source), + extra=extra, + ) class ModelInstallCompleteEvent(ModelEventBase): @@ -510,9 +581,15 @@ class ModelInstallCompleteEvent(ModelEventBase): total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)") @classmethod - def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent": + def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallCompleteEvent": assert job.config_out is not None - return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes) + return cls( + id=job.id, + source=str(job.source), + key=(job.config_out.key), + total_bytes=job.total_bytes, + extra=extra, + ) class ModelInstallCancelledEvent(ModelEventBase): @@ -524,8 +601,12 @@ class ModelInstallCancelledEvent(ModelEventBase): source: str = Field(description="Source of the model; local path, repo_id or url") @classmethod - def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent": - return cls(id=job.id, source=str(job.source)) + def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallCancelledEvent": + return cls( + id=job.id, + source=str(job.source), + extra=extra, + ) class ModelInstallErrorEvent(ModelEventBase): @@ -539,10 +620,16 @@ class ModelInstallErrorEvent(ModelEventBase): error: str = Field(description="A text description of the exception") @classmethod - def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent": + def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallErrorEvent": assert job.error_type is not None assert job.error is not None - return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error) + return cls( + id=job.id, + source=str(job.source), + error_type=job.error_type, + error=job.error, + extra=extra, + ) class BulkDownloadEventBase(EventBase): @@ -560,12 +647,17 @@ class BulkDownloadStartedEvent(BulkDownloadEventBase): @classmethod def build( - cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + cls, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + extra: Optional[ExtraData] = None, ) -> "BulkDownloadStartedEvent": return cls( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, bulk_download_item_name=bulk_download_item_name, + extra=extra, ) @@ -576,12 +668,17 @@ class BulkDownloadCompleteEvent(BulkDownloadEventBase): @classmethod def build( - cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + cls, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + extra: Optional[ExtraData] = None, ) -> "BulkDownloadCompleteEvent": return cls( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, bulk_download_item_name=bulk_download_item_name, + extra=extra, ) @@ -594,11 +691,17 @@ class BulkDownloadErrorEvent(BulkDownloadEventBase): @classmethod def build( - cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str + cls, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + error: str, + extra: Optional[ExtraData] = None, ) -> "BulkDownloadErrorEvent": return cls( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, bulk_download_item_name=bulk_download_item_name, error=error, + extra=extra, )