From 487815b1816308925a6eb20c7d234d8548c6ac96 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 3 Aug 2024 22:01:36 +1000 Subject: [PATCH] feat(app): generic progress events Some processes have steps, like denoising or a tiled spandel. Denoising has its own step callback but we don't have any generic way to signal progress. Processes like a tiled spandrel run show indeterminate progress in the client. This change introduces a new event to handle this: `InvocationGenericProgressEvent` A simplified helper is added to the invocation API so nodes can easily emit progress as they do their thing. --- invokeai/app/api/sockets.py | 2 + invokeai/app/services/events/events_base.py | 26 ++++++++ invokeai/app/services/events/events_common.py | 63 ++++++++++++++++++- .../app/services/shared/invocation_context.py | 44 +++++++++++++ 4 files changed, 134 insertions(+), 1 deletion(-) diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index b39922c69b..da009c3940 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -22,6 +22,7 @@ from invokeai.app.services.events.events_common import ( InvocationCompleteEvent, InvocationDenoiseProgressEvent, InvocationErrorEvent, + InvocationGenericProgressEvent, InvocationStartedEvent, ModelEventBase, ModelInstallCancelledEvent, @@ -56,6 +57,7 @@ class BulkDownloadSubscriptionEvent(BaseModel): QUEUE_EVENTS = { InvocationStartedEvent, InvocationDenoiseProgressEvent, + InvocationGenericProgressEvent, InvocationCompleteEvent, InvocationErrorEvent, QueueItemStatusChangedEvent, diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index bb578c23e8..5c1c8d2f93 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -3,6 +3,8 @@ from typing import TYPE_CHECKING, Optional +from PIL.Image import Image as PILImageType + from invokeai.app.services.events.events_common import ( BatchEnqueuedEvent, BulkDownloadCompleteEvent, @@ -17,6 +19,7 @@ from invokeai.app.services.events.events_common import ( InvocationCompleteEvent, InvocationDenoiseProgressEvent, InvocationErrorEvent, + InvocationGenericProgressEvent, InvocationStartedEvent, ModelInstallCancelledEvent, ModelInstallCompleteEvent, @@ -58,6 +61,29 @@ class EventServiceBase: """Emitted when an invocation is started""" self.dispatch(InvocationStartedEvent.build(queue_item, invocation)) + def emit_invocation_generic_progress( + self, + queue_item: "SessionQueueItem", + invocation: "BaseInvocation", + name: str, + step: int | None = None, + total_steps: int | None = None, + message: str | None = None, + image: PILImageType | None = None, + ) -> None: + """Emitted at each step during an invocation""" + self.dispatch( + InvocationGenericProgressEvent.build( + queue_item, + invocation, + name, + step, + total_steps, + message, + image, + ) + ) + def emit_invocation_denoise_progress( self, queue_item: "SessionQueueItem", diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index c6a867fb08..41a2b4566c 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -3,7 +3,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, P from fastapi_events.handlers.local import local_handler from fastapi_events.registry.payload_schema import registry as payload_schema -from pydantic import BaseModel, ConfigDict, Field +from PIL.Image import Image as PILImageType +from pydantic import BaseModel, ConfigDict, Field, model_validator from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_queue.session_queue_common import ( @@ -17,6 +18,7 @@ from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutpu from invokeai.app.util.misc import get_timestamp from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState +from invokeai.backend.util.util import image_to_dataURL if TYPE_CHECKING: from invokeai.app.services.download.download_base import DownloadJob @@ -120,6 +122,65 @@ class InvocationStartedEvent(InvocationEventBase): ) +@payload_schema.register +class InvocationGenericProgressEvent(InvocationEventBase): + """Event model for invocation_generic_progress""" + + __event_name__ = "invocation_generic_progress" + + name: str = Field(description="The name of the progress type") + step: int | None = Field( + default=None, + description="The current step. Omit for indeterminate progress.", + ) + total_steps: int | None = Field( + default=None, + description="The total number of steps. Omit for indeterminate progress.", + ) + image: ProgressImage | None = Field(default=None, description="An image sent at each step during processing") + message: str | None = Field(default=None, description="A message to display with the progress") + + @model_validator(mode="after") + def validate_step_total_steps(self): + if (self.step is None) is not (self.total_steps is None): + raise ValueError("must provide both step and total_steps or neither") + return self + + @classmethod + def build( + cls, + queue_item: SessionQueueItem, + invocation: AnyInvocation, + name: str, + step: int | None = None, + total_steps: int | None = None, + message: str | None = None, + image: PILImageType | None = None, + ) -> "InvocationGenericProgressEvent": + image_ = ( + ProgressImage( + dataURL=image_to_dataURL(image, image_format="JPEG"), + width=image.width, + height=image.height, + ) + if image + else None + ) + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation=invocation, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + name=name, + step=step, + total_steps=total_steps, + image=image_, + message=message, + ) + + @payload_schema.register class InvocationDenoiseProgressEvent(InvocationEventBase): """Event model for invocation_denoise_progress""" diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 01662335e4..5c0ba97a05 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -557,6 +557,50 @@ class UtilInterface(InvocationContextInterface): is_canceled=self.is_canceled, ) + def signal_progress( + self, + name: str, + step: int | None = None, + total_steps: int | None = None, + message: str | None = None, + image: Image | None = None, + ) -> None: + """Signals the progress of some long-running invocation process. The progress is displayed in the UI. + + Each progress event is grouped by both the given `name` and the invocation's ID. Once the invocation completes, + future progress events with the same name will be grouped separately. + + For progress that has a known number of steps, provide both `step` and `total_steps`. For indeterminate + progress, omit both `step` and `total_steps`. An error will be raised if only one of `step` and `total_steps` + is provided. + + For the best user experience: + - Signal process once with `step=0, total_steps=total_steps` before processing begins. + - Signal process after each step completes with `step=current_step, total_steps=total_steps`. + - Signal process once with `step=total_steps, total_steps=total_steps` after processing completes, if this + wasn't already done. + - If the process is indeterminate, signal progress with `step=None, total_steps=None` at regular intervals. + + Args: + name: The name of the action. This is used to group progress events together. + step: The current step of the action. Omit for indeterminate progress. + total_steps: The total number of steps of the action. Omit for indeterminate progress. + message: An optional message to display. If omitted, no message will be displayed. + image: An optional image to display. If omitted, no image will be displayed. + + Raises: + pydantic.ValidationError: If only one of `step` and `total_steps` is provided. + """ + self._services.events.emit_invocation_generic_progress( + queue_item=self._data.queue_item, + invocation=self._data.invocation, + name=name, + step=step, + total_steps=total_steps, + message=message, + image=image, + ) + class InvocationContext: """Provides access to various services and data for the current invocation.