diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index e75637bd58..9d26d0a196 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -25,6 +25,7 @@ from invokeai.app.invocations.primitives import ( LatentsField, LatentsOutput, InpaintMaskField, + InpaintMaskOutput, build_latents_output, ) from invokeai.app.util.controlnet_utils import prepare_control_image @@ -66,6 +67,76 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device()) SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] +@title("Create inpaint mask") +@tags("mask", "inpaint") +class CreateInpaintMaskInvocation(BaseInvocation): + """Creates mask for inpaint model run.""" + + # Metadata + type: Literal["create_inpaint_mask"] = "create_inpaint_mask" + + # Inputs + image: Optional[ImageField] = InputField(default=None, description="Image which will be inpainted") + mask: ImageField = InputField(description="The mask to use when pasting") + vae: VaeField = InputField( + description=FieldDescriptions.vae, + input=Input.Connection, + ) + tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) + fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) + + def prep_mask_tensor(self, mask_image): + if mask_image.mode != "L": + # FIXME: why do we get passed an RGB image here? We can only use single-channel. + mask_image = mask_image.convert("L") + mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) + if mask_tensor.dim() == 3: + mask_tensor = mask_tensor.unsqueeze(0) + #if shape is not None: + # mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR) + return mask_tensor + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> InpaintMaskOutput: + if self.image is not None: + image = context.services.images.get_pil_image(self.image.image_name) + image = image_resized_to_grid_as_tensor(image.convert("RGB")) + if image.dim() == 3: + image = image.unsqueeze(0) + else: + image = None + + mask = self.prep_mask_tensor( + context.services.images.get_pil_image(self.mask.image_name), + ) + + if image is not None: + vae_info = context.services.model_manager.get_model( + **self.vae.vae.dict(), + context=context, + ) + + img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR) + masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0) + # TODO: + masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) + + masked_latents_name = f"{context.graph_execution_state_id}__{self.id}_masked_latents" + context.services.latents.save(masked_latents_name, masked_latents) + else: + masked_latents_name = None + + mask_name = f"{context.graph_execution_state_id}__{self.id}_mask" + context.services.latents.save(mask_name, mask) + + return InpaintMaskOutput( + inpaint_mask=InpaintMaskField( + mask_name=mask_name, + masked_latents_name=masked_latents_name, + ), + ) + + def get_scheduler( context: InvocationContext, scheduler_info: ModelInfo, @@ -340,19 +411,18 @@ class DenoiseLatentsInvocation(BaseInvocation): return num_inference_steps, timesteps, init_timestep - def prep_mask_tensor(self, mask, context, lantents): - if mask is None: - return None + def prep_inpaint_mask(self, context, latents): + if self.mask is None: + return None, None - mask_image = context.services.images.get_pil_image(mask.image_name) - if mask_image.mode != "L": - # FIXME: why do we get passed an RGB image here? We can only use single-channel. - mask_image = mask_image.convert("L") - mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) - if mask_tensor.dim() == 3: - mask_tensor = mask_tensor.unsqueeze(0) - mask_tensor = tv_resize(mask_tensor, lantents.shape[-2:], T.InterpolationMode.BILINEAR) - return 1 - mask_tensor + mask = context.services.latents.get(self.mask.mask_name) + mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR) + if self.mask.masked_latents_name is not None: + masked_latents = context.services.latents.get(self.mask.masked_latents_name) + else: + masked_latents = None + + return 1 - mask, masked_latents @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: @@ -373,7 +443,7 @@ class DenoiseLatentsInvocation(BaseInvocation): if seed is None: seed = 0 - mask = self.prep_mask_tensor(self.mask, context, latents) + mask, masked_latents = self.prep_inpaint_mask(context, latents) # Get the source node id (we are invoking the prepared node) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) @@ -404,6 +474,8 @@ class DenoiseLatentsInvocation(BaseInvocation): noise = noise.to(device=unet.device, dtype=unet.dtype) if mask is not None: mask = mask.to(device=unet.device, dtype=unet.dtype) + if masked_latents is not None: + masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype) scheduler = get_scheduler( context=context, @@ -440,6 +512,7 @@ class DenoiseLatentsInvocation(BaseInvocation): noise=noise, seed=seed, mask=mask, + masked_latents=masked_latents, num_inference_steps=num_inference_steps, conditioning_data=conditioning_data, control_data=control_data, # list[ControlNetData] @@ -661,26 +734,11 @@ class ImageToLatentsInvocation(BaseInvocation): tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) - @torch.no_grad() - def invoke(self, context: InvocationContext) -> LatentsOutput: - # image = context.services.images.get( - # self.image.image_type, self.image.image_name - # ) - image = context.services.images.get_pil_image(self.image.image_name) - - # vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) - vae_info = context.services.model_manager.get_model( - **self.vae.vae.dict(), - context=context, - ) - - image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) - if image_tensor.dim() == 3: - image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") - + @staticmethod + def vae_encode(vae_info, upcast, tiled, image_tensor): with vae_info as vae: orig_dtype = vae.dtype - if self.fp32: + if upcast: vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( @@ -705,7 +763,7 @@ class ImageToLatentsInvocation(BaseInvocation): vae.to(dtype=torch.float16) # latents = latents.half() - if self.tiled: + if tiled: vae.enable_tiling() else: vae.disable_tiling() @@ -719,6 +777,27 @@ class ImageToLatentsInvocation(BaseInvocation): latents = vae.config.scaling_factor * latents latents = latents.to(dtype=orig_dtype) + return latents + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> LatentsOutput: + # image = context.services.images.get( + # self.image.image_type, self.image.image_name + # ) + image = context.services.images.get_pil_image(self.image.image_name) + + # vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) + vae_info = context.services.model_manager.get_model( + **self.vae.vae.dict(), + context=context, + ) + + image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) + if image_tensor.dim() == 3: + image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") + + latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor) + name = f"{context.graph_execution_state_id}__{self.id}" latents = latents.to("cpu") context.services.latents.save(name, latents) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 63738b349f..e9656271ac 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -314,7 +314,14 @@ class InpaintMaskField(BaseModel): """An inpaint mask field""" mask_name: str = Field(description="The name of the mask image") - masked_latens_name: Optional[str] = Field(description="The name of the masked image latents") + masked_latents_name: Optional[str] = Field(description="The name of the masked image latents") + + +class InpaintMaskOutput(BaseInvocationOutput): + """Base class for nodes that output a single image""" + + type: Literal["inpaint_mask_output"] = "inpaint_mask_output" + inpaint_mask: InpaintMaskField = OutputField(description="Mask for inpaint model run") # endregion diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 3796bbbec7..fb1ceb5b1c 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -342,6 +342,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): callback: Callable[[PipelineIntermediateState], None] = None, control_data: List[ControlNetData] = None, mask: Optional[torch.Tensor] = None, + masked_latents: Optional[torch.Tensor] = None, seed: Optional[int] = None, ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: if init_timestep.shape[0] == 0: @@ -375,11 +376,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): ) if is_inpainting_model(self.unet): - # You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint - # (that's why there's a mask!) but it seems to really want that blanked out. - masked_latents = orig_latents * torch.where(mask < 0.5, 1, 0) + if masked_latents is None: + raise Exception("Source image required for inpaint mask when inpaint model used!") - # TODO: we should probably pass this in so we don't have to try/finally around setting it. self.invokeai_diffuser.model_forward_callback = AddsMaskLatents( self._unet_forward, mask, masked_latents )