From 5f94340e4f2d8a9c4ab2335030fd02bd163e8a6d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 4 Aug 2024 18:47:45 +1000 Subject: [PATCH] feat(app): merge progress events into one - Merged `InvocationGenericProgressEvent` and `InvocationDenoiseProgressEvent` into single `InvocationProgressEvent` - Simplified API - message is required, percentage and image are optional, no steps/total steps - Added helper to build a `ProgressImage` - Added field validation to `ProgressImage` width and height - Added `ProgressImage` to `invocation_api.py` - Updated `InvocationContext` utils --- invokeai/app/api/sockets.py | 6 +- .../invocations/spandrel_image_to_image.py | 52 +++++---- invokeai/app/services/events/events_base.py | 40 ++----- invokeai/app/services/events/events_common.py | 105 +++--------------- .../session_processor_common.py | 17 ++- .../app/services/shared/invocation_context.py | 72 ++++++------ invokeai/app/util/step_callback.py | 37 +++--- invokeai/invocation_api/__init__.py | 2 + 8 files changed, 137 insertions(+), 194 deletions(-) diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index da009c3940..188f958c88 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -20,9 +20,8 @@ from invokeai.app.services.events.events_common import ( DownloadStartedEvent, FastAPIEvent, InvocationCompleteEvent, - InvocationDenoiseProgressEvent, InvocationErrorEvent, - InvocationGenericProgressEvent, + InvocationProgressEvent, InvocationStartedEvent, ModelEventBase, ModelInstallCancelledEvent, @@ -56,8 +55,7 @@ class BulkDownloadSubscriptionEvent(BaseModel): QUEUE_EVENTS = { InvocationStartedEvent, - InvocationDenoiseProgressEvent, - InvocationGenericProgressEvent, + InvocationProgressEvent, InvocationCompleteEvent, InvocationErrorEvent, QueueItemStatusChangedEvent, diff --git a/invokeai/app/invocations/spandrel_image_to_image.py b/invokeai/app/invocations/spandrel_image_to_image.py index 74b29fbc8c..2da8694ede 100644 --- a/invokeai/app/invocations/spandrel_image_to_image.py +++ b/invokeai/app/invocations/spandrel_image_to_image.py @@ -1,3 +1,4 @@ +import functools from typing import Callable import numpy as np @@ -150,19 +151,6 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): return pil_image - def _get_step_callback(self, context: InvocationContext) -> Callable[[int, int], None]: - invocation_type = self.get_type() - - def step_callback(step: int, total_steps: int) -> None: - context.util.signal_progress( - name=invocation_type, - step=step, - total_steps=total_steps, - message="Processing image", - ) - - return step_callback - @torch.inference_mode() def invoke(self, context: InvocationContext) -> ImageOutput: # Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to @@ -172,13 +160,19 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): # Load the model. spandrel_model_info = context.models.load(self.image_to_image_model) + def step_callback(step: int, total_steps: int) -> None: + context.util.signal_progress( + message=f"Processing image (tile {step}/{total_steps})", + percentage=step / total_steps, + ) + # Do the upscaling. with spandrel_model_info as spandrel_model: assert isinstance(spandrel_model, SpandrelImageToImageModel) # Upscale the image pil_image = self.upscale_image( - image, self.tile_size, spandrel_model, context.util.is_canceled, self._get_step_callback(context) + image, self.tile_size, spandrel_model, context.util.is_canceled, step_callback ) image_dto = context.images.save(image=pil_image) @@ -220,13 +214,26 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation): target_width = int(image.width * self.scale) target_height = int(image.height * self.scale) + def step_callback(iteration: int, step: int, total_steps: int) -> None: + context.util.signal_progress( + message=self._get_progress_message(iteration, step, total_steps), + percentage=step / total_steps, + ) + # Do the upscaling. with spandrel_model_info as spandrel_model: assert isinstance(spandrel_model, SpandrelImageToImageModel) + iteration = 1 + context.util.signal_progress(self._get_progress_message(iteration)) + # First pass of upscaling. Note: `pil_image` will be mutated. pil_image = self.upscale_image( - image, self.tile_size, spandrel_model, context.util.is_canceled, self._get_step_callback(context) + image, + self.tile_size, + spandrel_model, + context.util.is_canceled, + functools.partial(step_callback, iteration), ) # Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model @@ -236,22 +243,22 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation): if is_upscale_model: # This is an upscale model, so we should keep upscaling until we reach the target size. - iterations = 1 while pil_image.width < target_width or pil_image.height < target_height: + iteration += 1 + context.util.signal_progress(self._get_progress_message(iteration)) pil_image = self.upscale_image( pil_image, self.tile_size, spandrel_model, context.util.is_canceled, - self._get_step_callback(context), + functools.partial(step_callback, iteration), ) - iterations += 1 # Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x. # Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations. # We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice, # we should never reach this limit. - if iterations >= 5: + if iteration >= 5: context.logger.warning( "Upscale loop reached maximum iteration count of 5, stopping upscaling early." ) @@ -282,3 +289,10 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation): image_dto = context.images.save(image=pil_image) return ImageOutput.build(image_dto) + + @classmethod + def _get_progress_message(cls, iteration: int, step: int | None = None, total_steps: int | None = None) -> str: + if step is not None and total_steps is not None: + return f"Processing image (iteration {iteration}, tile {step}/{total_steps})" + + return f"Processing image (iteration {iteration})" diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 5c1c8d2f93..681386e877 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -3,8 +3,6 @@ from typing import TYPE_CHECKING, Optional -from PIL.Image import Image as PILImageType - from invokeai.app.services.events.events_common import ( BatchEnqueuedEvent, BulkDownloadCompleteEvent, @@ -17,9 +15,8 @@ from invokeai.app.services.events.events_common import ( DownloadStartedEvent, EventBase, InvocationCompleteEvent, - InvocationDenoiseProgressEvent, InvocationErrorEvent, - InvocationGenericProgressEvent, + InvocationProgressEvent, InvocationStartedEvent, ModelInstallCancelledEvent, ModelInstallCompleteEvent, @@ -33,13 +30,12 @@ from invokeai.app.services.events.events_common import ( QueueClearedEvent, QueueItemStatusChangedEvent, ) -from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState +from invokeai.app.services.session_processor.session_processor_common import ProgressImage if TYPE_CHECKING: from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.download.download_base import DownloadJob from invokeai.app.services.model_install.model_install_common import ModelInstallJob - from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_queue.session_queue_common import ( BatchStatus, EnqueueBatchResult, @@ -61,38 +57,16 @@ class EventServiceBase: """Emitted when an invocation is started""" self.dispatch(InvocationStartedEvent.build(queue_item, invocation)) - def emit_invocation_generic_progress( + def emit_invocation_progress( self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", - name: str, - step: int | None = None, - total_steps: int | None = None, - message: str | None = None, - image: PILImageType | None = None, + message: str, + percentage: float | None = None, + image: ProgressImage | None = None, ) -> None: """Emitted at each step during an invocation""" - self.dispatch( - InvocationGenericProgressEvent.build( - queue_item, - invocation, - name, - step, - total_steps, - message, - image, - ) - ) - - def emit_invocation_denoise_progress( - self, - queue_item: "SessionQueueItem", - invocation: "BaseInvocation", - intermediate_state: PipelineIntermediateState, - progress_image: "ProgressImage", - ) -> None: - """Emitted at each step during denoising of an invocation.""" - self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image)) + self.dispatch(InvocationProgressEvent.build(queue_item, invocation, message, percentage, image)) def emit_invocation_complete( self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput" diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index 41a2b4566c..ad84773d9c 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -1,10 +1,8 @@ -from math import floor from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar from fastapi_events.handlers.local import local_handler from fastapi_events.registry.payload_schema import registry as payload_schema -from PIL.Image import Image as PILImageType -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_queue.session_queue_common import ( @@ -17,8 +15,6 @@ from invokeai.app.services.session_queue.session_queue_common import ( from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput from invokeai.app.util.misc import get_timestamp from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType -from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState -from invokeai.backend.util.util import image_to_dataURL if TYPE_CHECKING: from invokeai.app.services.download.download_base import DownloadJob @@ -123,49 +119,28 @@ class InvocationStartedEvent(InvocationEventBase): @payload_schema.register -class InvocationGenericProgressEvent(InvocationEventBase): - """Event model for invocation_generic_progress""" +class InvocationProgressEvent(InvocationEventBase): + """Event model for invocation_progress""" - __event_name__ = "invocation_generic_progress" + __event_name__ = "invocation_progress" - name: str = Field(description="The name of the progress type") - step: int | None = Field( - default=None, - description="The current step. Omit for indeterminate progress.", + message: str = Field(description="A message to display") + percentage: float | None = Field( + default=None, ge=0, le=1, description="The percentage of the progress (omit to indicate indeterminate progress)" ) - total_steps: int | None = Field( - default=None, - description="The total number of steps. Omit for indeterminate progress.", + image: ProgressImage | None = Field( + default=None, description="An image representing the current state of the progress" ) - image: ProgressImage | None = Field(default=None, description="An image sent at each step during processing") - message: str | None = Field(default=None, description="A message to display with the progress") - - @model_validator(mode="after") - def validate_step_total_steps(self): - if (self.step is None) is not (self.total_steps is None): - raise ValueError("must provide both step and total_steps or neither") - return self @classmethod def build( cls, queue_item: SessionQueueItem, invocation: AnyInvocation, - name: str, - step: int | None = None, - total_steps: int | None = None, - message: str | None = None, - image: PILImageType | None = None, - ) -> "InvocationGenericProgressEvent": - image_ = ( - ProgressImage( - dataURL=image_to_dataURL(image, image_format="JPEG"), - width=image.width, - height=image.height, - ) - if image - else None - ) + message: str, + percentage: float | None = None, + image: ProgressImage | None = None, + ) -> "InvocationProgressEvent": return cls( queue_id=queue_item.queue_id, item_id=queue_item.item_id, @@ -173,62 +148,12 @@ class InvocationGenericProgressEvent(InvocationEventBase): session_id=queue_item.session_id, invocation=invocation, invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], - name=name, - step=step, - total_steps=total_steps, - image=image_, + percentage=percentage, + image=image, message=message, ) -@payload_schema.register -class InvocationDenoiseProgressEvent(InvocationEventBase): - """Event model for invocation_denoise_progress""" - - __event_name__ = "invocation_denoise_progress" - - progress_image: ProgressImage = Field(description="The progress image sent at each step during processing") - step: int = Field(description="The current step of the invocation") - total_steps: int = Field(description="The total number of steps in the invocation") - order: int = Field(description="The order of the invocation in the session") - percentage: float = Field(description="The percentage of completion of the invocation") - - @classmethod - def build( - cls, - queue_item: SessionQueueItem, - invocation: AnyInvocation, - intermediate_state: PipelineIntermediateState, - progress_image: ProgressImage, - ) -> "InvocationDenoiseProgressEvent": - step = intermediate_state.step - total_steps = intermediate_state.total_steps - order = intermediate_state.order - return cls( - queue_id=queue_item.queue_id, - item_id=queue_item.item_id, - batch_id=queue_item.batch_id, - session_id=queue_item.session_id, - invocation=invocation, - invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], - progress_image=progress_image, - step=step, - total_steps=total_steps, - order=order, - percentage=cls.calc_percentage(step, total_steps, order), - ) - - @staticmethod - def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float: - """Calculate the percentage of completion of denoising.""" - if total_steps == 0: - return 0.0 - if scheduler_order == 2: - return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2) - # order == 1 - return (step + 1 + 1) / (total_steps + 1) - - @payload_schema.register class InvocationCompleteEvent(InvocationEventBase): """Event model for invocation_complete""" diff --git a/invokeai/app/services/session_processor/session_processor_common.py b/invokeai/app/services/session_processor/session_processor_common.py index 0ca51de517..346f12d8bb 100644 --- a/invokeai/app/services/session_processor/session_processor_common.py +++ b/invokeai/app/services/session_processor/session_processor_common.py @@ -1,5 +1,8 @@ +from PIL.Image import Image as PILImageType from pydantic import BaseModel, Field +from invokeai.backend.util.util import image_to_dataURL + class SessionProcessorStatus(BaseModel): is_started: bool = Field(description="Whether the session processor is started") @@ -15,6 +18,16 @@ class CanceledException(Exception): class ProgressImage(BaseModel): """The progress image sent intermittently during processing""" - width: int = Field(description="The effective width of the image in pixels") - height: int = Field(description="The effective height of the image in pixels") + width: int = Field(ge=1, description="The effective width of the image in pixels") + height: int = Field(ge=1, description="The effective height of the image in pixels") dataURL: str = Field(description="The image data as a b64 data URL") + + @classmethod + def build(cls, image: PILImageType, size: tuple[int, int] | None = None) -> "ProgressImage": + """Build a ProgressImage from a PIL image""" + + return cls( + width=size[0] if size else image.width, + height=size[1] if size else image.height, + dataURL=image_to_dataURL(image, image_format="JPEG"), + ) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 5c0ba97a05..2e5b137c5f 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -14,6 +14,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.model_records.model_records_base import UnknownModelException +from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_manager.config import ( AnyModel, @@ -550,54 +551,61 @@ class UtilInterface(InvocationContextInterface): """ stable_diffusion_step_callback( - context_data=self._data, + signal_progress=self.signal_progress, intermediate_state=intermediate_state, base_model=base_model, - events=self._services.events, is_canceled=self.is_canceled, ) def signal_progress( - self, - name: str, - step: int | None = None, - total_steps: int | None = None, - message: str | None = None, - image: Image | None = None, + self, message: str, percentage: float | None = None, image: ProgressImage | None = None ) -> None: - """Signals the progress of some long-running invocation process. The progress is displayed in the UI. + """Signals the progress of some long-running invocation. The progress is displayed in the UI. - Each progress event is grouped by both the given `name` and the invocation's ID. Once the invocation completes, - future progress events with the same name will be grouped separately. + If you have an image to display, use `ProgressImage.build` to create the object. - For progress that has a known number of steps, provide both `step` and `total_steps`. For indeterminate - progress, omit both `step` and `total_steps`. An error will be raised if only one of `step` and `total_steps` - is provided. + If your progress image should be displayed at a different size, provide a tuple of `(width, height)` when + building the progress image. - For the best user experience: - - Signal process once with `step=0, total_steps=total_steps` before processing begins. - - Signal process after each step completes with `step=current_step, total_steps=total_steps`. - - Signal process once with `step=total_steps, total_steps=total_steps` after processing completes, if this - wasn't already done. - - If the process is indeterminate, signal progress with `step=None, total_steps=None` at regular intervals. + For example, SD denoising progress images are 1/8 the size of the original image. In this case, the progress + image should be built like this to ensure it displays at the correct size: + ```py + progress_image = ProgressImage.build(image, (width * 8, height * 8)) + ``` + + If your progress image is very large, consider downscaling it to reduce the payload size. + + Example: + ```py + total_steps = 10 + for i in range(total_steps): + # Do some iterative progressing + image = do_iterative_processing(image) + + # Calculate the percentage + step = i + 1 + percentage = step / total_steps + + # Create a short, friendly message + message = f"Processing (step {step}/{total_steps})" + + # Build the progress image + progress_image = ProgressImage.build(image) + + # Send progress to the UI + context.util.signal_progress(message, percentage, progress_image) + ``` Args: - name: The name of the action. This is used to group progress events together. - step: The current step of the action. Omit for indeterminate progress. - total_steps: The total number of steps of the action. Omit for indeterminate progress. - message: An optional message to display. If omitted, no message will be displayed. - image: An optional image to display. If omitted, no image will be displayed. - - Raises: - pydantic.ValidationError: If only one of `step` and `total_steps` is provided. + message: A message describing the current status. + percentage: The current percentage completion for the process. Omit for indeterminate progress. + image: An optional progress image to display. """ - self._services.events.emit_invocation_generic_progress( + self._services.events.emit_invocation_progress( queue_item=self._data.queue_item, invocation=self._data.invocation, - name=name, - step=step, - total_steps=total_steps, message=message, + percentage=percentage, image=image, ) diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index c0c101cd75..a5056c9617 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Callable, Optional +from math import floor +from typing import Callable, Optional import torch from PIL import Image @@ -6,11 +7,6 @@ from PIL import Image from invokeai.app.services.session_processor.session_processor_common import CanceledException, ProgressImage from invokeai.backend.model_manager.config import BaseModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState -from invokeai.backend.util.util import image_to_dataURL - -if TYPE_CHECKING: - from invokeai.app.services.events.events_base import EventServiceBase - from invokeai.app.services.shared.invocation_context import InvocationContextData # fast latents preview matrix for sdxl # generated by @StAlKeR7779 @@ -56,11 +52,25 @@ def sample_to_lowres_estimated_image( return Image.fromarray(latents_ubyte.numpy()) +def calc_percentage(intermediate_state: PipelineIntermediateState) -> float: + """Calculate the percentage of completion of denoising.""" + + step = intermediate_state.step + total_steps = intermediate_state.total_steps + order = intermediate_state.order + + if total_steps == 0: + return 0.0 + if order == 2: + return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2) + # order == 1 + return (step + 1 + 1) / (total_steps + 1) + + def stable_diffusion_step_callback( - context_data: "InvocationContextData", + signal_progress: Callable[[str, float | None, ProgressImage | None], None], intermediate_state: PipelineIntermediateState, base_model: BaseModelType, - events: "EventServiceBase", is_canceled: Callable[[], bool], ) -> None: if is_canceled(): @@ -86,11 +96,10 @@ def stable_diffusion_step_callback( width *= 8 height *= 8 - dataURL = image_to_dataURL(image, image_format="JPEG") + percentage = calc_percentage(intermediate_state) - events.emit_invocation_denoise_progress( - context_data.queue_item, - context_data.invocation, - intermediate_state, - ProgressImage(dataURL=dataURL, width=width, height=height), + signal_progress( + "Denoising", + percentage, + ProgressImage.build(image=image, size=(width, height)), ) diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index 586f85b9c2..267c83bb9a 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -66,6 +66,7 @@ from invokeai.app.invocations.scheduler import SchedulerOutput from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory +from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.util.misc import SEED_MAX, get_random_seed @@ -176,4 +177,5 @@ __all__ = [ # invokeai.app.util.misc "SEED_MAX", "get_random_seed", + "ProgressImage", ]