diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 15c5f17438..072fc01cc9 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -1,6 +1,7 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) from datetime import datetime, timezone +from functools import partial from typing import Any, Literal, Optional, Union import numpy as np @@ -12,6 +13,7 @@ from ..services.image_storage import ImageType from ..services.invocation_services import InvocationServices from .baseinvocation import BaseInvocation, InvocationContext from .image import ImageField, ImageOutput +from ...backend.stable_diffusion import PipelineIntermediateState SAMPLER_NAME_VALUES = Literal[ "ddim", "plms", "k_lms", "k_dpm_2", "k_dpm_2_a", "k_euler", "k_euler_a", "k_heun" @@ -41,8 +43,15 @@ class TextToImageInvocation(BaseInvocation): # TODO: pass this an emitter method or something? or a session for dispatching? def dispatch_progress( - self, context: InvocationContext, sample: Any = None, step: int = 0 + self, context: InvocationContext, intermediate_state: PipelineIntermediateState ) -> None: + step = intermediate_state.step + # if intermediate_state.predicted_original is not None: + # # Some schedulers report not only the noisy latents at the current timestep, + # # but also their estimate so far of what the de-noised latents will be. + # sample = intermediate_state.predicted_original + # else: + # sample = intermediate_state.latents context.services.events.emit_generator_progress( context.graph_execution_state_id, self.id, @@ -51,9 +60,6 @@ class TextToImageInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> ImageOutput: - def step_callback(sample, step=0): - self.dispatch_progress(context, sample, step) - # Handle invalid model parameter # TODO: figure out if this can be done via a validator that uses the model_cache # TODO: How to get the default model name now? @@ -65,7 +71,7 @@ class TextToImageInvocation(BaseInvocation): results = context.services.generate.prompt2image( prompt=self.prompt, - step_callback=step_callback, + step_callback=partial(self.dispatch_progress, context), **self.dict( exclude={"prompt"} ), # Shorthand for passing all of the parameters above manually @@ -109,9 +115,6 @@ class ImageToImageInvocation(TextToImageInvocation): ) mask = None - def step_callback(sample, step=0): - self.dispatch_progress(context, sample, step) - # Handle invalid model parameter # TODO: figure out if this can be done via a validator that uses the model_cache # TODO: How to get the default model name now? @@ -125,7 +128,7 @@ class ImageToImageInvocation(TextToImageInvocation): prompt=self.prompt, init_img=image, init_mask=mask, - step_callback=step_callback, + step_callback=partial(self.dispatch_progress, context), **self.dict( exclude={"prompt", "image", "mask"} ), # Shorthand for passing all of the parameters above manually @@ -174,9 +177,6 @@ class InpaintInvocation(ImageToImageInvocation): else context.services.images.get(self.mask.image_type, self.mask.image_name) ) - def step_callback(sample, step=0): - self.dispatch_progress(context, sample, step) - # Handle invalid model parameter # TODO: figure out if this can be done via a validator that uses the model_cache # TODO: How to get the default model name now? @@ -190,7 +190,7 @@ class InpaintInvocation(ImageToImageInvocation): prompt=self.prompt, init_img=image, init_mask=mask, - step_callback=step_callback, + step_callback=partial(self.dispatch_progress, context), **self.dict( exclude={"prompt", "image", "mask"} ), # Shorthand for passing all of the parameters above manually diff --git a/invokeai/backend/web/invoke_ai_web_server.py b/invokeai/backend/web/invoke_ai_web_server.py index dc77ff4723..7209e31449 100644 --- a/invokeai/backend/web/invoke_ai_web_server.py +++ b/invokeai/backend/web/invoke_ai_web_server.py @@ -1022,7 +1022,7 @@ class InvokeAIWebServer: "RGB" ) - def image_progress(sample, step): + def image_progress(intermediate_state: PipelineIntermediateState): if self.canceled.is_set(): raise CanceledException @@ -1030,6 +1030,14 @@ class InvokeAIWebServer: nonlocal generation_parameters nonlocal progress + step = intermediate_state.step + if intermediate_state.predicted_original is not None: + # Some schedulers report not only the noisy latents at the current timestep, + # but also their estimate so far of what the de-noised latents will be. + sample = intermediate_state.predicted_original + else: + sample = intermediate_state.latents + generation_messages = { "txt2img": "common.statusGeneratingTextToImage", "img2img": "common.statusGeneratingImageToImage", @@ -1302,16 +1310,9 @@ class InvokeAIWebServer: progress.set_current_iteration(progress.current_iteration + 1) - def diffusers_step_callback_adapter(*cb_args, **kwargs): - if isinstance(cb_args[0], PipelineIntermediateState): - progress_state: PipelineIntermediateState = cb_args[0] - return image_progress(progress_state.latents, progress_state.step) - else: - return image_progress(*cb_args, **kwargs) - self.generate.prompt2image( **generation_parameters, - step_callback=diffusers_step_callback_adapter, + step_callback=image_progress, image_callback=image_done, )