feat(app): merge progress events into one

- Merged `InvocationGenericProgressEvent` and `InvocationDenoiseProgressEvent` into single `InvocationProgressEvent`
- Simplified API - message is required, percentage and image are optional, no steps/total steps
- Added helper to build a `ProgressImage`
- Added field validation to `ProgressImage` width and height
- Added `ProgressImage` to `invocation_api.py`
- Updated `InvocationContext` utils
This commit is contained in:
psychedelicious
2024-08-04 18:47:45 +10:00
parent 682280683a
commit 5f94340e4f
8 changed files with 137 additions and 194 deletions

View File

@ -20,9 +20,8 @@ from invokeai.app.services.events.events_common import (
DownloadStartedEvent, DownloadStartedEvent,
FastAPIEvent, FastAPIEvent,
InvocationCompleteEvent, InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent, InvocationErrorEvent,
InvocationGenericProgressEvent, InvocationProgressEvent,
InvocationStartedEvent, InvocationStartedEvent,
ModelEventBase, ModelEventBase,
ModelInstallCancelledEvent, ModelInstallCancelledEvent,
@ -56,8 +55,7 @@ class BulkDownloadSubscriptionEvent(BaseModel):
QUEUE_EVENTS = { QUEUE_EVENTS = {
InvocationStartedEvent, InvocationStartedEvent,
InvocationDenoiseProgressEvent, InvocationProgressEvent,
InvocationGenericProgressEvent,
InvocationCompleteEvent, InvocationCompleteEvent,
InvocationErrorEvent, InvocationErrorEvent,
QueueItemStatusChangedEvent, QueueItemStatusChangedEvent,

View File

@ -1,3 +1,4 @@
import functools
from typing import Callable from typing import Callable
import numpy as np import numpy as np
@ -150,19 +151,6 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
return pil_image return pil_image
def _get_step_callback(self, context: InvocationContext) -> Callable[[int, int], None]:
invocation_type = self.get_type()
def step_callback(step: int, total_steps: int) -> None:
context.util.signal_progress(
name=invocation_type,
step=step,
total_steps=total_steps,
message="Processing image",
)
return step_callback
@torch.inference_mode() @torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to # Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
@ -172,13 +160,19 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
# Load the model. # Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model) spandrel_model_info = context.models.load(self.image_to_image_model)
def step_callback(step: int, total_steps: int) -> None:
context.util.signal_progress(
message=f"Processing image (tile {step}/{total_steps})",
percentage=step / total_steps,
)
# Do the upscaling. # Do the upscaling.
with spandrel_model_info as spandrel_model: with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel) assert isinstance(spandrel_model, SpandrelImageToImageModel)
# Upscale the image # Upscale the image
pil_image = self.upscale_image( pil_image = self.upscale_image(
image, self.tile_size, spandrel_model, context.util.is_canceled, self._get_step_callback(context) image, self.tile_size, spandrel_model, context.util.is_canceled, step_callback
) )
image_dto = context.images.save(image=pil_image) image_dto = context.images.save(image=pil_image)
@ -220,13 +214,26 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
target_width = int(image.width * self.scale) target_width = int(image.width * self.scale)
target_height = int(image.height * self.scale) target_height = int(image.height * self.scale)
def step_callback(iteration: int, step: int, total_steps: int) -> None:
context.util.signal_progress(
message=self._get_progress_message(iteration, step, total_steps),
percentage=step / total_steps,
)
# Do the upscaling. # Do the upscaling.
with spandrel_model_info as spandrel_model: with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel) assert isinstance(spandrel_model, SpandrelImageToImageModel)
iteration = 1
context.util.signal_progress(self._get_progress_message(iteration))
# First pass of upscaling. Note: `pil_image` will be mutated. # First pass of upscaling. Note: `pil_image` will be mutated.
pil_image = self.upscale_image( pil_image = self.upscale_image(
image, self.tile_size, spandrel_model, context.util.is_canceled, self._get_step_callback(context) image,
self.tile_size,
spandrel_model,
context.util.is_canceled,
functools.partial(step_callback, iteration),
) )
# Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model # Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model
@ -236,22 +243,22 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
if is_upscale_model: if is_upscale_model:
# This is an upscale model, so we should keep upscaling until we reach the target size. # This is an upscale model, so we should keep upscaling until we reach the target size.
iterations = 1
while pil_image.width < target_width or pil_image.height < target_height: while pil_image.width < target_width or pil_image.height < target_height:
iteration += 1
context.util.signal_progress(self._get_progress_message(iteration))
pil_image = self.upscale_image( pil_image = self.upscale_image(
pil_image, pil_image,
self.tile_size, self.tile_size,
spandrel_model, spandrel_model,
context.util.is_canceled, context.util.is_canceled,
self._get_step_callback(context), functools.partial(step_callback, iteration),
) )
iterations += 1
# Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x. # Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.
# Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations. # Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations.
# We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice, # We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice,
# we should never reach this limit. # we should never reach this limit.
if iterations >= 5: if iteration >= 5:
context.logger.warning( context.logger.warning(
"Upscale loop reached maximum iteration count of 5, stopping upscaling early." "Upscale loop reached maximum iteration count of 5, stopping upscaling early."
) )
@ -282,3 +289,10 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
image_dto = context.images.save(image=pil_image) image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)
@classmethod
def _get_progress_message(cls, iteration: int, step: int | None = None, total_steps: int | None = None) -> str:
if step is not None and total_steps is not None:
return f"Processing image (iteration {iteration}, tile {step}/{total_steps})"
return f"Processing image (iteration {iteration})"

