Add progress image callbacks to TiledMultiDiffusionDenoiseLatentsInvocation.

This commit is contained in:
Ryan Dick 2024-06-19 13:29:42 -04:00 committed by Kent Keirsey
parent fa40061eca
commit c5ee415607

View File

@ -24,7 +24,7 @@ from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import ( from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline, MultiDiffusionPipeline,
MultiDiffusionRegionConditioning, MultiDiffusionRegionConditioning,
@ -170,6 +170,12 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
min_overlap=self.tile_min_overlap, min_overlap=self.tile_min_overlap,
) )
# Get the unet's config so that we can pass the base to dispatch_progress().
unet_config = context.models.get_config(self.unet.unet.key)
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
# Prepare an iterator that yields the UNet's LoRA models and their weights. # Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras: for lora in self.unet.loras:
@ -250,8 +256,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
scheduler_step_kwargs=scheduler_step_kwargs, scheduler_step_kwargs=scheduler_step_kwargs,
noise=noise, noise=noise,
timesteps=timesteps, timesteps=timesteps,
# TODO(ryand): Add proper callback. callback=step_callback,
callback=lambda x: None,
) )
result_latents = result_latents.to("cpu") result_latents = result_latents.to("cpu")