mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove regional conditioning logic from MultiDiffusionPipeline - it is not yet supported.
This commit is contained in:
parent
49ca42f84a
commit
865c2335de
@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from contextlib import nullcontext
|
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -12,7 +11,6 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
)
|
)
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
|
||||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
|
||||||
from invokeai.backend.tiles.utils import Tile
|
from invokeai.backend.tiles.utils import Tile
|
||||||
|
|
||||||
|
|
||||||
@ -79,12 +77,8 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
use_regional_prompting = (
|
use_regional_prompting = (
|
||||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||||
)
|
)
|
||||||
unet_attention_patcher = None
|
|
||||||
attn_ctx = nullcontext()
|
|
||||||
|
|
||||||
if use_regional_prompting:
|
if use_regional_prompting:
|
||||||
unet_attention_patcher = UNetAttentionPatcher(ip_adapter_data=None)
|
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
|
||||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
|
||||||
|
|
||||||
# Populate a weighted mask that will be used to combine the results from each region after every step.
|
# Populate a weighted mask that will be used to combine the results from each region after every step.
|
||||||
# For now, we assume that each regions has the same weight (1.0).
|
# For now, we assume that each regions has the same weight (1.0).
|
||||||
@ -96,66 +90,65 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
||||||
] += 1.0
|
] += 1.0
|
||||||
|
|
||||||
with attn_ctx:
|
callback(
|
||||||
|
PipelineIntermediateState(
|
||||||
|
step=-1,
|
||||||
|
order=self.scheduler.order,
|
||||||
|
total_steps=len(timesteps),
|
||||||
|
timestep=self.scheduler.config.num_train_timesteps,
|
||||||
|
latents=latents,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
|
batched_t = t.expand(batch_size)
|
||||||
|
|
||||||
|
prev_samples_by_region: list[torch.Tensor] = []
|
||||||
|
pred_original_by_region: list[torch.Tensor | None] = []
|
||||||
|
for region in regions:
|
||||||
|
# Run a denoising step on the region.
|
||||||
|
step_output = self._region_step(
|
||||||
|
region=region,
|
||||||
|
t=batched_t,
|
||||||
|
latents=latents,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
step_index=i,
|
||||||
|
total_step_count=len(timesteps),
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
|
control_data=control_data,
|
||||||
|
)
|
||||||
|
prev_samples_by_region.append(step_output.prev_sample)
|
||||||
|
pred_original_by_region.append(getattr(step_output, "pred_original_sample", None))
|
||||||
|
|
||||||
|
# Merge the prev_sample results from each region.
|
||||||
|
merged_latents = torch.zeros_like(latents)
|
||||||
|
for region_idx, region in enumerate(regions):
|
||||||
|
merged_latents[
|
||||||
|
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
||||||
|
] += prev_samples_by_region[region_idx]
|
||||||
|
latents = merged_latents / region_weight_mask
|
||||||
|
|
||||||
|
# Merge the predicted_original results from each region.
|
||||||
|
predicted_original = None
|
||||||
|
if all(pred_original_by_region):
|
||||||
|
merged_pred_original = torch.zeros_like(latents)
|
||||||
|
for region_idx, region in enumerate(regions):
|
||||||
|
merged_pred_original[
|
||||||
|
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
||||||
|
] += pred_original_by_region[region_idx]
|
||||||
|
predicted_original = merged_pred_original / region_weight_mask
|
||||||
|
|
||||||
callback(
|
callback(
|
||||||
PipelineIntermediateState(
|
PipelineIntermediateState(
|
||||||
step=-1,
|
step=i,
|
||||||
order=self.scheduler.order,
|
order=self.scheduler.order,
|
||||||
total_steps=len(timesteps),
|
total_steps=len(timesteps),
|
||||||
timestep=self.scheduler.config.num_train_timesteps,
|
timestep=int(t),
|
||||||
latents=latents,
|
latents=latents,
|
||||||
|
predicted_original=predicted_original,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
|
||||||
batched_t = t.expand(batch_size)
|
|
||||||
|
|
||||||
prev_samples_by_region: list[torch.Tensor] = []
|
|
||||||
pred_original_by_region: list[torch.Tensor | None] = []
|
|
||||||
for region in regions:
|
|
||||||
# Run a denoising step on the region.
|
|
||||||
step_output = self._region_step(
|
|
||||||
region=region,
|
|
||||||
t=batched_t,
|
|
||||||
latents=latents,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
step_index=i,
|
|
||||||
total_step_count=len(timesteps),
|
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
|
||||||
control_data=control_data,
|
|
||||||
)
|
|
||||||
prev_samples_by_region.append(step_output.prev_sample)
|
|
||||||
pred_original_by_region.append(getattr(step_output, "pred_original_sample", None))
|
|
||||||
|
|
||||||
# Merge the prev_sample results from each region.
|
|
||||||
merged_latents = torch.zeros_like(latents)
|
|
||||||
for region_idx, region in enumerate(regions):
|
|
||||||
merged_latents[
|
|
||||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
|
||||||
] += prev_samples_by_region[region_idx]
|
|
||||||
latents = merged_latents / region_weight_mask
|
|
||||||
|
|
||||||
# Merge the predicted_original results from each region.
|
|
||||||
predicted_original = None
|
|
||||||
if all(pred_original_by_region):
|
|
||||||
merged_pred_original = torch.zeros_like(latents)
|
|
||||||
for region_idx, region in enumerate(regions):
|
|
||||||
merged_pred_original[
|
|
||||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
|
||||||
] += pred_original_by_region[region_idx]
|
|
||||||
predicted_original = merged_pred_original / region_weight_mask
|
|
||||||
|
|
||||||
callback(
|
|
||||||
PipelineIntermediateState(
|
|
||||||
step=i,
|
|
||||||
order=self.scheduler.order,
|
|
||||||
total_steps=len(timesteps),
|
|
||||||
timestep=int(t),
|
|
||||||
latents=latents,
|
|
||||||
predicted_original=predicted_original,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
Loading…
Reference in New Issue
Block a user