From c5ee415607c33e0dd4f5227738c6fdf3afef8ffa Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 19 Jun 2024 13:29:42 -0400 Subject: [PATCH] Add progress image callbacks to TiledMultiDiffusionDenoiseLatentsInvocation. --- .../tiled_multi_diffusion_denoise_latents.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py index 53b8ed7fef..6274204c14 100644 --- a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py +++ b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py @@ -24,7 +24,7 @@ from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_patcher import ModelPatcher -from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData +from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import ( MultiDiffusionPipeline, MultiDiffusionRegionConditioning, @@ -170,6 +170,12 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation): min_overlap=self.tile_min_overlap, ) + # Get the unet's config so that we can pass the base to dispatch_progress(). + unet_config = context.models.get_config(self.unet.unet.key) + + def step_callback(state: PipelineIntermediateState) -> None: + context.util.sd_step_callback(state, unet_config.base) + # Prepare an iterator that yields the UNet's LoRA models and their weights. def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.unet.loras: @@ -250,8 +256,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation): scheduler_step_kwargs=scheduler_step_kwargs, noise=noise, timesteps=timesteps, - # TODO(ryand): Add proper callback. - callback=lambda x: None, + callback=step_callback, ) result_latents = result_latents.to("cpu")