mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore: comments and ruff
This commit is contained in:
parent
6d7c8d5f57
commit
bc12d6654e
@ -202,14 +202,14 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
|||||||
default=None,
|
default=None,
|
||||||
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
|
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
|
||||||
title="[OPTIONAL] Image",
|
title="[OPTIONAL] Image",
|
||||||
ui_order=6
|
ui_order=6,
|
||||||
)
|
)
|
||||||
vae: Optional[VAEField] = InputField(
|
vae: Optional[VAEField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
|
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
|
||||||
title="[OPTIONAL] VAE",
|
title="[OPTIONAL] VAE",
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
ui_order=7
|
ui_order=7,
|
||||||
)
|
)
|
||||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
|
||||||
fp32: bool = InputField(
|
fp32: bool = InputField(
|
||||||
@ -218,7 +218,6 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
|||||||
ui_order=9,
|
ui_order=9,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||||
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
||||||
@ -254,9 +253,8 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
|||||||
expanded_image_dto = context.images.save(expanded_mask_image)
|
expanded_image_dto = context.images.save(expanded_mask_image)
|
||||||
|
|
||||||
masked_latents_name = None
|
masked_latents_name = None
|
||||||
# Check for Inpaint model and generate masked_latents
|
|
||||||
if self.vae is not None and self.image is not None:
|
if self.vae is not None and self.image is not None:
|
||||||
#both fields must be present at the same time
|
# both fields must be present at the same time
|
||||||
mask = blur_tensor
|
mask = blur_tensor
|
||||||
vae_info: LoadedModel = context.models.load(self.vae.vae)
|
vae_info: LoadedModel = context.models.load(self.vae.vae)
|
||||||
image = context.images.get_pil(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
@ -264,11 +262,10 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
|||||||
if image_tensor.dim() == 3:
|
if image_tensor.dim() == 3:
|
||||||
image_tensor = image_tensor.unsqueeze(0)
|
image_tensor = image_tensor.unsqueeze(0)
|
||||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) # <1 to include gradient area
|
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||||
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||||||
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
||||||
|
|
||||||
|
|
||||||
return GradientMaskOutput(
|
return GradientMaskOutput(
|
||||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
|
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
|
||||||
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
||||||
|
Loading…
Reference in New Issue
Block a user