mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(events): add extra
field to event payloads
This allows for arbitrary serializable data to be sent with events.
This commit is contained in:
@ -14,6 +14,7 @@ from invokeai.app.services.events.events_common import (
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
EventBase,
|
||||
ExtraData,
|
||||
InvocationCompleteEvent,
|
||||
InvocationDenoiseProgressEvent,
|
||||
InvocationErrorEvent,
|
||||
@ -57,9 +58,11 @@ class EventServiceBase:
|
||||
|
||||
# region: Invocation
|
||||
|
||||
def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None:
|
||||
def emit_invocation_started(
|
||||
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", extra: Optional[ExtraData] = None
|
||||
) -> None:
|
||||
"""Emitted when an invocation is started"""
|
||||
self.dispatch(InvocationStartedEvent.build(queue_item, invocation))
|
||||
self.dispatch(InvocationStartedEvent.build(queue_item, invocation, extra))
|
||||
|
||||
def emit_invocation_denoise_progress(
|
||||
self,
|
||||
@ -67,143 +70,184 @@ class EventServiceBase:
|
||||
invocation: "BaseInvocation",
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
progress_image: "ProgressImage",
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted at each step during denoising of an invocation."""
|
||||
self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image))
|
||||
self.dispatch(
|
||||
InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image, extra)
|
||||
)
|
||||
|
||||
def emit_invocation_complete(
|
||||
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"
|
||||
self,
|
||||
queue_item: "SessionQueueItem",
|
||||
invocation: "BaseInvocation",
|
||||
output: "BaseInvocationOutput",
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when an invocation is complete"""
|
||||
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output))
|
||||
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output, extra))
|
||||
|
||||
def emit_invocation_error(
|
||||
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", error_type: str, error: str
|
||||
self,
|
||||
queue_item: "SessionQueueItem",
|
||||
invocation: "BaseInvocation",
|
||||
error_type: str,
|
||||
error: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when an invocation encounters an error"""
|
||||
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error))
|
||||
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error, extra))
|
||||
|
||||
# endregion
|
||||
|
||||
# region Session
|
||||
|
||||
def emit_session_started(self, queue_item: "SessionQueueItem") -> None:
|
||||
def emit_session_started(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a session has started"""
|
||||
self.dispatch(SessionStartedEvent.build(queue_item))
|
||||
self.dispatch(SessionStartedEvent.build(queue_item, extra))
|
||||
|
||||
def emit_session_complete(self, queue_item: "SessionQueueItem") -> None:
|
||||
def emit_session_complete(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a session has completed all invocations"""
|
||||
self.dispatch(SessionCompleteEvent.build(queue_item))
|
||||
self.dispatch(SessionCompleteEvent.build(queue_item, extra))
|
||||
|
||||
def emit_session_canceled(self, queue_item: "SessionQueueItem") -> None:
|
||||
def emit_session_canceled(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a session is canceled"""
|
||||
self.dispatch(SessionCanceledEvent.build(queue_item))
|
||||
self.dispatch(SessionCanceledEvent.build(queue_item, extra))
|
||||
|
||||
# endregion
|
||||
|
||||
# region Queue
|
||||
|
||||
def emit_queue_item_status_changed(
|
||||
self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus"
|
||||
self,
|
||||
queue_item: "SessionQueueItem",
|
||||
batch_status: "BatchStatus",
|
||||
queue_status: "SessionQueueStatus",
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when a queue item's status changes"""
|
||||
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
|
||||
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status, extra))
|
||||
|
||||
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
|
||||
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a batch is enqueued"""
|
||||
self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
|
||||
self.dispatch(BatchEnqueuedEvent.build(enqueue_result, extra))
|
||||
|
||||
def emit_queue_cleared(self, queue_id: str) -> None:
|
||||
def emit_queue_cleared(self, queue_id: str, extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a queue is cleared"""
|
||||
self.dispatch(QueueClearedEvent.build(queue_id))
|
||||
self.dispatch(QueueClearedEvent.build(queue_id, extra))
|
||||
|
||||
# endregion
|
||||
|
||||
# region Download
|
||||
|
||||
def emit_download_started(self, job: "DownloadJob") -> None:
|
||||
def emit_download_started(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a download is started"""
|
||||
self.dispatch(DownloadStartedEvent.build(job))
|
||||
self.dispatch(DownloadStartedEvent.build(job, extra))
|
||||
|
||||
def emit_download_progress(self, job: "DownloadJob") -> None:
|
||||
def emit_download_progress(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted at intervals during a download"""
|
||||
self.dispatch(DownloadProgressEvent.build(job))
|
||||
self.dispatch(DownloadProgressEvent.build(job, extra))
|
||||
|
||||
def emit_download_complete(self, job: "DownloadJob") -> None:
|
||||
def emit_download_complete(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a download is completed"""
|
||||
self.dispatch(DownloadCompleteEvent.build(job))
|
||||
self.dispatch(DownloadCompleteEvent.build(job, extra))
|
||||
|
||||
def emit_download_cancelled(self, job: "DownloadJob") -> None:
|
||||
def emit_download_cancelled(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a download is cancelled"""
|
||||
self.dispatch(DownloadCancelledEvent.build(job))
|
||||
self.dispatch(DownloadCancelledEvent.build(job, extra))
|
||||
|
||||
def emit_download_error(self, job: "DownloadJob") -> None:
|
||||
def emit_download_error(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a download encounters an error"""
|
||||
self.dispatch(DownloadErrorEvent.build(job))
|
||||
self.dispatch(DownloadErrorEvent.build(job, extra))
|
||||
|
||||
# endregion
|
||||
|
||||
# region Model loading
|
||||
|
||||
def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
|
||||
def emit_model_load_started(
|
||||
self,
|
||||
config: "AnyModelConfig",
|
||||
submodel_type: Optional["SubModelType"] = None,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when a model load is started."""
|
||||
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
|
||||
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type, extra))
|
||||
|
||||
def emit_model_load_complete(
|
||||
self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
|
||||
self,
|
||||
config: "AnyModelConfig",
|
||||
submodel_type: Optional["SubModelType"] = None,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when a model load is complete."""
|
||||
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
|
||||
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type, extra))
|
||||
|
||||
# endregion
|
||||
|
||||
# region Model install
|
||||
|
||||
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
|
||||
def emit_model_install_download_progress(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted at intervals while the install job is in progress (remote models only)."""
|
||||
self.dispatch(ModelInstallDownloadProgressEvent.build(job))
|
||||
self.dispatch(ModelInstallDownloadProgressEvent.build(job, extra))
|
||||
|
||||
def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None:
|
||||
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job))
|
||||
def emit_model_install_downloads_complete(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
|
||||
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job, extra))
|
||||
|
||||
def emit_model_install_started(self, job: "ModelInstallJob") -> None:
|
||||
def emit_model_install_started(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted once when an install job is started (after any download)."""
|
||||
self.dispatch(ModelInstallStartedEvent.build(job))
|
||||
self.dispatch(ModelInstallStartedEvent.build(job, extra))
|
||||
|
||||
def emit_model_install_complete(self, job: "ModelInstallJob") -> None:
|
||||
def emit_model_install_complete(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when an install job is completed successfully."""
|
||||
self.dispatch(ModelInstallCompleteEvent.build(job))
|
||||
self.dispatch(ModelInstallCompleteEvent.build(job, extra))
|
||||
|
||||
def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None:
|
||||
def emit_model_install_cancelled(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when an install job is cancelled."""
|
||||
self.dispatch(ModelInstallCancelledEvent.build(job))
|
||||
self.dispatch(ModelInstallCancelledEvent.build(job, extra))
|
||||
|
||||
def emit_model_install_error(self, job: "ModelInstallJob") -> None:
|
||||
def emit_model_install_error(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when an install job encounters an exception."""
|
||||
self.dispatch(ModelInstallErrorEvent.build(job))
|
||||
self.dispatch(ModelInstallErrorEvent.build(job, extra))
|
||||
|
||||
# endregion
|
||||
|
||||
# region Bulk image download
|
||||
|
||||
def emit_bulk_download_started(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
self,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when a bulk image download is started"""
|
||||
self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
|
||||
self.dispatch(
|
||||
BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, extra)
|
||||
)
|
||||
|
||||
def emit_bulk_download_complete(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
self,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when a bulk image download is complete"""
|
||||
self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
|
||||
self.dispatch(
|
||||
BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, extra)
|
||||
)
|
||||
|
||||
def emit_bulk_download_error(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
|
||||
self,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
error: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when a bulk image download has an error"""
|
||||
self.dispatch(
|
||||
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
|
||||
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error, extra)
|
||||
)
|
||||
|
||||
# endregion
|
||||
|
@ -22,6 +22,9 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
|
||||
|
||||
ExtraData: TypeAlias = dict[str, Any]
|
||||
|
||||
|
||||
class EventBase(BaseModel):
|
||||
"""Base class for all events. All events must inherit from this class.
|
||||
|
||||
@ -33,6 +36,7 @@ class EventBase(BaseModel):
|
||||
"""
|
||||
|
||||
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
|
||||
extra: Optional[ExtraData] = Field(default=None, description="Extra data to include with the event")
|
||||
|
||||
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
|
||||
|
||||
@ -111,7 +115,9 @@ class InvocationStartedEvent(InvocationEventBase):
|
||||
__event_name__ = "invocation_started"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent":
|
||||
def build(
|
||||
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, extra: Optional[ExtraData] = None
|
||||
) -> "InvocationStartedEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
@ -120,6 +126,7 @@ class InvocationStartedEvent(InvocationEventBase):
|
||||
invocation_id=invocation.id,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
invocation_type=invocation.get_type(),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
@ -141,6 +148,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
|
||||
invocation: BaseInvocation,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
progress_image: ProgressImage,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "InvocationDenoiseProgressEvent":
|
||||
step = intermediate_state.step
|
||||
total_steps = intermediate_state.total_steps
|
||||
@ -158,6 +166,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
|
||||
total_steps=total_steps,
|
||||
order=order,
|
||||
percentage=cls.calc_percentage(step, total_steps, order),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -180,7 +189,11 @@ class InvocationCompleteEvent(InvocationEventBase):
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
invocation: BaseInvocation,
|
||||
result: BaseInvocationOutput,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "InvocationCompleteEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
@ -191,6 +204,7 @@ class InvocationCompleteEvent(InvocationEventBase):
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
invocation_type=invocation.get_type(),
|
||||
result=result,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
@ -204,7 +218,12 @@ class InvocationErrorEvent(InvocationEventBase):
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, error_type: str, error: str
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
invocation: BaseInvocation,
|
||||
error_type: str,
|
||||
error: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "InvocationErrorEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
@ -216,6 +235,7 @@ class InvocationErrorEvent(InvocationEventBase):
|
||||
invocation_type=invocation.get_type(),
|
||||
error_type=error_type,
|
||||
error=error,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
@ -225,12 +245,13 @@ class SessionStartedEvent(SessionEventBase):
|
||||
__event_name__ = "session_started"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_item: SessionQueueItem) -> "SessionStartedEvent":
|
||||
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionStartedEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
@ -240,12 +261,13 @@ class SessionCompleteEvent(SessionEventBase):
|
||||
__event_name__ = "session_complete"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_item: SessionQueueItem) -> "SessionCompleteEvent":
|
||||
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionCompleteEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
@ -255,12 +277,13 @@ class SessionCanceledEvent(SessionEventBase):
|
||||
__event_name__ = "session_canceled"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_item: SessionQueueItem) -> "SessionCanceledEvent":
|
||||
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionCanceledEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
@ -280,7 +303,11 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, queue_item: SessionQueueItem, batch_status: BatchStatus, queue_status: SessionQueueStatus
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
batch_status: BatchStatus,
|
||||
queue_status: SessionQueueStatus,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "QueueItemStatusChangedEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
@ -294,6 +321,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
@ -310,13 +338,14 @@ class BatchEnqueuedEvent(QueueEventBase):
|
||||
priority: int = Field(description="The priority of the batch")
|
||||
|
||||
@classmethod
|
||||
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
|
||||
def build(cls, enqueue_result: EnqueueBatchResult, extra: Optional[ExtraData] = None) -> "BatchEnqueuedEvent":
|
||||
return cls(
|
||||
queue_id=enqueue_result.queue_id,
|
||||
batch_id=enqueue_result.batch.batch_id,
|
||||
enqueued=enqueue_result.enqueued,
|
||||
requested=enqueue_result.requested,
|
||||
priority=enqueue_result.priority,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
@ -326,8 +355,11 @@ class QueueClearedEvent(QueueEventBase):
|
||||
__event_name__ = "queue_cleared"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_id: str) -> "QueueClearedEvent":
|
||||
return cls(queue_id=queue_id)
|
||||
def build(cls, queue_id: str, extra: Optional[ExtraData] = None) -> "QueueClearedEvent":
|
||||
return cls(
|
||||
queue_id=queue_id,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class DownloadEventBase(EventBase):
|
||||
@ -344,9 +376,13 @@ class DownloadStartedEvent(DownloadEventBase):
|
||||
download_path: str = Field(description="The local path where the download is saved")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob") -> "DownloadStartedEvent":
|
||||
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadStartedEvent":
|
||||
assert job.download_path
|
||||
return cls(source=str(job.source), download_path=job.download_path.as_posix())
|
||||
return cls(
|
||||
source=str(job.source),
|
||||
download_path=job.download_path.as_posix(),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class DownloadProgressEvent(DownloadEventBase):
|
||||
@ -359,13 +395,14 @@ class DownloadProgressEvent(DownloadEventBase):
|
||||
total_bytes: int = Field(description="The total number of bytes to be downloaded")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob") -> "DownloadProgressEvent":
|
||||
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadProgressEvent":
|
||||
assert job.download_path
|
||||
return cls(
|
||||
source=str(job.source),
|
||||
download_path=job.download_path.as_posix(),
|
||||
current_bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
@ -378,9 +415,14 @@ class DownloadCompleteEvent(DownloadEventBase):
|
||||
total_bytes: int = Field(description="The total number of bytes downloaded")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob") -> "DownloadCompleteEvent":
|
||||
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadCompleteEvent":
|
||||
assert job.download_path
|
||||
return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes)
|
||||
return cls(
|
||||
source=str(job.source),
|
||||
download_path=job.download_path.as_posix(),
|
||||
total_bytes=job.total_bytes,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class DownloadCancelledEvent(DownloadEventBase):
|
||||
@ -389,8 +431,11 @@ class DownloadCancelledEvent(DownloadEventBase):
|
||||
__event_name__ = "download_cancelled"
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob") -> "DownloadCancelledEvent":
|
||||
return cls(source=str(job.source))
|
||||
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadCancelledEvent":
|
||||
return cls(
|
||||
source=str(job.source),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class DownloadErrorEvent(DownloadEventBase):
|
||||
@ -402,10 +447,15 @@ class DownloadErrorEvent(DownloadEventBase):
|
||||
error: str = Field(description="The error message")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob") -> "DownloadErrorEvent":
|
||||
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadErrorEvent":
|
||||
assert job.error_type
|
||||
assert job.error
|
||||
return cls(source=str(job.source), error_type=job.error_type, error=job.error)
|
||||
return cls(
|
||||
source=str(job.source),
|
||||
error_type=job.error_type,
|
||||
error=job.error,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelEventBase(EventBase):
|
||||
@ -421,8 +471,14 @@ class ModelLoadStartedEvent(ModelEventBase):
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
|
||||
|
||||
@classmethod
|
||||
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
|
||||
return cls(config=config, submodel_type=submodel_type)
|
||||
def build(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, extra: Optional[ExtraData] = None
|
||||
) -> "ModelLoadStartedEvent":
|
||||
return cls(
|
||||
config=config,
|
||||
submodel_type=submodel_type,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelLoadCompleteEvent(ModelEventBase):
|
||||
@ -434,8 +490,14 @@ class ModelLoadCompleteEvent(ModelEventBase):
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
|
||||
|
||||
@classmethod
|
||||
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
|
||||
return cls(config=config, submodel_type=submodel_type)
|
||||
def build(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, extra: Optional[ExtraData] = None
|
||||
) -> "ModelLoadCompleteEvent":
|
||||
return cls(
|
||||
config=config,
|
||||
submodel_type=submodel_type,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallDownloadProgressEvent(ModelEventBase):
|
||||
@ -453,7 +515,7 @@ class ModelInstallDownloadProgressEvent(ModelEventBase):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent":
|
||||
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallDownloadProgressEvent":
|
||||
parts: list[dict[str, str | int]] = [
|
||||
{
|
||||
"url": str(x.source),
|
||||
@ -470,6 +532,7 @@ class ModelInstallDownloadProgressEvent(ModelEventBase):
|
||||
parts=parts,
|
||||
bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
@ -482,8 +545,12 @@ class ModelInstallDownloadsCompleteEvent(ModelEventBase):
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent":
|
||||
return cls(id=job.id, source=str(job.source))
|
||||
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallDownloadsCompleteEvent":
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallStartedEvent(ModelEventBase):
|
||||
@ -495,8 +562,12 @@ class ModelInstallStartedEvent(ModelEventBase):
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent":
|
||||
return cls(id=job.id, source=str(job.source))
|
||||
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallStartedEvent":
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallCompleteEvent(ModelEventBase):
|
||||
@ -510,9 +581,15 @@ class ModelInstallCompleteEvent(ModelEventBase):
|
||||
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent":
|
||||
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallCompleteEvent":
|
||||
assert job.config_out is not None
|
||||
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
key=(job.config_out.key),
|
||||
total_bytes=job.total_bytes,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallCancelledEvent(ModelEventBase):
|
||||
@ -524,8 +601,12 @@ class ModelInstallCancelledEvent(ModelEventBase):
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent":
|
||||
return cls(id=job.id, source=str(job.source))
|
||||
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallCancelledEvent":
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallErrorEvent(ModelEventBase):
|
||||
@ -539,10 +620,16 @@ class ModelInstallErrorEvent(ModelEventBase):
|
||||
error: str = Field(description="A text description of the exception")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent":
|
||||
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallErrorEvent":
|
||||
assert job.error_type is not None
|
||||
assert job.error is not None
|
||||
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
error_type=job.error_type,
|
||||
error=job.error,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class BulkDownloadEventBase(EventBase):
|
||||
@ -560,12 +647,17 @@ class BulkDownloadStartedEvent(BulkDownloadEventBase):
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
cls,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "BulkDownloadStartedEvent":
|
||||
return cls(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
@ -576,12 +668,17 @@ class BulkDownloadCompleteEvent(BulkDownloadEventBase):
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
cls,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "BulkDownloadCompleteEvent":
|
||||
return cls(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
@ -594,11 +691,17 @@ class BulkDownloadErrorEvent(BulkDownloadEventBase):
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
|
||||
cls,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
error: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "BulkDownloadErrorEvent":
|
||||
return cls(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
error=error,
|
||||
extra=extra,
|
||||
)
|
||||
|
Reference in New Issue
Block a user