mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Connect TiledMultiDiffusionDenoiseLatents to the MultiDiffusionPipeline backend.
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
@ -11,7 +11,15 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
|
||||
from invokeai.backend.tiles.utils import Tile
|
||||
from invokeai.backend.tiles.utils import TBLR
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiDiffusionRegionConditioning:
|
||||
# Region coords in latent space.
|
||||
region: TBLR
|
||||
text_conditioning_data: TextConditioningData
|
||||
control_data: list[ControlNetData]
|
||||
|
||||
|
||||
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
@ -45,15 +53,13 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
# - May need a cleaner AddsMaskGuidance implementation to handle this plan... we'll see.
|
||||
def multi_diffusion_denoise(
|
||||
self,
|
||||
regions: list[Tile],
|
||||
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
|
||||
latents: torch.Tensor,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
conditioning_data: TextConditioningData,
|
||||
noise: Optional[torch.Tensor],
|
||||
timesteps: torch.Tensor,
|
||||
init_timestep: torch.Tensor,
|
||||
callback: Callable[[PipelineIntermediateState], None],
|
||||
control_data: list[ControlNetData] | None = None,
|
||||
) -> torch.Tensor:
|
||||
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
|
||||
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
|
||||
@ -74,21 +80,14 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
# cropping into regions.
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
|
||||
use_regional_prompting = (
|
||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||
)
|
||||
if use_regional_prompting:
|
||||
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
|
||||
|
||||
# Populate a weighted mask that will be used to combine the results from each region after every step.
|
||||
# For now, we assume that each regions has the same weight (1.0).
|
||||
region_weight_mask = torch.zeros(
|
||||
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
|
||||
)
|
||||
for region in regions:
|
||||
region_weight_mask[
|
||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
||||
] += 1.0
|
||||
for region_conditioning in multi_diffusion_conditioning:
|
||||
region = region_conditioning.region
|
||||
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
|
||||
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
@ -103,39 +102,36 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t = t.expand(batch_size)
|
||||
|
||||
prev_samples_by_region: list[torch.Tensor] = []
|
||||
pred_original_by_region: list[torch.Tensor | None] = []
|
||||
for region in regions:
|
||||
merged_latents = torch.zeros_like(latents)
|
||||
merged_pred_original: torch.Tensor | None = None
|
||||
for region_conditioning in multi_diffusion_conditioning:
|
||||
# Run a denoising step on the region.
|
||||
step_output = self._region_step(
|
||||
region=region,
|
||||
region_conditioning=region_conditioning,
|
||||
t=batched_t,
|
||||
latents=latents,
|
||||
conditioning_data=conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
control_data=control_data,
|
||||
)
|
||||
prev_samples_by_region.append(step_output.prev_sample)
|
||||
pred_original_by_region.append(getattr(step_output, "pred_original_sample", None))
|
||||
|
||||
# Merge the prev_sample results from each region.
|
||||
merged_latents = torch.zeros_like(latents)
|
||||
for region_idx, region in enumerate(regions):
|
||||
merged_latents[
|
||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
||||
] += prev_samples_by_region[region_idx]
|
||||
# Store the results from the region.
|
||||
region = region_conditioning.region
|
||||
merged_latents[:, :, region.top : region.bottom, region.left : region.right] += step_output.prev_sample
|
||||
pred_orig_sample = getattr(step_output, "pred_original_sample", None)
|
||||
if pred_orig_sample is not None:
|
||||
# If one region has pred_original_sample, then we can assume that all regions will have it, because
|
||||
# they all use the same scheduler.
|
||||
if merged_pred_original is None:
|
||||
merged_pred_original = torch.zeros_like(latents)
|
||||
merged_pred_original[:, :, region.top : region.bottom, region.left : region.right] += (
|
||||
pred_orig_sample
|
||||
)
|
||||
|
||||
# Normalize the merged results.
|
||||
latents = merged_latents / region_weight_mask
|
||||
|
||||
# Merge the predicted_original results from each region.
|
||||
predicted_original = None
|
||||
if all(pred_original_by_region):
|
||||
merged_pred_original = torch.zeros_like(latents)
|
||||
for region_idx, region in enumerate(regions):
|
||||
merged_pred_original[
|
||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
||||
] += pred_original_by_region[region_idx]
|
||||
if merged_pred_original is not None:
|
||||
predicted_original = merged_pred_original / region_weight_mask
|
||||
|
||||
callback(
|
||||
@ -154,44 +150,38 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
@torch.inference_mode()
|
||||
def _region_step(
|
||||
self,
|
||||
region: Tile,
|
||||
region_conditioning: MultiDiffusionRegionConditioning,
|
||||
t: torch.Tensor,
|
||||
latents: torch.Tensor,
|
||||
conditioning_data: TextConditioningData,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
control_data: list[ControlNetData] | None = None,
|
||||
):
|
||||
use_regional_prompting = (
|
||||
region_conditioning.text_conditioning_data.cond_regions is not None
|
||||
or region_conditioning.text_conditioning_data.uncond_regions is not None
|
||||
)
|
||||
if use_regional_prompting:
|
||||
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
|
||||
|
||||
# Crop the inputs to the region.
|
||||
region_latents = latents[
|
||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
||||
:,
|
||||
:,
|
||||
region_conditioning.region.top : region_conditioning.region.bottom,
|
||||
region_conditioning.region.left : region_conditioning.region.right,
|
||||
]
|
||||
|
||||
region_control_data: list[ControlNetData] | None = None
|
||||
if control_data is not None:
|
||||
region_control_data = [self._crop_controlnet_data(c, region) for c in control_data]
|
||||
|
||||
# Run the denoising step on the region.
|
||||
return self.step(
|
||||
t=t,
|
||||
latents=region_latents,
|
||||
conditioning_data=conditioning_data,
|
||||
conditioning_data=region_conditioning.text_conditioning_data,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
mask_guidance=None,
|
||||
mask=None,
|
||||
masked_latents=None,
|
||||
control_data=region_control_data,
|
||||
control_data=region_conditioning.control_data,
|
||||
)
|
||||
|
||||
def _crop_controlnet_data(self, control_data: ControlNetData, region: Tile) -> ControlNetData:
|
||||
"""Crop a ControlNetData object to a region."""
|
||||
# Create a shallow copy of the control_data object.
|
||||
control_data_copy = copy.copy(control_data)
|
||||
# The ControlNet reference image is the only attribute that needs to be cropped.
|
||||
control_data_copy.image_tensor = control_data.image_tensor[
|
||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
||||
]
|
||||
return control_data_copy
|
||||
|
Reference in New Issue
Block a user