mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove inpainting support from MultiDiffusionPipeline.
This commit is contained in:
parent
20322d781e
commit
493fcd8660
@ -6,15 +6,13 @@ from typing import Any, Callable, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||||
AddsMaskGuidance,
|
|
||||||
ControlNetData,
|
ControlNetData,
|
||||||
PipelineIntermediateState,
|
PipelineIntermediateState,
|
||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
T2IAdapterData,
|
|
||||||
is_inpainting_model,
|
|
||||||
)
|
)
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
|
||||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||||
|
from invokeai.backend.tiles.utils import Tile
|
||||||
|
|
||||||
|
|
||||||
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||||
@ -46,35 +44,23 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
# - TBD, need to think about this more
|
# - TBD, need to think about this more
|
||||||
# - step(...) remains mostly unmodified, is not overriden in this sub-class.
|
# - step(...) remains mostly unmodified, is not overriden in this sub-class.
|
||||||
# - May need a cleaner AddsMaskGuidance implementation to handle this plan... we'll see.
|
# - May need a cleaner AddsMaskGuidance implementation to handle this plan... we'll see.
|
||||||
def latents_from_embeddings(
|
def multi_diffusion_denoise(
|
||||||
self,
|
self,
|
||||||
|
regions: list[Tile],
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
scheduler_step_kwargs: dict[str, Any],
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
conditioning_data: TextConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
noise: Optional[torch.Tensor],
|
noise: Optional[torch.Tensor],
|
||||||
seed: int,
|
|
||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
init_timestep: torch.Tensor,
|
init_timestep: torch.Tensor,
|
||||||
callback: Callable[[PipelineIntermediateState], None],
|
callback: Callable[[PipelineIntermediateState], None],
|
||||||
control_data: list[ControlNetData] | None = None,
|
control_data: list[ControlNetData] | None = None,
|
||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
|
||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
|
||||||
mask: Optional[torch.Tensor] = None,
|
|
||||||
masked_latents: Optional[torch.Tensor] = None,
|
|
||||||
is_gradient_mask: bool = False,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if ip_adapter_data is not None:
|
|
||||||
raise NotImplementedError("ip_adapter_data is not supported in MultiDiffusionPipeline")
|
|
||||||
if t2i_adapter_data is not None:
|
|
||||||
raise NotImplementedError("t2i_adapter_data is not supported in MultiDiffusionPipeline")
|
|
||||||
|
|
||||||
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
|
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
|
||||||
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
|
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
|
||||||
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
|
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
orig_latents = latents.clone()
|
|
||||||
|
|
||||||
batch_size = latents.shape[0]
|
batch_size = latents.shape[0]
|
||||||
batched_init_timestep = init_timestep.expand(batch_size)
|
batched_init_timestep = init_timestep.expand(batch_size)
|
||||||
|
|
||||||
@ -85,32 +71,10 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||||
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
|
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
|
||||||
|
|
||||||
|
# TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after
|
||||||
|
# cropping into regions.
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
|
|
||||||
# Handle mask guidance (a.k.a. inpainting).
|
|
||||||
mask_guidance: AddsMaskGuidance | None = None
|
|
||||||
if mask is not None and not is_inpainting_model(self.unet):
|
|
||||||
# We are doing inpainting, since a mask is provided, but we are not using an inpainting model, so we will
|
|
||||||
# apply mask guidance to the latents.
|
|
||||||
|
|
||||||
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
|
||||||
# We still need noise for inpainting, so we generate it from the seed here.
|
|
||||||
if noise is None:
|
|
||||||
noise = torch.randn(
|
|
||||||
orig_latents.shape,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device="cpu",
|
|
||||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
|
||||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
|
||||||
|
|
||||||
mask_guidance = AddsMaskGuidance(
|
|
||||||
mask=mask,
|
|
||||||
mask_latents=orig_latents,
|
|
||||||
scheduler=self.scheduler,
|
|
||||||
noise=noise,
|
|
||||||
is_gradient_mask=is_gradient_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
use_regional_prompting = (
|
use_regional_prompting = (
|
||||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||||
)
|
)
|
||||||
@ -141,9 +105,9 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
step_index=i,
|
step_index=i,
|
||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
mask_guidance=mask_guidance,
|
mask_guidance=None,
|
||||||
mask=mask,
|
mask=None,
|
||||||
masked_latents=masked_latents,
|
masked_latents=None,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
@ -160,14 +124,4 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# restore unmasked part after the last step is completed
|
|
||||||
# in-process masking happens before each step
|
|
||||||
if mask is not None:
|
|
||||||
if is_gradient_mask:
|
|
||||||
latents = torch.where(mask > 0, latents, orig_latents)
|
|
||||||
else:
|
|
||||||
latents = torch.lerp(
|
|
||||||
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
|
|
||||||
)
|
|
||||||
|
|
||||||
return latents
|
return latents
|
||||||
|
Loading…
Reference in New Issue
Block a user