diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 70892ecde9..b8140b11e9 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -4,6 +4,8 @@ from datetime import datetime, timezone from typing import Any, Literal, Optional, Union import numpy as np + +from torch import Tensor from PIL import Image from pydantic import Field from skimage.exposure.histogram_matching import match_histograms @@ -12,7 +14,9 @@ from ..services.image_storage import ImageType from ..services.invocation_services import InvocationServices from .baseinvocation import BaseInvocation, InvocationContext from .image import ImageField, ImageOutput -from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator +from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator +from ...backend.stable_diffusion import PipelineIntermediateState +from ...backend.util.util import image_to_dataURL SAMPLER_NAME_VALUES = Literal[ tuple(InvokeAIGenerator.schedulers()) @@ -41,18 +45,32 @@ class TextToImageInvocation(BaseInvocation): # TODO: pass this an emitter method or something? or a session for dispatching? def dispatch_progress( - self, context: InvocationContext, sample: Any = None, step: int = 0 - ) -> None: + self, context: InvocationContext, sample: Tensor, step: int + ) -> None: + # TODO: only output a preview image when requested + image = Generator.sample_to_lowres_estimated_image(sample) + + (width, height) = image.size + width *= 8 + height *= 8 + + dataURL = image_to_dataURL(image, image_format="JPEG") + context.services.events.emit_generator_progress( context.graph_execution_state_id, self.id, + { + "width": width, + "height": height, + "dataURL": dataURL + }, step, - float(step) / float(self.steps), + self.steps, ) def invoke(self, context: InvocationContext) -> ImageOutput: - def step_callback(sample, step=0): - self.dispatch_progress(context, sample, step) + def step_callback(state: PipelineIntermediateState): + self.dispatch_progress(context, state.latents, state.step) # Handle invalid model parameter # TODO: figure out if this can be done via a validator that uses the model_cache diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index e2ab4e61e3..c8eb7671d0 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -1,7 +1,10 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from typing import Any, Dict +from typing import Any, Dict, TypedDict +ProgressImage = TypedDict( + "ProgressImage", {"dataURL": str, "width": int, "height": int} +) class EventServiceBase: session_event: str = "session_event" @@ -23,8 +26,9 @@ class EventServiceBase: self, graph_execution_state_id: str, invocation_id: str, + progress_image: ProgressImage | None, step: int, - percent: float, + total_steps: int, ) -> None: """Emitted when there is generation progress""" self.__emit_session_event( @@ -32,8 +36,9 @@ class EventServiceBase: payload=dict( graph_execution_state_id=graph_execution_state_id, invocation_id=invocation_id, + progress_image=progress_image, step=step, - percent=percent, + total_steps=total_steps, ), )