View File

@ -3,8 +3,6 @@
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from PIL.Image import Image as PILImageType
from invokeai.app.services.events.events_common import ( from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent, BatchEnqueuedEvent,
BulkDownloadCompleteEvent, BulkDownloadCompleteEvent,
@ -17,9 +15,8 @@ from invokeai.app.services.events.events_common import (
DownloadStartedEvent, DownloadStartedEvent,
EventBase, EventBase,
InvocationCompleteEvent, InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent, InvocationErrorEvent,
InvocationGenericProgressEvent, InvocationProgressEvent,
InvocationStartedEvent, InvocationStartedEvent,
ModelInstallCancelledEvent, ModelInstallCancelledEvent,
ModelInstallCompleteEvent, ModelInstallCompleteEvent,
@ -33,13 +30,12 @@ from invokeai.app.services.events.events_common import (
QueueClearedEvent, QueueClearedEvent,
QueueItemStatusChangedEvent, QueueItemStatusChangedEvent,
) )
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.app.services.session_processor.session_processor_common import ProgressImage
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.download.download_base import DownloadJob from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.model_install.model_install_common import ModelInstallJob from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import ( from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus, BatchStatus,
EnqueueBatchResult, EnqueueBatchResult,
@ -61,38 +57,16 @@ class EventServiceBase:
"""Emitted when an invocation is started""" """Emitted when an invocation is started"""
self.dispatch(InvocationStartedEvent.build(queue_item, invocation)) self.dispatch(InvocationStartedEvent.build(queue_item, invocation))
def emit_invocation_generic_progress( def emit_invocation_progress(
self, self,
queue_item: "SessionQueueItem", queue_item: "SessionQueueItem",
invocation: "BaseInvocation", invocation: "BaseInvocation",
name: str, message: str,
step: int | None = None, percentage: float | None = None,
total_steps: int | None = None, image: ProgressImage | None = None,
message: str | None = None,
image: PILImageType | None = None,
) -> None: ) -> None:
"""Emitted at each step during an invocation""" """Emitted at each step during an invocation"""
self.dispatch( self.dispatch(InvocationProgressEvent.build(queue_item, invocation, message, percentage, image))
InvocationGenericProgressEvent.build(
queue_item,
invocation,
name,
step,
total_steps,
message,
image,
)
)
def emit_invocation_denoise_progress(
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
intermediate_state: PipelineIntermediateState,
progress_image: "ProgressImage",
) -> None:
"""Emitted at each step during denoising of an invocation."""
self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image))
def emit_invocation_complete( def emit_invocation_complete(
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput" self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"

View File

@ -1,10 +1,8 @@
from math import floor
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema from fastapi_events.registry.payload_schema import registry as payload_schema
from PIL.Image import Image as PILImageType from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator
from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import ( from invokeai.app.services.session_queue.session_queue_common import (
@ -17,8 +15,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.util.util import image_to_dataURL
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.services.download.download_base import DownloadJob from invokeai.app.services.download.download_base import DownloadJob
@ -123,49 +119,28 @@ class InvocationStartedEvent(InvocationEventBase):
@payload_schema.register @payload_schema.register
class InvocationGenericProgressEvent(InvocationEventBase): class InvocationProgressEvent(InvocationEventBase):
"""Event model for invocation_generic_progress""" """Event model for invocation_progress"""
__event_name__ = "invocation_generic_progress" __event_name__ = "invocation_progress"
name: str = Field(description="The name of the progress type") message: str = Field(description="A message to display")
step: int | None = Field( percentage: float | None = Field(
default=None, default=None, ge=0, le=1, description="The percentage of the progress (omit to indicate indeterminate progress)"
description="The current step. Omit for indeterminate progress.",
) )
total_steps: int | None = Field( image: ProgressImage | None = Field(
default=None, default=None, description="An image representing the current state of the progress"
description="The total number of steps. Omit for indeterminate progress.",
) )
image: ProgressImage | None = Field(default=None, description="An image sent at each step during processing")
message: str | None = Field(default=None, description="A message to display with the progress")
@model_validator(mode="after")
def validate_step_total_steps(self):
if (self.step is None) is not (self.total_steps is None):
raise ValueError("must provide both step and total_steps or neither")
return self
@classmethod @classmethod
def build( def build(
cls, cls,
queue_item: SessionQueueItem, queue_item: SessionQueueItem,
invocation: AnyInvocation, invocation: AnyInvocation,
name: str, message: str,
step: int | None = None, percentage: float | None = None,
total_steps: int | None = None, image: ProgressImage | None = None,
message: str | None = None, ) -> "InvocationProgressEvent":
image: PILImageType | None = None,
) -> "InvocationGenericProgressEvent":
image_ = (
ProgressImage(
dataURL=image_to_dataURL(image, image_format="JPEG"),
width=image.width,
height=image.height,
)
if image
else None
)
return cls( return cls(
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
item_id=queue_item.item_id, item_id=queue_item.item_id,
@ -173,62 +148,12 @@ class InvocationGenericProgressEvent(InvocationEventBase):
session_id=queue_item.session_id, session_id=queue_item.session_id,
invocation=invocation, invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
name=name, percentage=percentage,
step=step, image=image,
total_steps=total_steps,
image=image_,
message=message, message=message,
) )
@payload_schema.register
class InvocationDenoiseProgressEvent(InvocationEventBase):
"""Event model for invocation_denoise_progress"""
__event_name__ = "invocation_denoise_progress"
progress_image: ProgressImage = Field(description="The progress image sent at each step during processing")
step: int = Field(description="The current step of the invocation")
total_steps: int = Field(description="The total number of steps in the invocation")
order: int = Field(description="The order of the invocation in the session")
percentage: float = Field(description="The percentage of completion of the invocation")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: AnyInvocation,
intermediate_state: PipelineIntermediateState,
progress_image: ProgressImage,
) -> "InvocationDenoiseProgressEvent":
step = intermediate_state.step
total_steps = intermediate_state.total_steps
order = intermediate_state.order
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
progress_image=progress_image,
step=step,
total_steps=total_steps,
order=order,
percentage=cls.calc_percentage(step, total_steps, order),
)
@staticmethod
def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float:
"""Calculate the percentage of completion of denoising."""
if total_steps == 0:
return 0.0
if scheduler_order == 2:
return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
# order == 1
return (step + 1 + 1) / (total_steps + 1)
@payload_schema.register @payload_schema.register
class InvocationCompleteEvent(InvocationEventBase): class InvocationCompleteEvent(InvocationEventBase):
"""Event model for invocation_complete""" """Event model for invocation_complete"""

View File

@ -1,5 +1,8 @@
from PIL.Image import Image as PILImageType
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.backend.util.util import image_to_dataURL
class SessionProcessorStatus(BaseModel): class SessionProcessorStatus(BaseModel):
is_started: bool = Field(description="Whether the session processor is started") is_started: bool = Field(description="Whether the session processor is started")
@ -15,6 +18,16 @@ class CanceledException(Exception):
class ProgressImage(BaseModel): class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing""" """The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels") width: int = Field(ge=1, description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels") height: int = Field(ge=1, description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL") dataURL: str = Field(description="The image data as a b64 data URL")
@classmethod
def build(cls, image: PILImageType, size: tuple[int, int] | None = None) -> "ProgressImage":
"""Build a ProgressImage from a PIL image"""
return cls(
width=size[0] if size else image.width,
height=size[1] if size else image.height,
dataURL=image_to_dataURL(image, image_format="JPEG"),
)

View File

@ -14,6 +14,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModel, AnyModel,
@ -550,54 +551,61 @@ class UtilInterface(InvocationContextInterface):
""" """
stable_diffusion_step_callback( stable_diffusion_step_callback(
context_data=self._data, signal_progress=self.signal_progress,
intermediate_state=intermediate_state, intermediate_state=intermediate_state,
base_model=base_model, base_model=base_model,
events=self._services.events,
is_canceled=self.is_canceled, is_canceled=self.is_canceled,
) )
def signal_progress( def signal_progress(
self, self, message: str, percentage: float | None = None, image: ProgressImage | None = None
name: str,
step: int | None = None,
total_steps: int | None = None,
message: str | None = None,
image: Image | None = None,
) -> None: ) -> None:
"""Signals the progress of some long-running invocation process. The progress is displayed in the UI. """Signals the progress of some long-running invocation. The progress is displayed in the UI.
Each progress event is grouped by both the given `name` and the invocation's ID. Once the invocation completes, If you have an image to display, use `ProgressImage.build` to create the object.
future progress events with the same name will be grouped separately.
For progress that has a known number of steps, provide both `step` and `total_steps`. For indeterminate If your progress image should be displayed at a different size, provide a tuple of `(width, height)` when
progress, omit both `step` and `total_steps`. An error will be raised if only one of `step` and `total_steps` building the progress image.
is provided.
For the best user experience: For example, SD denoising progress images are 1/8 the size of the original image. In this case, the progress
- Signal process once with `step=0, total_steps=total_steps` before processing begins. image should be built like this to ensure it displays at the correct size:
- Signal process after each step completes with `step=current_step, total_steps=total_steps`. ```py
- Signal process once with `step=total_steps, total_steps=total_steps` after processing completes, if this progress_image = ProgressImage.build(image, (width * 8, height * 8))
wasn't already done. ```
- If the process is indeterminate, signal progress with `step=None, total_steps=None` at regular intervals.
If your progress image is very large, consider downscaling it to reduce the payload size.
Example:
```py
total_steps = 10
for i in range(total_steps):
# Do some iterative progressing
image = do_iterative_processing(image)
# Calculate the percentage
step = i + 1
percentage = step / total_steps
# Create a short, friendly message
message = f"Processing (step {step}/{total_steps})"
# Build the progress image
progress_image = ProgressImage.build(image)
# Send progress to the UI
context.util.signal_progress(message, percentage, progress_image)
```
Args: Args:
name: The name of the action. This is used to group progress events together. message: A message describing the current status.
step: The current step of the action. Omit for indeterminate progress. percentage: The current percentage completion for the process. Omit for indeterminate progress.
total_steps: The total number of steps of the action. Omit for indeterminate progress. image: An optional progress image to display.
message: An optional message to display. If omitted, no message will be displayed.
image: An optional image to display. If omitted, no image will be displayed.
Raises:
pydantic.ValidationError: If only one of `step` and `total_steps` is provided.
""" """
self._services.events.emit_invocation_generic_progress( self._services.events.emit_invocation_progress(
queue_item=self._data.queue_item, queue_item=self._data.queue_item,
invocation=self._data.invocation, invocation=self._data.invocation,
name=name,
step=step,
total_steps=total_steps,
message=message, message=message,
percentage=percentage,
image=image, image=image,
) )

View File

@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Callable, Optional from math import floor
from typing import Callable, Optional
import torch import torch
from PIL import Image from PIL import Image
@ -6,11 +7,6 @@ from PIL import Image
from invokeai.app.services.session_processor.session_processor_common import CanceledException, ProgressImage from invokeai.app.services.session_processor.session_processor_common import CanceledException, ProgressImage
from invokeai.backend.model_manager.config import BaseModelType from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.util.util import image_to_dataURL
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.shared.invocation_context import InvocationContextData
# fast latents preview matrix for sdxl # fast latents preview matrix for sdxl
# generated by @StAlKeR7779 # generated by @StAlKeR7779
@ -56,11 +52,25 @@ def sample_to_lowres_estimated_image(
return Image.fromarray(latents_ubyte.numpy()) return Image.fromarray(latents_ubyte.numpy())
def calc_percentage(intermediate_state: PipelineIntermediateState) -> float:
"""Calculate the percentage of completion of denoising."""
step = intermediate_state.step
total_steps = intermediate_state.total_steps
order = intermediate_state.order
if total_steps == 0:
return 0.0
if order == 2:
return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
# order == 1
return (step + 1 + 1) / (total_steps + 1)
def stable_diffusion_step_callback( def stable_diffusion_step_callback(
context_data: "InvocationContextData", signal_progress: Callable[[str, float | None, ProgressImage | None], None],
intermediate_state: PipelineIntermediateState, intermediate_state: PipelineIntermediateState,
base_model: BaseModelType, base_model: BaseModelType,
events: "EventServiceBase",
is_canceled: Callable[[], bool], is_canceled: Callable[[], bool],
) -> None: ) -> None:
if is_canceled(): if is_canceled():
@ -86,11 +96,10 @@ def stable_diffusion_step_callback(
width *= 8 width *= 8
height *= 8 height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG") percentage = calc_percentage(intermediate_state)
events.emit_invocation_denoise_progress( signal_progress(
context_data.queue_item, "Denoising",
context_data.invocation, percentage,
intermediate_state, ProgressImage.build(image=image, size=(width, height)),
ProgressImage(dataURL=dataURL, width=width, height=height),
) )

View File

@ -66,6 +66,7 @@ from invokeai.app.invocations.scheduler import SchedulerOutput
from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
@ -176,4 +177,5 @@ __all__ = [
# invokeai.app.util.misc # invokeai.app.util.misc
"SEED_MAX", "SEED_MAX",
"get_random_seed", "get_random_seed",
"ProgressImage",
] ]