Stricter typing for the is_gradient_mask: bool.

This commit is contained in:
Ryan Dick 2024-06-12 15:23:40 -04:00
parent f82af7c22d
commit 73a8c55852
3 changed files with 6 additions and 7 deletions

View File

@ -819,7 +819,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed=seed, seed=seed,
mask=mask, mask=mask,
masked_latents=masked_latents, masked_latents=masked_latents,
gradient_mask=gradient_mask, is_gradient_mask=gradient_mask,
scheduler_step_kwargs=scheduler_step_kwargs, scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=controlnet_data, control_data=controlnet_data,

View File

@ -340,7 +340,6 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
seed=seed, seed=seed,
mask=None, mask=None,
masked_latents=None, masked_latents=None,
gradient_mask=None,
scheduler_step_kwargs=scheduler_step_kwargs, scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=[controlnet_data], control_data=[controlnet_data],

View File

@ -82,7 +82,7 @@ class AddsMaskGuidance:
mask_latents: torch.Tensor mask_latents: torch.Tensor
scheduler: SchedulerMixin scheduler: SchedulerMixin
noise: torch.Tensor noise: torch.Tensor
gradient_mask: bool is_gradient_mask: bool
def __call__(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: def __call__(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return self.apply_mask(latents, t) return self.apply_mask(latents, t)
@ -100,7 +100,7 @@ class AddsMaskGuidance:
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already? # TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
# mask_latents = self.scheduler.scale_model_input(mask_latents, t) # mask_latents = self.scheduler.scale_model_input(mask_latents, t)
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
if self.gradient_mask: if self.is_gradient_mask:
threshhold = (t.item()) / self.scheduler.config.num_train_timesteps threshhold = (t.item()) / self.scheduler.config.num_train_timesteps
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
masked_input = torch.where(mask_bool, latents, mask_latents) masked_input = torch.where(mask_bool, latents, mask_latents)
@ -295,7 +295,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
t2i_adapter_data: Optional[list[T2IAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None, masked_latents: Optional[torch.Tensor] = None,
gradient_mask: Optional[bool] = False, is_gradient_mask: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if init_timestep.shape[0] == 0: if init_timestep.shape[0] == 0:
return latents return latents
@ -328,7 +328,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
generator=torch.Generator(device="cpu").manual_seed(seed), generator=torch.Generator(device="cpu").manual_seed(seed),
).to(device=orig_latents.device, dtype=orig_latents.dtype) ).to(device=orig_latents.device, dtype=orig_latents.dtype)
mask_guidance = AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask) mask_guidance = AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, is_gradient_mask)
try: try:
latents = self.generate_latents_from_embeddings( latents = self.generate_latents_from_embeddings(
@ -348,7 +348,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# restore unmasked part after the last step is completed # restore unmasked part after the last step is completed
# in-process masking happens before each step # in-process masking happens before each step
if mask is not None: if mask is not None:
if gradient_mask: if is_gradient_mask:
latents = torch.where(mask > 0, latents, orig_latents) latents = torch.where(mask > 0, latents, orig_latents)
else: else:
latents = torch.lerp( latents = torch.lerp(