chore: comments and ruff

This commit is contained in:
dunkeroni 2024-04-11 23:25:04 -04:00 committed by Kent Keirsey
parent 6d7c8d5f57
commit bc12d6654e

View File

@ -202,14 +202,14 @@ class CreateGradientMaskInvocation(BaseInvocation):
default=None,
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
title="[OPTIONAL] Image",
ui_order=6
ui_order=6,
)
vae: Optional[VAEField] = InputField(
default=None,
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
title="[OPTIONAL] VAE",
input=Input.Connection,
ui_order=7
ui_order=7,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
fp32: bool = InputField(
@ -218,7 +218,6 @@ class CreateGradientMaskInvocation(BaseInvocation):
ui_order=9,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
@ -254,7 +253,6 @@ class CreateGradientMaskInvocation(BaseInvocation):
expanded_image_dto = context.images.save(expanded_mask_image)
masked_latents_name = None
# Check for Inpaint model and generate masked_latents
if self.vae is not None and self.image is not None:
# both fields must be present at the same time
mask = blur_tensor
@ -264,11 +262,10 @@ class CreateGradientMaskInvocation(BaseInvocation):
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) # <1 to include gradient area
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
masked_latents_name = context.tensors.save(tensor=masked_latents)
return GradientMaskOutput(
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),