Consolidate _region_step() function - the separation wasn't really adding any value.

This commit is contained in:
Ryan Dick 2024-06-19 14:03:56 -04:00 committed by Kent Keirsey
parent 7c032ea604
commit 07ac292680

View File

@ -27,6 +27,15 @@ class MultiDiffusionRegionConditioning:
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
"""A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising."""
def _check_regional_prompting(self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]):
"""Validate that regional conditioning is not used."""
for region_conditioning in multi_diffusion_conditioning:
if (
region_conditioning.text_conditioning_data.cond_regions is not None
or region_conditioning.text_conditioning_data.uncond_regions is not None
):
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
def multi_diffusion_denoise(
self,
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
@ -36,6 +45,8 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
timesteps: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None],
) -> torch.Tensor:
self._check_regional_prompting(multi_diffusion_conditioning)
if timesteps.shape[0] == 0:
return latents
@ -90,14 +101,26 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
# Switch to the scheduler for the region batch.
self.scheduler = region_batch_schedulers[region_idx]
# Run a denoising step on the region.
step_output = self._region_step(
region_conditioning=region_conditioning,
# Crop the inputs to the region.
region_latents = latents[
:,
:,
region_conditioning.region.top : region_conditioning.region.bottom,
region_conditioning.region.left : region_conditioning.region.right,
]
# Run the denoising step on the region.
step_output = self.step(
t=batched_t,
latents=latents,
latents=region_latents,
conditioning_data=region_conditioning.text_conditioning_data,
step_index=i,
total_step_count=len(timesteps),
scheduler_step_kwargs=scheduler_step_kwargs,
mask_guidance=None,
mask=None,
masked_latents=None,
control_data=region_conditioning.control_data,
)
# Store the results from the region.
@ -133,42 +156,3 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
)
return latents
@torch.inference_mode()
def _region_step(
self,
region_conditioning: MultiDiffusionRegionConditioning,
t: torch.Tensor,
latents: torch.Tensor,
step_index: int,
total_step_count: int,
scheduler_step_kwargs: dict[str, Any],
):
use_regional_prompting = (
region_conditioning.text_conditioning_data.cond_regions is not None
or region_conditioning.text_conditioning_data.uncond_regions is not None
)
if use_regional_prompting:
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
# Crop the inputs to the region.
region_latents = latents[
:,
:,
region_conditioning.region.top : region_conditioning.region.bottom,
region_conditioning.region.left : region_conditioning.region.right,
]
# Run the denoising step on the region.
return self.step(
t=t,
latents=region_latents,
conditioning_data=region_conditioning.text_conditioning_data,
step_index=step_index,
total_step_count=total_step_count,
scheduler_step_kwargs=scheduler_step_kwargs,
mask_guidance=None,
mask=None,
masked_latents=None,
control_data=region_conditioning.control_data,
)