From 07ac292680d2ad71f605d60d5b9760968de35c64 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 19 Jun 2024 14:03:56 -0400 Subject: [PATCH] Consolidate _region_step() function - the separation wasn't really adding any value. --- .../multi_diffusion_pipeline.py | 70 +++++++------------ 1 file changed, 27 insertions(+), 43 deletions(-) diff --git a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py index e2934247ed..3fcabc615a 100644 --- a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py +++ b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py @@ -27,6 +27,15 @@ class MultiDiffusionRegionConditioning: class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): """A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising.""" + def _check_regional_prompting(self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]): + """Validate that regional conditioning is not used.""" + for region_conditioning in multi_diffusion_conditioning: + if ( + region_conditioning.text_conditioning_data.cond_regions is not None + or region_conditioning.text_conditioning_data.uncond_regions is not None + ): + raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.") + def multi_diffusion_denoise( self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning], @@ -36,6 +45,8 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): timesteps: torch.Tensor, callback: Callable[[PipelineIntermediateState], None], ) -> torch.Tensor: + self._check_regional_prompting(multi_diffusion_conditioning) + if timesteps.shape[0] == 0: return latents @@ -90,14 +101,26 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): # Switch to the scheduler for the region batch. self.scheduler = region_batch_schedulers[region_idx] - # Run a denoising step on the region. - step_output = self._region_step( - region_conditioning=region_conditioning, + # Crop the inputs to the region. + region_latents = latents[ + :, + :, + region_conditioning.region.top : region_conditioning.region.bottom, + region_conditioning.region.left : region_conditioning.region.right, + ] + + # Run the denoising step on the region. + step_output = self.step( t=batched_t, - latents=latents, + latents=region_latents, + conditioning_data=region_conditioning.text_conditioning_data, step_index=i, total_step_count=len(timesteps), scheduler_step_kwargs=scheduler_step_kwargs, + mask_guidance=None, + mask=None, + masked_latents=None, + control_data=region_conditioning.control_data, ) # Store the results from the region. @@ -133,42 +156,3 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): ) return latents - - @torch.inference_mode() - def _region_step( - self, - region_conditioning: MultiDiffusionRegionConditioning, - t: torch.Tensor, - latents: torch.Tensor, - step_index: int, - total_step_count: int, - scheduler_step_kwargs: dict[str, Any], - ): - use_regional_prompting = ( - region_conditioning.text_conditioning_data.cond_regions is not None - or region_conditioning.text_conditioning_data.uncond_regions is not None - ) - if use_regional_prompting: - raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.") - - # Crop the inputs to the region. - region_latents = latents[ - :, - :, - region_conditioning.region.top : region_conditioning.region.bottom, - region_conditioning.region.left : region_conditioning.region.right, - ] - - # Run the denoising step on the region. - return self.step( - t=t, - latents=region_latents, - conditioning_data=region_conditioning.text_conditioning_data, - step_index=step_index, - total_step_count=total_step_count, - scheduler_step_kwargs=scheduler_step_kwargs, - mask_guidance=None, - mask=None, - masked_latents=None, - control_data=region_conditioning.control_data, - )