mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Consolidate _region_step() function - the separation wasn't really adding any value.
This commit is contained in:
parent
7c032ea604
commit
07ac292680
@ -27,6 +27,15 @@ class MultiDiffusionRegionConditioning:
|
|||||||
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||||
"""A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising."""
|
"""A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising."""
|
||||||
|
|
||||||
|
def _check_regional_prompting(self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]):
|
||||||
|
"""Validate that regional conditioning is not used."""
|
||||||
|
for region_conditioning in multi_diffusion_conditioning:
|
||||||
|
if (
|
||||||
|
region_conditioning.text_conditioning_data.cond_regions is not None
|
||||||
|
or region_conditioning.text_conditioning_data.uncond_regions is not None
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
|
||||||
|
|
||||||
def multi_diffusion_denoise(
|
def multi_diffusion_denoise(
|
||||||
self,
|
self,
|
||||||
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
|
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
|
||||||
@ -36,6 +45,8 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
callback: Callable[[PipelineIntermediateState], None],
|
callback: Callable[[PipelineIntermediateState], None],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
self._check_regional_prompting(multi_diffusion_conditioning)
|
||||||
|
|
||||||
if timesteps.shape[0] == 0:
|
if timesteps.shape[0] == 0:
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
@ -90,14 +101,26 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
# Switch to the scheduler for the region batch.
|
# Switch to the scheduler for the region batch.
|
||||||
self.scheduler = region_batch_schedulers[region_idx]
|
self.scheduler = region_batch_schedulers[region_idx]
|
||||||
|
|
||||||
# Run a denoising step on the region.
|
# Crop the inputs to the region.
|
||||||
step_output = self._region_step(
|
region_latents = latents[
|
||||||
region_conditioning=region_conditioning,
|
:,
|
||||||
|
:,
|
||||||
|
region_conditioning.region.top : region_conditioning.region.bottom,
|
||||||
|
region_conditioning.region.left : region_conditioning.region.right,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Run the denoising step on the region.
|
||||||
|
step_output = self.step(
|
||||||
t=batched_t,
|
t=batched_t,
|
||||||
latents=latents,
|
latents=region_latents,
|
||||||
|
conditioning_data=region_conditioning.text_conditioning_data,
|
||||||
step_index=i,
|
step_index=i,
|
||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
|
mask_guidance=None,
|
||||||
|
mask=None,
|
||||||
|
masked_latents=None,
|
||||||
|
control_data=region_conditioning.control_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store the results from the region.
|
# Store the results from the region.
|
||||||
@ -133,42 +156,3 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def _region_step(
|
|
||||||
self,
|
|
||||||
region_conditioning: MultiDiffusionRegionConditioning,
|
|
||||||
t: torch.Tensor,
|
|
||||||
latents: torch.Tensor,
|
|
||||||
step_index: int,
|
|
||||||
total_step_count: int,
|
|
||||||
scheduler_step_kwargs: dict[str, Any],
|
|
||||||
):
|
|
||||||
use_regional_prompting = (
|
|
||||||
region_conditioning.text_conditioning_data.cond_regions is not None
|
|
||||||
or region_conditioning.text_conditioning_data.uncond_regions is not None
|
|
||||||
)
|
|
||||||
if use_regional_prompting:
|
|
||||||
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
|
|
||||||
|
|
||||||
# Crop the inputs to the region.
|
|
||||||
region_latents = latents[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
region_conditioning.region.top : region_conditioning.region.bottom,
|
|
||||||
region_conditioning.region.left : region_conditioning.region.right,
|
|
||||||
]
|
|
||||||
|
|
||||||
# Run the denoising step on the region.
|
|
||||||
return self.step(
|
|
||||||
t=t,
|
|
||||||
latents=region_latents,
|
|
||||||
conditioning_data=region_conditioning.text_conditioning_data,
|
|
||||||
step_index=step_index,
|
|
||||||
total_step_count=total_step_count,
|
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
|
||||||
mask_guidance=None,
|
|
||||||
mask=None,
|
|
||||||
masked_latents=None,
|
|
||||||
control_data=region_conditioning.control_data,
|
|
||||||
)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user