feat(events): add extra field to event payloads

This allows for arbitrary serializable data to be sent with events.
This commit is contained in:
psychedelicious
2024-05-20 15:34:03 +10:00
parent 87f2d04ddd
commit 452f4fe0e6
2 changed files with 235 additions and 88 deletions

View File

@ -14,6 +14,7 @@ from invokeai.app.services.events.events_common import (
DownloadProgressEvent,
DownloadStartedEvent,
EventBase,
ExtraData,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
@ -57,9 +58,11 @@ class EventServiceBase:
# region: Invocation
def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None:
def emit_invocation_started(
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", extra: Optional[ExtraData] = None
) -> None:
"""Emitted when an invocation is started"""
self.dispatch(InvocationStartedEvent.build(queue_item, invocation))
self.dispatch(InvocationStartedEvent.build(queue_item, invocation, extra))
def emit_invocation_denoise_progress(
self,
@ -67,143 +70,184 @@ class EventServiceBase:
invocation: "BaseInvocation",
intermediate_state: PipelineIntermediateState,
progress_image: "ProgressImage",
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted at each step during denoising of an invocation."""
self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image))
self.dispatch(
InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image, extra)
)
def emit_invocation_complete(
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
output: "BaseInvocationOutput",
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted when an invocation is complete"""
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output))
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output, extra))
def emit_invocation_error(
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", error_type: str, error: str
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
error_type: str,
error: str,
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted when an invocation encounters an error"""
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error))
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error, extra))
# endregion
# region Session
def emit_session_started(self, queue_item: "SessionQueueItem") -> None:
def emit_session_started(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a session has started"""
self.dispatch(SessionStartedEvent.build(queue_item))
self.dispatch(SessionStartedEvent.build(queue_item, extra))
def emit_session_complete(self, queue_item: "SessionQueueItem") -> None:
def emit_session_complete(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a session has completed all invocations"""
self.dispatch(SessionCompleteEvent.build(queue_item))
self.dispatch(SessionCompleteEvent.build(queue_item, extra))
def emit_session_canceled(self, queue_item: "SessionQueueItem") -> None:
def emit_session_canceled(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a session is canceled"""
self.dispatch(SessionCanceledEvent.build(queue_item))
self.dispatch(SessionCanceledEvent.build(queue_item, extra))
# endregion
# region Queue
def emit_queue_item_status_changed(
self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus"
self,
queue_item: "SessionQueueItem",
batch_status: "BatchStatus",
queue_status: "SessionQueueStatus",
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted when a queue item's status changes"""
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status, extra))
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a batch is enqueued"""
self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
self.dispatch(BatchEnqueuedEvent.build(enqueue_result, extra))
def emit_queue_cleared(self, queue_id: str) -> None:
def emit_queue_cleared(self, queue_id: str, extra: Optional[ExtraData] = None) -> None:
"""Emitted when a queue is cleared"""
self.dispatch(QueueClearedEvent.build(queue_id))
self.dispatch(QueueClearedEvent.build(queue_id, extra))
# endregion
# region Download
def emit_download_started(self, job: "DownloadJob") -> None:
def emit_download_started(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a download is started"""
self.dispatch(DownloadStartedEvent.build(job))
self.dispatch(DownloadStartedEvent.build(job, extra))
def emit_download_progress(self, job: "DownloadJob") -> None:
def emit_download_progress(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted at intervals during a download"""
self.dispatch(DownloadProgressEvent.build(job))
self.dispatch(DownloadProgressEvent.build(job, extra))
def emit_download_complete(self, job: "DownloadJob") -> None:
def emit_download_complete(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a download is completed"""
self.dispatch(DownloadCompleteEvent.build(job))
self.dispatch(DownloadCompleteEvent.build(job, extra))
def emit_download_cancelled(self, job: "DownloadJob") -> None:
def emit_download_cancelled(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a download is cancelled"""
self.dispatch(DownloadCancelledEvent.build(job))
self.dispatch(DownloadCancelledEvent.build(job, extra))
def emit_download_error(self, job: "DownloadJob") -> None:
def emit_download_error(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a download encounters an error"""
self.dispatch(DownloadErrorEvent.build(job))
self.dispatch(DownloadErrorEvent.build(job, extra))
# endregion
# region Model loading
def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
def emit_model_load_started(
self,
config: "AnyModelConfig",
submodel_type: Optional["SubModelType"] = None,
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted when a model load is started."""
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type, extra))
def emit_model_load_complete(
self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
self,
config: "AnyModelConfig",
submodel_type: Optional["SubModelType"] = None,
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted when a model load is complete."""
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type, extra))
# endregion
# region Model install
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
def emit_model_install_download_progress(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted at intervals while the install job is in progress (remote models only)."""
self.dispatch(ModelInstallDownloadProgressEvent.build(job))
self.dispatch(ModelInstallDownloadProgressEvent.build(job, extra))
def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None:
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job))
def emit_model_install_downloads_complete(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job, extra))
def emit_model_install_started(self, job: "ModelInstallJob") -> None:
def emit_model_install_started(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted once when an install job is started (after any download)."""
self.dispatch(ModelInstallStartedEvent.build(job))
self.dispatch(ModelInstallStartedEvent.build(job, extra))
def emit_model_install_complete(self, job: "ModelInstallJob") -> None:
def emit_model_install_complete(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when an install job is completed successfully."""
self.dispatch(ModelInstallCompleteEvent.build(job))
self.dispatch(ModelInstallCompleteEvent.build(job, extra))
def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None:
def emit_model_install_cancelled(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when an install job is cancelled."""
self.dispatch(ModelInstallCancelledEvent.build(job))
self.dispatch(ModelInstallCancelledEvent.build(job, extra))
def emit_model_install_error(self, job: "ModelInstallJob") -> None:
def emit_model_install_error(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when an install job encounters an exception."""
self.dispatch(ModelInstallErrorEvent.build(job))
self.dispatch(ModelInstallErrorEvent.build(job, extra))
# endregion
# region Bulk image download
def emit_bulk_download_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted when a bulk image download is started"""
self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
self.dispatch(
BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, extra)
)
def emit_bulk_download_complete(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted when a bulk image download is complete"""
self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
self.dispatch(
BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, extra)
)
def emit_bulk_download_error(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
error: str,
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted when a bulk image download has an error"""
self.dispatch(
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error, extra)
)
# endregion

View File

@ -22,6 +22,9 @@ if TYPE_CHECKING:
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
ExtraData: TypeAlias = dict[str, Any]
class EventBase(BaseModel):
"""Base class for all events. All events must inherit from this class.
@ -33,6 +36,7 @@ class EventBase(BaseModel):
"""
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
extra: Optional[ExtraData] = Field(default=None, description="Extra data to include with the event")
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
@ -111,7 +115,9 @@ class InvocationStartedEvent(InvocationEventBase):
__event_name__ = "invocation_started"
@classmethod
def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent":
def build(
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, extra: Optional[ExtraData] = None
) -> "InvocationStartedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
@ -120,6 +126,7 @@ class InvocationStartedEvent(InvocationEventBase):
invocation_id=invocation.id,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
invocation_type=invocation.get_type(),
extra=extra,
)
@ -141,6 +148,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
invocation: BaseInvocation,
intermediate_state: PipelineIntermediateState,
progress_image: ProgressImage,
extra: Optional[ExtraData] = None,
) -> "InvocationDenoiseProgressEvent":
step = intermediate_state.step
total_steps = intermediate_state.total_steps
@ -158,6 +166,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
total_steps=total_steps,
order=order,
percentage=cls.calc_percentage(step, total_steps, order),
extra=extra,
)
@staticmethod
@ -180,7 +189,11 @@ class InvocationCompleteEvent(InvocationEventBase):
@classmethod
def build(
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
cls,
queue_item: SessionQueueItem,
invocation: BaseInvocation,
result: BaseInvocationOutput,
extra: Optional[ExtraData] = None,
) -> "InvocationCompleteEvent":
return cls(
queue_id=queue_item.queue_id,
@ -191,6 +204,7 @@ class InvocationCompleteEvent(InvocationEventBase):
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
invocation_type=invocation.get_type(),
result=result,
extra=extra,
)
@ -204,7 +218,12 @@ class InvocationErrorEvent(InvocationEventBase):
@classmethod
def build(
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, error_type: str, error: str
cls,
queue_item: SessionQueueItem,
invocation: BaseInvocation,
error_type: str,
error: str,
extra: Optional[ExtraData] = None,
) -> "InvocationErrorEvent":
return cls(
queue_id=queue_item.queue_id,
@ -216,6 +235,7 @@ class InvocationErrorEvent(InvocationEventBase):
invocation_type=invocation.get_type(),
error_type=error_type,
error=error,
extra=extra,
)
@ -225,12 +245,13 @@ class SessionStartedEvent(SessionEventBase):
__event_name__ = "session_started"
@classmethod
def build(cls, queue_item: SessionQueueItem) -> "SessionStartedEvent":
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionStartedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
extra=extra,
)
@ -240,12 +261,13 @@ class SessionCompleteEvent(SessionEventBase):
__event_name__ = "session_complete"
@classmethod
def build(cls, queue_item: SessionQueueItem) -> "SessionCompleteEvent":
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionCompleteEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
extra=extra,
)
@ -255,12 +277,13 @@ class SessionCanceledEvent(SessionEventBase):
__event_name__ = "session_canceled"
@classmethod
def build(cls, queue_item: SessionQueueItem) -> "SessionCanceledEvent":
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionCanceledEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
extra=extra,
)
@ -280,7 +303,11 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
@classmethod
def build(
cls, queue_item: SessionQueueItem, batch_status: BatchStatus, queue_status: SessionQueueStatus
cls,
queue_item: SessionQueueItem,
batch_status: BatchStatus,
queue_status: SessionQueueStatus,
extra: Optional[ExtraData] = None,
) -> "QueueItemStatusChangedEvent":
return cls(
queue_id=queue_item.queue_id,
@ -294,6 +321,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
batch_status=batch_status,
queue_status=queue_status,
extra=extra,
)
@ -310,13 +338,14 @@ class BatchEnqueuedEvent(QueueEventBase):
priority: int = Field(description="The priority of the batch")
@classmethod
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
def build(cls, enqueue_result: EnqueueBatchResult, extra: Optional[ExtraData] = None) -> "BatchEnqueuedEvent":
return cls(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
enqueued=enqueue_result.enqueued,
requested=enqueue_result.requested,
priority=enqueue_result.priority,
extra=extra,
)
@ -326,8 +355,11 @@ class QueueClearedEvent(QueueEventBase):
__event_name__ = "queue_cleared"
@classmethod
def build(cls, queue_id: str) -> "QueueClearedEvent":
return cls(queue_id=queue_id)
def build(cls, queue_id: str, extra: Optional[ExtraData] = None) -> "QueueClearedEvent":
return cls(
queue_id=queue_id,
extra=extra,
)
class DownloadEventBase(EventBase):
@ -344,9 +376,13 @@ class DownloadStartedEvent(DownloadEventBase):
download_path: str = Field(description="The local path where the download is saved")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadStartedEvent":
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadStartedEvent":
assert job.download_path
return cls(source=str(job.source), download_path=job.download_path.as_posix())
return cls(
source=str(job.source),
download_path=job.download_path.as_posix(),
extra=extra,
)
class DownloadProgressEvent(DownloadEventBase):
@ -359,13 +395,14 @@ class DownloadProgressEvent(DownloadEventBase):
total_bytes: int = Field(description="The total number of bytes to be downloaded")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadProgressEvent":
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadProgressEvent":
assert job.download_path
return cls(
source=str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
extra=extra,
)
@ -378,9 +415,14 @@ class DownloadCompleteEvent(DownloadEventBase):
total_bytes: int = Field(description="The total number of bytes downloaded")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadCompleteEvent":
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadCompleteEvent":
assert job.download_path
return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes)
return cls(
source=str(job.source),
download_path=job.download_path.as_posix(),
total_bytes=job.total_bytes,
extra=extra,
)
class DownloadCancelledEvent(DownloadEventBase):
@ -389,8 +431,11 @@ class DownloadCancelledEvent(DownloadEventBase):
__event_name__ = "download_cancelled"
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadCancelledEvent":
return cls(source=str(job.source))
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadCancelledEvent":
return cls(
source=str(job.source),
extra=extra,
)
class DownloadErrorEvent(DownloadEventBase):
@ -402,10 +447,15 @@ class DownloadErrorEvent(DownloadEventBase):
error: str = Field(description="The error message")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadErrorEvent":
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadErrorEvent":
assert job.error_type
assert job.error
return cls(source=str(job.source), error_type=job.error_type, error=job.error)
return cls(
source=str(job.source),
error_type=job.error_type,
error=job.error,
extra=extra,
)
class ModelEventBase(EventBase):
@ -421,8 +471,14 @@ class ModelLoadStartedEvent(ModelEventBase):
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
return cls(config=config, submodel_type=submodel_type)
def build(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, extra: Optional[ExtraData] = None
) -> "ModelLoadStartedEvent":
return cls(
config=config,
submodel_type=submodel_type,
extra=extra,
)
class ModelLoadCompleteEvent(ModelEventBase):
@ -434,8 +490,14 @@ class ModelLoadCompleteEvent(ModelEventBase):
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
return cls(config=config, submodel_type=submodel_type)
def build(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, extra: Optional[ExtraData] = None
) -> "ModelLoadCompleteEvent":
return cls(
config=config,
submodel_type=submodel_type,
extra=extra,
)
class ModelInstallDownloadProgressEvent(ModelEventBase):
@ -453,7 +515,7 @@ class ModelInstallDownloadProgressEvent(ModelEventBase):
)
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent":
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallDownloadProgressEvent":
parts: list[dict[str, str | int]] = [
{
"url": str(x.source),
@ -470,6 +532,7 @@ class ModelInstallDownloadProgressEvent(ModelEventBase):
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
extra=extra,
)
@ -482,8 +545,12 @@ class ModelInstallDownloadsCompleteEvent(ModelEventBase):
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent":
return cls(id=job.id, source=str(job.source))
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallDownloadsCompleteEvent":
return cls(
id=job.id,
source=str(job.source),
extra=extra,
)
class ModelInstallStartedEvent(ModelEventBase):
@ -495,8 +562,12 @@ class ModelInstallStartedEvent(ModelEventBase):
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent":
return cls(id=job.id, source=str(job.source))
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallStartedEvent":
return cls(
id=job.id,
source=str(job.source),
extra=extra,
)
class ModelInstallCompleteEvent(ModelEventBase):
@ -510,9 +581,15 @@ class ModelInstallCompleteEvent(ModelEventBase):
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent":
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallCompleteEvent":
assert job.config_out is not None
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
return cls(
id=job.id,
source=str(job.source),
key=(job.config_out.key),
total_bytes=job.total_bytes,
extra=extra,
)
class ModelInstallCancelledEvent(ModelEventBase):
@ -524,8 +601,12 @@ class ModelInstallCancelledEvent(ModelEventBase):
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent":
return cls(id=job.id, source=str(job.source))
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallCancelledEvent":
return cls(
id=job.id,
source=str(job.source),
extra=extra,
)
class ModelInstallErrorEvent(ModelEventBase):
@ -539,10 +620,16 @@ class ModelInstallErrorEvent(ModelEventBase):
error: str = Field(description="A text description of the exception")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent":
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallErrorEvent":
assert job.error_type is not None
assert job.error is not None
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
return cls(
id=job.id,
source=str(job.source),
error_type=job.error_type,
error=job.error,
extra=extra,
)
class BulkDownloadEventBase(EventBase):
@ -560,12 +647,17 @@ class BulkDownloadStartedEvent(BulkDownloadEventBase):
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
cls,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
extra: Optional[ExtraData] = None,
) -> "BulkDownloadStartedEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
extra=extra,
)
@ -576,12 +668,17 @@ class BulkDownloadCompleteEvent(BulkDownloadEventBase):
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
cls,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
extra: Optional[ExtraData] = None,
) -> "BulkDownloadCompleteEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
extra=extra,
)
@ -594,11 +691,17 @@ class BulkDownloadErrorEvent(BulkDownloadEventBase):
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
cls,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
error: str,
extra: Optional[ExtraData] = None,
) -> "BulkDownloadErrorEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=error,
extra=extra,
)