2024-04-01 07:16:33 +00:00
|
|
|
from math import floor
|
2024-05-28 05:32:47 +00:00
|
|
|
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
from fastapi_events.handlers.local import local_handler
|
2024-05-27 00:59:36 +00:00
|
|
|
from fastapi_events.registry.payload_schema import registry as payload_schema
|
2024-05-29 07:29:51 +00:00
|
|
|
from pydantic import BaseModel, ConfigDict, Field
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
|
|
|
from invokeai.app.services.session_queue.session_queue_common import (
|
|
|
|
QUEUE_ITEM_STATUS,
|
|
|
|
BatchStatus,
|
|
|
|
EnqueueBatchResult,
|
|
|
|
SessionQueueItem,
|
|
|
|
SessionQueueStatus,
|
|
|
|
)
|
2024-05-29 07:29:51 +00:00
|
|
|
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
|
2024-03-14 08:04:19 +00:00
|
|
|
from invokeai.app.util.misc import get_timestamp
|
|
|
|
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
2024-04-01 07:16:33 +00:00
|
|
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
2024-03-31 01:02:49 +00:00
|
|
|
from invokeai.app.services.download.download_base import DownloadJob
|
2024-03-14 08:04:19 +00:00
|
|
|
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
|
|
|
|
|
|
|
|
2024-03-10 12:23:11 +00:00
|
|
|
class EventBase(BaseModel):
|
2024-03-14 08:04:19 +00:00
|
|
|
"""Base class for all events. All events must inherit from this class.
|
|
|
|
|
2024-03-10 12:23:11 +00:00
|
|
|
Events must define a class attribute `__event_name__` to identify the event.
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
All other attributes should be defined as normal for a pydantic model.
|
|
|
|
|
|
|
|
A timestamp is automatically added to the event when it is created.
|
|
|
|
"""
|
|
|
|
|
2024-05-28 05:32:47 +00:00
|
|
|
__event_name__: ClassVar[str]
|
2024-03-14 08:04:19 +00:00
|
|
|
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
|
|
|
|
|
|
|
|
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
|
|
|
|
|
2024-04-01 06:29:02 +00:00
|
|
|
@classmethod
|
|
|
|
def get_events(cls) -> set[type["EventBase"]]:
|
|
|
|
"""Get a set of all event models."""
|
|
|
|
|
|
|
|
event_subclasses: set[type["EventBase"]] = set()
|
|
|
|
for subclass in cls.__subclasses__():
|
|
|
|
# We only want to include subclasses that are event models, not intermediary classes
|
|
|
|
if hasattr(subclass, "__event_name__"):
|
|
|
|
event_subclasses.add(subclass)
|
|
|
|
event_subclasses.update(subclass.get_events())
|
|
|
|
|
|
|
|
return event_subclasses
|
|
|
|
|
2024-03-14 08:04:19 +00:00
|
|
|
|
2024-05-27 00:01:38 +00:00
|
|
|
TEvent = TypeVar("TEvent", bound=EventBase, contravariant=True)
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
FastAPIEvent: TypeAlias = tuple[str, TEvent]
|
|
|
|
"""
|
|
|
|
A tuple representing a `fastapi-events` event, with the event name and payload.
|
|
|
|
Provide a generic type to `TEvent` to specify the payload type.
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
2024-05-27 00:01:38 +00:00
|
|
|
class FastAPIEventFunc(Protocol, Generic[TEvent]):
|
|
|
|
def __call__(self, event: FastAPIEvent[TEvent]) -> Optional[Coroutine[Any, Any, None]]: ...
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
|
2024-05-27 00:01:38 +00:00
|
|
|
def register_events(events: set[type[TEvent]] | type[TEvent], func: FastAPIEventFunc[TEvent]) -> None:
|
2024-05-25 23:41:03 +00:00
|
|
|
"""Register a function to handle specific events.
|
2024-03-14 08:04:19 +00:00
|
|
|
|
2024-05-25 23:41:03 +00:00
|
|
|
:param events: An event or set of events to handle
|
2024-03-14 08:04:19 +00:00
|
|
|
:param func: The function to handle the events
|
|
|
|
"""
|
2024-05-25 23:41:03 +00:00
|
|
|
events = events if isinstance(events, set) else {events}
|
2024-03-14 08:04:19 +00:00
|
|
|
for event in events:
|
2024-04-01 07:16:33 +00:00
|
|
|
assert hasattr(event, "__event_name__")
|
|
|
|
local_handler.register(event_name=event.__event_name__, _func=func) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
|
2024-03-10 12:23:11 +00:00
|
|
|
class QueueEventBase(EventBase):
|
2024-03-14 08:04:19 +00:00
|
|
|
"""Base class for queue events"""
|
|
|
|
|
|
|
|
queue_id: str = Field(description="The ID of the queue")
|
|
|
|
|
|
|
|
|
2024-03-10 12:23:11 +00:00
|
|
|
class QueueItemEventBase(QueueEventBase):
|
2024-03-14 08:04:19 +00:00
|
|
|
"""Base class for queue item events"""
|
|
|
|
|
|
|
|
item_id: int = Field(description="The ID of the queue item")
|
|
|
|
batch_id: str = Field(description="The ID of the queue batch")
|
|
|
|
|
|
|
|
|
2024-05-26 01:40:38 +00:00
|
|
|
class InvocationEventBase(QueueItemEventBase):
|
2024-03-14 08:04:19 +00:00
|
|
|
"""Base class for invocation events"""
|
|
|
|
|
2024-05-26 01:40:38 +00:00
|
|
|
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
2024-03-14 08:04:19 +00:00
|
|
|
queue_id: str = Field(description="The ID of the queue")
|
|
|
|
item_id: int = Field(description="The ID of the queue item")
|
|
|
|
batch_id: str = Field(description="The ID of the queue batch")
|
|
|
|
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
2024-05-29 07:29:51 +00:00
|
|
|
invocation: AnyInvocation = Field(description="The ID of the invocation")
|
2024-03-14 08:04:19 +00:00
|
|
|
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class InvocationStartedEvent(InvocationEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for invocation_started"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "invocation_started"
|
|
|
|
|
|
|
|
@classmethod
|
2024-05-29 07:29:51 +00:00
|
|
|
def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent":
|
2024-03-14 08:04:19 +00:00
|
|
|
return cls(
|
|
|
|
queue_id=queue_item.queue_id,
|
|
|
|
item_id=queue_item.item_id,
|
|
|
|
batch_id=queue_item.batch_id,
|
|
|
|
session_id=queue_item.session_id,
|
2024-05-27 00:02:54 +00:00
|
|
|
invocation=invocation,
|
2024-03-14 08:04:19 +00:00
|
|
|
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class InvocationDenoiseProgressEvent(InvocationEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for invocation_denoise_progress"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "invocation_denoise_progress"
|
|
|
|
|
|
|
|
progress_image: ProgressImage = Field(description="The progress image sent at each step during processing")
|
|
|
|
step: int = Field(description="The current step of the invocation")
|
|
|
|
total_steps: int = Field(description="The total number of steps in the invocation")
|
2024-04-01 07:16:33 +00:00
|
|
|
order: int = Field(description="The order of the invocation in the session")
|
|
|
|
percentage: float = Field(description="The percentage of completion of the invocation")
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build(
|
|
|
|
cls,
|
|
|
|
queue_item: SessionQueueItem,
|
2024-05-29 07:29:51 +00:00
|
|
|
invocation: AnyInvocation,
|
2024-04-01 07:16:33 +00:00
|
|
|
intermediate_state: PipelineIntermediateState,
|
2024-03-14 08:04:19 +00:00
|
|
|
progress_image: ProgressImage,
|
|
|
|
) -> "InvocationDenoiseProgressEvent":
|
2024-04-01 07:16:33 +00:00
|
|
|
step = intermediate_state.step
|
|
|
|
total_steps = intermediate_state.total_steps
|
|
|
|
order = intermediate_state.order
|
2024-03-14 08:04:19 +00:00
|
|
|
return cls(
|
|
|
|
queue_id=queue_item.queue_id,
|
|
|
|
item_id=queue_item.item_id,
|
|
|
|
batch_id=queue_item.batch_id,
|
|
|
|
session_id=queue_item.session_id,
|
2024-05-27 00:02:54 +00:00
|
|
|
invocation=invocation,
|
2024-03-14 08:04:19 +00:00
|
|
|
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
|
|
|
progress_image=progress_image,
|
|
|
|
step=step,
|
|
|
|
total_steps=total_steps,
|
2024-04-01 07:16:33 +00:00
|
|
|
order=order,
|
|
|
|
percentage=cls.calc_percentage(step, total_steps, order),
|
2024-03-14 08:04:19 +00:00
|
|
|
)
|
|
|
|
|
2024-04-01 07:16:33 +00:00
|
|
|
@staticmethod
|
|
|
|
def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float:
|
|
|
|
"""Calculate the percentage of completion of denoising."""
|
|
|
|
if total_steps == 0:
|
|
|
|
return 0.0
|
|
|
|
if scheduler_order == 2:
|
|
|
|
return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
|
|
|
|
# order == 1
|
|
|
|
return (step + 1 + 1) / (total_steps + 1)
|
|
|
|
|
2024-03-14 08:04:19 +00:00
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class InvocationCompleteEvent(InvocationEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for invocation_complete"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "invocation_complete"
|
|
|
|
|
2024-05-29 07:29:51 +00:00
|
|
|
result: AnyInvocationOutput = Field(description="The result of the invocation")
|
2024-05-28 05:35:57 +00:00
|
|
|
|
2024-03-14 08:04:19 +00:00
|
|
|
@classmethod
|
|
|
|
def build(
|
2024-05-29 07:29:51 +00:00
|
|
|
cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput
|
2024-03-14 08:04:19 +00:00
|
|
|
) -> "InvocationCompleteEvent":
|
|
|
|
return cls(
|
|
|
|
queue_id=queue_item.queue_id,
|
|
|
|
item_id=queue_item.item_id,
|
|
|
|
batch_id=queue_item.batch_id,
|
|
|
|
session_id=queue_item.session_id,
|
2024-05-27 00:02:54 +00:00
|
|
|
invocation=invocation,
|
2024-03-14 08:04:19 +00:00
|
|
|
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
|
|
|
result=result,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class InvocationErrorEvent(InvocationEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for invocation_error"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "invocation_error"
|
|
|
|
|
|
|
|
error_type: str = Field(description="The error type")
|
|
|
|
error_message: str = Field(description="The error message")
|
|
|
|
error_traceback: str = Field(description="The error traceback")
|
2024-05-25 10:04:35 +00:00
|
|
|
user_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation")
|
|
|
|
project_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation")
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build(
|
|
|
|
cls,
|
|
|
|
queue_item: SessionQueueItem,
|
2024-05-29 07:29:51 +00:00
|
|
|
invocation: AnyInvocation,
|
2024-03-14 08:04:19 +00:00
|
|
|
error_type: str,
|
|
|
|
error_message: str,
|
|
|
|
error_traceback: str,
|
|
|
|
) -> "InvocationErrorEvent":
|
|
|
|
return cls(
|
|
|
|
queue_id=queue_item.queue_id,
|
|
|
|
item_id=queue_item.item_id,
|
|
|
|
batch_id=queue_item.batch_id,
|
|
|
|
session_id=queue_item.session_id,
|
2024-05-27 00:02:54 +00:00
|
|
|
invocation=invocation,
|
2024-03-14 08:04:19 +00:00
|
|
|
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
|
|
|
error_type=error_type,
|
|
|
|
error_message=error_message,
|
|
|
|
error_traceback=error_traceback,
|
2024-05-25 10:04:35 +00:00
|
|
|
user_id=getattr(queue_item, "user_id", None),
|
|
|
|
project_id=getattr(queue_item, "project_id", None),
|
2024-03-14 08:04:19 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-14 08:25:55 +00:00
|
|
|
class QueueItemStatusChangedEvent(QueueItemEventBase):
|
|
|
|
"""Event model for queue_item_status_changed"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "queue_item_status_changed"
|
|
|
|
|
|
|
|
status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item")
|
|
|
|
error_type: Optional[str] = Field(default=None, description="The error type, if any")
|
|
|
|
error_message: Optional[str] = Field(default=None, description="The error message, if any")
|
|
|
|
error_traceback: Optional[str] = Field(default=None, description="The error traceback, if any")
|
|
|
|
created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created")
|
|
|
|
updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated")
|
|
|
|
started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started")
|
|
|
|
completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed")
|
|
|
|
batch_status: BatchStatus = Field(description="The status of the batch")
|
|
|
|
queue_status: SessionQueueStatus = Field(description="The status of the queue")
|
2024-03-14 08:25:55 +00:00
|
|
|
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build(
|
|
|
|
cls, queue_item: SessionQueueItem, batch_status: BatchStatus, queue_status: SessionQueueStatus
|
|
|
|
) -> "QueueItemStatusChangedEvent":
|
|
|
|
return cls(
|
|
|
|
queue_id=queue_item.queue_id,
|
|
|
|
item_id=queue_item.item_id,
|
|
|
|
batch_id=queue_item.batch_id,
|
2024-03-14 08:05:40 +00:00
|
|
|
session_id=queue_item.session_id,
|
2024-03-14 08:04:19 +00:00
|
|
|
status=queue_item.status,
|
|
|
|
error_type=queue_item.error_type,
|
|
|
|
error_message=queue_item.error_message,
|
|
|
|
error_traceback=queue_item.error_traceback,
|
|
|
|
created_at=str(queue_item.created_at) if queue_item.created_at else None,
|
|
|
|
updated_at=str(queue_item.updated_at) if queue_item.updated_at else None,
|
|
|
|
started_at=str(queue_item.started_at) if queue_item.started_at else None,
|
|
|
|
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
|
|
|
|
batch_status=batch_status,
|
|
|
|
queue_status=queue_status,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class BatchEnqueuedEvent(QueueEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for batch_enqueued"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "batch_enqueued"
|
|
|
|
|
|
|
|
batch_id: str = Field(description="The ID of the batch")
|
|
|
|
enqueued: int = Field(description="The number of invocations enqueued")
|
|
|
|
requested: int = Field(
|
|
|
|
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
|
|
|
|
)
|
|
|
|
priority: int = Field(description="The priority of the batch")
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
|
|
|
|
return cls(
|
|
|
|
queue_id=enqueue_result.queue_id,
|
|
|
|
batch_id=enqueue_result.batch.batch_id,
|
|
|
|
enqueued=enqueue_result.enqueued,
|
|
|
|
requested=enqueue_result.requested,
|
|
|
|
priority=enqueue_result.priority,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class QueueClearedEvent(QueueEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for queue_cleared"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "queue_cleared"
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build(cls, queue_id: str) -> "QueueClearedEvent":
|
|
|
|
return cls(queue_id=queue_id)
|
|
|
|
|
|
|
|
|
2024-03-10 12:23:11 +00:00
|
|
|
class DownloadEventBase(EventBase):
|
2024-03-14 08:04:19 +00:00
|
|
|
"""Base class for events associated with a download"""
|
|
|
|
|
|
|
|
source: str = Field(description="The source of the download")
|
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class DownloadStartedEvent(DownloadEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for download_started"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "download_started"
|
|
|
|
|
|
|
|
download_path: str = Field(description="The local path where the download is saved")
|
|
|
|
|
|
|
|
@classmethod
|
2024-03-31 01:02:49 +00:00
|
|
|
def build(cls, job: "DownloadJob") -> "DownloadStartedEvent":
|
|
|
|
assert job.download_path
|
|
|
|
return cls(source=str(job.source), download_path=job.download_path.as_posix())
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class DownloadProgressEvent(DownloadEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for download_progress"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "download_progress"
|
|
|
|
|
|
|
|
download_path: str = Field(description="The local path where the download is saved")
|
|
|
|
current_bytes: int = Field(description="The number of bytes downloaded so far")
|
|
|
|
total_bytes: int = Field(description="The total number of bytes to be downloaded")
|
|
|
|
|
|
|
|
@classmethod
|
2024-03-31 01:02:49 +00:00
|
|
|
def build(cls, job: "DownloadJob") -> "DownloadProgressEvent":
|
|
|
|
assert job.download_path
|
|
|
|
return cls(
|
|
|
|
source=str(job.source),
|
|
|
|
download_path=job.download_path.as_posix(),
|
|
|
|
current_bytes=job.bytes,
|
|
|
|
total_bytes=job.total_bytes,
|
|
|
|
)
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class DownloadCompleteEvent(DownloadEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for download_complete"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "download_complete"
|
|
|
|
|
|
|
|
download_path: str = Field(description="The local path where the download is saved")
|
|
|
|
total_bytes: int = Field(description="The total number of bytes downloaded")
|
|
|
|
|
|
|
|
@classmethod
|
2024-03-31 01:02:49 +00:00
|
|
|
def build(cls, job: "DownloadJob") -> "DownloadCompleteEvent":
|
|
|
|
assert job.download_path
|
|
|
|
return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes)
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class DownloadCancelledEvent(DownloadEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for download_cancelled"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "download_cancelled"
|
|
|
|
|
|
|
|
@classmethod
|
2024-03-31 01:02:49 +00:00
|
|
|
def build(cls, job: "DownloadJob") -> "DownloadCancelledEvent":
|
|
|
|
return cls(source=str(job.source))
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class DownloadErrorEvent(DownloadEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for download_error"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "download_error"
|
|
|
|
|
|
|
|
error_type: str = Field(description="The type of error")
|
|
|
|
error: str = Field(description="The error message")
|
|
|
|
|
|
|
|
@classmethod
|
2024-03-31 01:02:49 +00:00
|
|
|
def build(cls, job: "DownloadJob") -> "DownloadErrorEvent":
|
|
|
|
assert job.error_type
|
|
|
|
assert job.error
|
|
|
|
return cls(source=str(job.source), error_type=job.error_type, error=job.error)
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
|
2024-03-10 12:23:11 +00:00
|
|
|
class ModelEventBase(EventBase):
|
2024-03-14 08:04:19 +00:00
|
|
|
"""Base class for events associated with a model"""
|
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class ModelLoadStartedEvent(ModelEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for model_load_started"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "model_load_started"
|
|
|
|
|
|
|
|
config: AnyModelConfig = Field(description="The model's config")
|
|
|
|
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
|
|
|
|
return cls(config=config, submodel_type=submodel_type)
|
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class ModelLoadCompleteEvent(ModelEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for model_load_complete"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "model_load_complete"
|
|
|
|
|
|
|
|
config: AnyModelConfig = Field(description="The model's config")
|
|
|
|
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
|
|
|
|
return cls(config=config, submodel_type=submodel_type)
|
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class ModelInstallDownloadProgressEvent(ModelEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for model_install_download_progress"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "model_install_download_progress"
|
|
|
|
|
|
|
|
id: int = Field(description="The ID of the install job")
|
|
|
|
source: str = Field(description="Source of the model; local path, repo_id or url")
|
|
|
|
local_path: str = Field(description="Where model is downloading to")
|
|
|
|
bytes: int = Field(description="Number of bytes downloaded so far")
|
|
|
|
total_bytes: int = Field(description="Total size of download, including all files")
|
|
|
|
parts: list[dict[str, int | str]] = Field(
|
|
|
|
description="Progress of downloading URLs that comprise the model, if any"
|
|
|
|
)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent":
|
|
|
|
parts: list[dict[str, str | int]] = [
|
|
|
|
{
|
|
|
|
"url": str(x.source),
|
|
|
|
"local_path": str(x.download_path),
|
|
|
|
"bytes": x.bytes,
|
|
|
|
"total_bytes": x.total_bytes,
|
|
|
|
}
|
|
|
|
for x in job.download_parts
|
|
|
|
]
|
|
|
|
return cls(
|
|
|
|
id=job.id,
|
|
|
|
source=str(job.source),
|
|
|
|
local_path=job.local_path.as_posix(),
|
|
|
|
parts=parts,
|
|
|
|
bytes=job.bytes,
|
|
|
|
total_bytes=job.total_bytes,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
|
2024-03-14 08:04:19 +00:00
|
|
|
"""Emitted once when an install job becomes active."""
|
|
|
|
|
|
|
|
__event_name__ = "model_install_downloads_complete"
|
|
|
|
|
|
|
|
id: int = Field(description="The ID of the install job")
|
|
|
|
source: str = Field(description="Source of the model; local path, repo_id or url")
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent":
|
|
|
|
return cls(id=job.id, source=str(job.source))
|
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class ModelInstallStartedEvent(ModelEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for model_install_started"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "model_install_started"
|
|
|
|
|
|
|
|
id: int = Field(description="The ID of the install job")
|
|
|
|
source: str = Field(description="Source of the model; local path, repo_id or url")
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent":
|
|
|
|
return cls(id=job.id, source=str(job.source))
|
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class ModelInstallCompleteEvent(ModelEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for model_install_complete"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "model_install_complete"
|
|
|
|
|
|
|
|
id: int = Field(description="The ID of the install job")
|
|
|
|
source: str = Field(description="Source of the model; local path, repo_id or url")
|
|
|
|
key: str = Field(description="Model config record key")
|
|
|
|
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent":
|
|
|
|
assert job.config_out is not None
|
|
|
|
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
|
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class ModelInstallCancelledEvent(ModelEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for model_install_cancelled"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "model_install_cancelled"
|
|
|
|
|
|
|
|
id: int = Field(description="The ID of the install job")
|
|
|
|
source: str = Field(description="Source of the model; local path, repo_id or url")
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent":
|
|
|
|
return cls(id=job.id, source=str(job.source))
|
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class ModelInstallErrorEvent(ModelEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for model_install_error"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "model_install_error"
|
|
|
|
|
|
|
|
id: int = Field(description="The ID of the install job")
|
|
|
|
source: str = Field(description="Source of the model; local path, repo_id or url")
|
|
|
|
error_type: str = Field(description="The name of the exception")
|
|
|
|
error: str = Field(description="A text description of the exception")
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent":
|
|
|
|
assert job.error_type is not None
|
|
|
|
assert job.error is not None
|
|
|
|
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
|
|
|
|
|
|
|
|
|
2024-03-10 12:23:11 +00:00
|
|
|
class BulkDownloadEventBase(EventBase):
|
2024-03-14 08:04:19 +00:00
|
|
|
"""Base class for events associated with a bulk image download"""
|
|
|
|
|
|
|
|
bulk_download_id: str = Field(description="The ID of the bulk image download")
|
|
|
|
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
|
|
|
|
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
|
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class BulkDownloadStartedEvent(BulkDownloadEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for bulk_download_started"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "bulk_download_started"
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build(
|
|
|
|
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
|
|
|
) -> "BulkDownloadStartedEvent":
|
|
|
|
return cls(
|
|
|
|
bulk_download_id=bulk_download_id,
|
|
|
|
bulk_download_item_id=bulk_download_item_id,
|
|
|
|
bulk_download_item_name=bulk_download_item_name,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for bulk_download_complete"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "bulk_download_complete"
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build(
|
|
|
|
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
|
|
|
) -> "BulkDownloadCompleteEvent":
|
|
|
|
return cls(
|
|
|
|
bulk_download_id=bulk_download_id,
|
|
|
|
bulk_download_item_id=bulk_download_item_id,
|
|
|
|
bulk_download_item_name=bulk_download_item_name,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2024-05-27 00:59:36 +00:00
|
|
|
@payload_schema.register
|
2024-03-10 12:23:11 +00:00
|
|
|
class BulkDownloadErrorEvent(BulkDownloadEventBase):
|
2024-03-14 08:25:55 +00:00
|
|
|
"""Event model for bulk_download_error"""
|
2024-03-14 08:04:19 +00:00
|
|
|
|
|
|
|
__event_name__ = "bulk_download_error"
|
|
|
|
|
|
|
|
error: str = Field(description="The error message")
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build(
|
|
|
|
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
|
|
|
|
) -> "BulkDownloadErrorEvent":
|
|
|
|
return cls(
|
|
|
|
bulk_download_id=bulk_download_id,
|
|
|
|
bulk_download_item_id=bulk_download_item_id,
|
|
|
|
bulk_download_item_name=bulk_download_item_name,
|
|
|
|
error=error,
|
|
|
|
)
|