feat(nodes): added gradient mask node

This commit is contained in:
dunkeroni
2024-02-20 21:13:19 -05:00
committed by psychedelicious
parent f7fc20459a
commit 06cc57d82a
4 changed files with 80 additions and 11 deletions

View File

@ -86,6 +86,7 @@ class AddsMaskGuidance:
mask_latents: torch.FloatTensor
scheduler: SchedulerMixin
noise: torch.Tensor
gradient_mask: bool
def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput:
output_class = step_output.__class__ # We'll create a new one with masked data.
@ -121,7 +122,12 @@ class AddsMaskGuidance:
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
if self.gradient_mask:
threshhold = (t.item()) / self.scheduler.config.num_train_timesteps
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
masked_input = torch.where(mask_bool, latents, mask_latents)
else:
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
return masked_input
@ -335,6 +341,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None,
gradient_mask: Optional[bool] = False,
seed: Optional[int] = None,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if init_timestep.shape[0] == 0:
@ -375,7 +382,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self._unet_forward, mask, masked_latents
)
else:
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
try:
latents, attention_map_saver = self.generate_latents_from_embeddings(
@ -392,7 +399,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.invokeai_diffuser.model_forward_callback = self._unet_forward
# restore unmasked part
if mask is not None:
if mask is not None and not gradient_mask:
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
return latents, attention_map_saver