mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Connect TiledMultiDiffusionDenoiseLatents to the MultiDiffusionPipeline backend.
This commit is contained in:
parent
865c2335de
commit
35adaf1c17
@ -1,10 +1,10 @@
|
|||||||
|
import copy
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import Iterator, Tuple
|
from typing import Iterator, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
from PIL import Image
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||||
from pydantic import field_validator
|
from pydantic import field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
@ -19,21 +19,38 @@ from invokeai.app.invocations.fields import (
|
|||||||
LatentsField,
|
LatentsField,
|
||||||
UIType,
|
UIType,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
|
|
||||||
from invokeai.app.invocations.model import UNetField
|
from invokeai.app.invocations.model import UNetField
|
||||||
from invokeai.app.invocations.noise import get_noise
|
from invokeai.app.invocations.noise import get_noise
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import LatentsOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData
|
||||||
|
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
|
||||||
|
MultiDiffusionPipeline,
|
||||||
|
MultiDiffusionRegionConditioning,
|
||||||
|
)
|
||||||
from invokeai.backend.tiles.tiles import (
|
from invokeai.backend.tiles.tiles import (
|
||||||
calc_tiles_min_overlap,
|
calc_tiles_min_overlap,
|
||||||
merge_tiles_with_linear_blending,
|
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.tiles.utils import TBLR
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
|
def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> ControlNetData:
|
||||||
|
"""Crop a ControlNetData object to a region."""
|
||||||
|
# Create a shallow copy of the control_data object.
|
||||||
|
control_data_copy = copy.copy(control_data)
|
||||||
|
# The ControlNet reference image is the only attribute that needs to be cropped.
|
||||||
|
control_data_copy.image_tensor = control_data.image_tensor[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
latent_region.top * LATENT_SCALE_FACTOR : latent_region.bottom * LATENT_SCALE_FACTOR,
|
||||||
|
latent_region.left * LATENT_SCALE_FACTOR : latent_region.right * LATENT_SCALE_FACTOR,
|
||||||
|
]
|
||||||
|
return control_data_copy
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"tiled_multi_diffusion_denoise_latents",
|
"tiled_multi_diffusion_denoise_latents",
|
||||||
title="Tiled Multi-Diffusion Denoise Latents",
|
title="Tiled Multi-Diffusion Denoise Latents",
|
||||||
@ -119,8 +136,33 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
|||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_pipeline(
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
scheduler: SchedulerMixin,
|
||||||
|
) -> MultiDiffusionPipeline:
|
||||||
|
# TODO(ryand): Get rid of this FakeVae hack.
|
||||||
|
class FakeVae:
|
||||||
|
class FakeVaeConfig:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.block_out_channels = [0]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.config = FakeVae.FakeVaeConfig()
|
||||||
|
|
||||||
|
return MultiDiffusionPipeline(
|
||||||
|
vae=FakeVae(), # TODO: oh...
|
||||||
|
text_encoder=None,
|
||||||
|
tokenizer=None,
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
|
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
_, _, latent_height, latent_width = latents.shape
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
|
||||||
@ -149,15 +191,6 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
|||||||
min_overlap=self.tile_min_overlap,
|
min_overlap=self.tile_min_overlap,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Split the noise and latents into tiles.
|
|
||||||
noise_tiles: list[torch.Tensor] = []
|
|
||||||
latent_tiles: list[torch.Tensor] = []
|
|
||||||
for tile in tiles:
|
|
||||||
noise_tile = noise[..., tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right]
|
|
||||||
latent_tile = latents[..., tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right]
|
|
||||||
noise_tiles.append(noise_tile)
|
|
||||||
latent_tiles.append(latent_tile)
|
|
||||||
|
|
||||||
# Prepare an iterator that yields the UNet's LoRA models and their weights.
|
# Prepare an iterator that yields the UNet's LoRA models and their weights.
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
@ -169,7 +202,6 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
|||||||
# Load the UNet model.
|
# Load the UNet model.
|
||||||
unet_info = context.models.load(self.unet.unet)
|
unet_info = context.models.load(self.unet.unet)
|
||||||
|
|
||||||
refined_latent_tiles: list[torch.Tensor] = []
|
|
||||||
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
|
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
|
||||||
assert isinstance(unet, UNet2DConditionModel)
|
assert isinstance(unet, UNet2DConditionModel)
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
@ -178,7 +210,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
|||||||
scheduler_name=self.scheduler,
|
scheduler_name=self.scheduler,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler)
|
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)
|
||||||
|
|
||||||
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
|
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
|
||||||
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
|
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
|
||||||
@ -203,95 +235,47 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Split the controlnet_data into tiles.
|
# Split the controlnet_data into tiles.
|
||||||
if controlnet_data is not None:
|
# controlnet_data_tiles[t][c] is the c'th control data for the t'th tile.
|
||||||
# controlnet_data_tiles[t][c] is the c'th control data for the t'th tile.
|
controlnet_data_tiles: list[list[ControlNetData]] = []
|
||||||
controlnet_data_tiles: list[list[ControlNetData]] = []
|
for tile in tiles:
|
||||||
for tile in tiles:
|
tile_controlnet_data = [crop_controlnet_data(cn, tile.coords) for cn in controlnet_data or []]
|
||||||
# To split the controlnet_data into tiles, we simply need to crop each image_tensor. All other
|
controlnet_data_tiles.append(tile_controlnet_data)
|
||||||
# params can be copied unmodified.
|
|
||||||
tile_controlnet_data = [
|
|
||||||
ControlNetData(
|
|
||||||
model=cn.model,
|
|
||||||
image_tensor=cn.image_tensor[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
tile.coords.top * LATENT_SCALE_FACTOR : tile.coords.bottom * LATENT_SCALE_FACTOR,
|
|
||||||
tile.coords.left * LATENT_SCALE_FACTOR : tile.coords.right * LATENT_SCALE_FACTOR,
|
|
||||||
],
|
|
||||||
weight=cn.weight,
|
|
||||||
begin_step_percent=cn.begin_step_percent,
|
|
||||||
end_step_percent=cn.end_step_percent,
|
|
||||||
control_mode=cn.control_mode,
|
|
||||||
resize_mode=cn.resize_mode,
|
|
||||||
)
|
|
||||||
for cn in controlnet_data
|
|
||||||
]
|
|
||||||
controlnet_data_tiles.append(tile_controlnet_data)
|
|
||||||
|
|
||||||
# Denoise (i.e. "refine") each tile independently.
|
# Prepare the MultiDiffusionRegionConditioning list.
|
||||||
for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True):
|
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning] = []
|
||||||
assert latent_tile.shape == noise_tile.shape
|
for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True):
|
||||||
|
multi_diffusion_conditioning.append(
|
||||||
# Prepare a PIL Image for ControlNet processing.
|
MultiDiffusionRegionConditioning(
|
||||||
# TODO(ryand): This is a bit awkward that we have to prepare both torch.Tensor and PIL.Image versions of
|
region=tile.coords,
|
||||||
# the tiles. Ideally, the ControlNet code should be able to work with Tensors.
|
text_conditioning_data=conditioning_data,
|
||||||
image_tile_pil = Image.fromarray(image_tile_np)
|
control_data=tile_controlnet_data,
|
||||||
|
)
|
||||||
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
|
|
||||||
scheduler,
|
|
||||||
device=unet.device,
|
|
||||||
steps=self.steps,
|
|
||||||
denoising_start=self.denoising_start,
|
|
||||||
denoising_end=self.denoising_end,
|
|
||||||
seed=seed,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM.
|
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
|
||||||
latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype)
|
scheduler,
|
||||||
noise_tile = noise_tile.to(device=unet.device, dtype=unet.dtype)
|
device=unet.device,
|
||||||
refined_latent_tile = pipeline.latents_from_embeddings(
|
steps=self.steps,
|
||||||
latents=latent_tile,
|
denoising_start=self.denoising_start,
|
||||||
timesteps=timesteps,
|
denoising_end=self.denoising_end,
|
||||||
init_timestep=init_timestep,
|
seed=seed,
|
||||||
noise=noise_tile,
|
)
|
||||||
seed=seed,
|
|
||||||
mask=None,
|
# Run Multi-Diffusion denoising.
|
||||||
masked_latents=None,
|
result_latents = pipeline.multi_diffusion_denoise(
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
multi_diffusion_conditioning=multi_diffusion_conditioning,
|
||||||
conditioning_data=conditioning_data,
|
latents=latents,
|
||||||
control_data=[controlnet_data],
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
ip_adapter_data=None,
|
noise=noise,
|
||||||
t2i_adapter_data=None,
|
timesteps=timesteps,
|
||||||
callback=lambda x: None,
|
init_timestep=init_timestep,
|
||||||
)
|
# TODO(ryand): Add proper callback.
|
||||||
refined_latent_tiles.append(refined_latent_tile)
|
callback=lambda x: None,
|
||||||
|
|
||||||
# VAE-decode each refined latent tile independently.
|
|
||||||
refined_image_tiles: list[Image.Image] = []
|
|
||||||
for refined_latent_tile in refined_latent_tiles:
|
|
||||||
refined_image_tile = LatentsToImageInvocation.vae_decode(
|
|
||||||
context=context,
|
|
||||||
vae_info=vae_info,
|
|
||||||
seamless_axes=self.vae.seamless_axes,
|
|
||||||
latents=refined_latent_tile,
|
|
||||||
use_fp32=self.vae_fp32,
|
|
||||||
use_tiling=False,
|
|
||||||
)
|
)
|
||||||
refined_image_tiles.append(refined_image_tile)
|
|
||||||
|
|
||||||
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
|
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
|
||||||
|
result_latents = result_latents.to("cpu")
|
||||||
TorchDevice.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
# Merge the refined image tiles back into a single image.
|
name = context.tensors.save(tensor=result_latents)
|
||||||
refined_image_tiles_np = [np.array(t) for t in refined_image_tiles]
|
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||||
merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8)
|
|
||||||
# TODO(ryand): Tune the blend_amount. Should this be exposed as a parameter?
|
|
||||||
merge_tiles_with_linear_blending(
|
|
||||||
dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=self.tile_overlap
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save the refined image and return its reference.
|
|
||||||
merged_image_pil = Image.fromarray(merged_image_np)
|
|
||||||
image_dto = context.images.save(image=merged_image_pil)
|
|
||||||
|
|
||||||
return ImageOutput.build(image_dto)
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -11,7 +11,15 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
)
|
)
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
|
||||||
from invokeai.backend.tiles.utils import Tile
|
from invokeai.backend.tiles.utils import TBLR
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiDiffusionRegionConditioning:
|
||||||
|
# Region coords in latent space.
|
||||||
|
region: TBLR
|
||||||
|
text_conditioning_data: TextConditioningData
|
||||||
|
control_data: list[ControlNetData]
|
||||||
|
|
||||||
|
|
||||||
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||||
@ -45,15 +53,13 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
# - May need a cleaner AddsMaskGuidance implementation to handle this plan... we'll see.
|
# - May need a cleaner AddsMaskGuidance implementation to handle this plan... we'll see.
|
||||||
def multi_diffusion_denoise(
|
def multi_diffusion_denoise(
|
||||||
self,
|
self,
|
||||||
regions: list[Tile],
|
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
scheduler_step_kwargs: dict[str, Any],
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
conditioning_data: TextConditioningData,
|
|
||||||
noise: Optional[torch.Tensor],
|
noise: Optional[torch.Tensor],
|
||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
init_timestep: torch.Tensor,
|
init_timestep: torch.Tensor,
|
||||||
callback: Callable[[PipelineIntermediateState], None],
|
callback: Callable[[PipelineIntermediateState], None],
|
||||||
control_data: list[ControlNetData] | None = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
|
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
|
||||||
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
|
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
|
||||||
@ -74,21 +80,14 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
# cropping into regions.
|
# cropping into regions.
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
|
|
||||||
use_regional_prompting = (
|
|
||||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
|
||||||
)
|
|
||||||
if use_regional_prompting:
|
|
||||||
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
|
|
||||||
|
|
||||||
# Populate a weighted mask that will be used to combine the results from each region after every step.
|
# Populate a weighted mask that will be used to combine the results from each region after every step.
|
||||||
# For now, we assume that each regions has the same weight (1.0).
|
# For now, we assume that each regions has the same weight (1.0).
|
||||||
region_weight_mask = torch.zeros(
|
region_weight_mask = torch.zeros(
|
||||||
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
|
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
|
||||||
)
|
)
|
||||||
for region in regions:
|
for region_conditioning in multi_diffusion_conditioning:
|
||||||
region_weight_mask[
|
region = region_conditioning.region
|
||||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
|
||||||
] += 1.0
|
|
||||||
|
|
||||||
callback(
|
callback(
|
||||||
PipelineIntermediateState(
|
PipelineIntermediateState(
|
||||||
@ -103,39 +102,36 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
batched_t = t.expand(batch_size)
|
batched_t = t.expand(batch_size)
|
||||||
|
|
||||||
prev_samples_by_region: list[torch.Tensor] = []
|
merged_latents = torch.zeros_like(latents)
|
||||||
pred_original_by_region: list[torch.Tensor | None] = []
|
merged_pred_original: torch.Tensor | None = None
|
||||||
for region in regions:
|
for region_conditioning in multi_diffusion_conditioning:
|
||||||
# Run a denoising step on the region.
|
# Run a denoising step on the region.
|
||||||
step_output = self._region_step(
|
step_output = self._region_step(
|
||||||
region=region,
|
region_conditioning=region_conditioning,
|
||||||
t=batched_t,
|
t=batched_t,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
step_index=i,
|
step_index=i,
|
||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
control_data=control_data,
|
|
||||||
)
|
)
|
||||||
prev_samples_by_region.append(step_output.prev_sample)
|
|
||||||
pred_original_by_region.append(getattr(step_output, "pred_original_sample", None))
|
|
||||||
|
|
||||||
# Merge the prev_sample results from each region.
|
# Store the results from the region.
|
||||||
merged_latents = torch.zeros_like(latents)
|
region = region_conditioning.region
|
||||||
for region_idx, region in enumerate(regions):
|
merged_latents[:, :, region.top : region.bottom, region.left : region.right] += step_output.prev_sample
|
||||||
merged_latents[
|
pred_orig_sample = getattr(step_output, "pred_original_sample", None)
|
||||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
if pred_orig_sample is not None:
|
||||||
] += prev_samples_by_region[region_idx]
|
# If one region has pred_original_sample, then we can assume that all regions will have it, because
|
||||||
|
# they all use the same scheduler.
|
||||||
|
if merged_pred_original is None:
|
||||||
|
merged_pred_original = torch.zeros_like(latents)
|
||||||
|
merged_pred_original[:, :, region.top : region.bottom, region.left : region.right] += (
|
||||||
|
pred_orig_sample
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalize the merged results.
|
||||||
latents = merged_latents / region_weight_mask
|
latents = merged_latents / region_weight_mask
|
||||||
|
|
||||||
# Merge the predicted_original results from each region.
|
|
||||||
predicted_original = None
|
predicted_original = None
|
||||||
if all(pred_original_by_region):
|
if merged_pred_original is not None:
|
||||||
merged_pred_original = torch.zeros_like(latents)
|
|
||||||
for region_idx, region in enumerate(regions):
|
|
||||||
merged_pred_original[
|
|
||||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
|
||||||
] += pred_original_by_region[region_idx]
|
|
||||||
predicted_original = merged_pred_original / region_weight_mask
|
predicted_original = merged_pred_original / region_weight_mask
|
||||||
|
|
||||||
callback(
|
callback(
|
||||||
@ -154,44 +150,38 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _region_step(
|
def _region_step(
|
||||||
self,
|
self,
|
||||||
region: Tile,
|
region_conditioning: MultiDiffusionRegionConditioning,
|
||||||
t: torch.Tensor,
|
t: torch.Tensor,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
conditioning_data: TextConditioningData,
|
|
||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
scheduler_step_kwargs: dict[str, Any],
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
control_data: list[ControlNetData] | None = None,
|
|
||||||
):
|
):
|
||||||
|
use_regional_prompting = (
|
||||||
|
region_conditioning.text_conditioning_data.cond_regions is not None
|
||||||
|
or region_conditioning.text_conditioning_data.uncond_regions is not None
|
||||||
|
)
|
||||||
|
if use_regional_prompting:
|
||||||
|
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
|
||||||
|
|
||||||
# Crop the inputs to the region.
|
# Crop the inputs to the region.
|
||||||
region_latents = latents[
|
region_latents = latents[
|
||||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
:,
|
||||||
|
:,
|
||||||
|
region_conditioning.region.top : region_conditioning.region.bottom,
|
||||||
|
region_conditioning.region.left : region_conditioning.region.right,
|
||||||
]
|
]
|
||||||
|
|
||||||
region_control_data: list[ControlNetData] | None = None
|
|
||||||
if control_data is not None:
|
|
||||||
region_control_data = [self._crop_controlnet_data(c, region) for c in control_data]
|
|
||||||
|
|
||||||
# Run the denoising step on the region.
|
# Run the denoising step on the region.
|
||||||
return self.step(
|
return self.step(
|
||||||
t=t,
|
t=t,
|
||||||
latents=region_latents,
|
latents=region_latents,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=region_conditioning.text_conditioning_data,
|
||||||
step_index=step_index,
|
step_index=step_index,
|
||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
mask_guidance=None,
|
mask_guidance=None,
|
||||||
mask=None,
|
mask=None,
|
||||||
masked_latents=None,
|
masked_latents=None,
|
||||||
control_data=region_control_data,
|
control_data=region_conditioning.control_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _crop_controlnet_data(self, control_data: ControlNetData, region: Tile) -> ControlNetData:
|
|
||||||
"""Crop a ControlNetData object to a region."""
|
|
||||||
# Create a shallow copy of the control_data object.
|
|
||||||
control_data_copy = copy.copy(control_data)
|
|
||||||
# The ControlNet reference image is the only attribute that needs to be cropped.
|
|
||||||
control_data_copy.image_tensor = control_data.image_tensor[
|
|
||||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
|
||||||
]
|
|
||||||
return control_data_copy
|
|
||||||
|
Loading…
Reference in New Issue
Block a user