From 95883c2efd7393356ac04cf7c277b307d320fcbe Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Sun, 27 Aug 2023 12:29:11 -0400 Subject: [PATCH] Add Initial (non-working) Seamless Implementation --- invokeai/app/invocations/latent.py | 198 +++++++++++++++++++++-------- 1 file changed, 145 insertions(+), 53 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 314301663b..cd2dcc24a2 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -21,6 +21,8 @@ from torchvision.transforms.functional import resize as tv_resize from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.primitives import ( + DenoiseMaskField, + DenoiseMaskOutput, ImageField, ImageOutput, LatentsField, @@ -31,8 +33,8 @@ from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_management.models import ModelType, SilenceWarnings -from ...backend.model_management.models import BaseModelType from ...backend.model_management.lora import ModelPatcher +from ...backend.model_management.models import BaseModelType from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ConditioningData, @@ -44,16 +46,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import Post from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.util.devices import choose_precision, choose_torch_device from ..models.image import ImageCategory, ResourceOrigin -from .baseinvocation import ( - BaseInvocation, - FieldDescriptions, - Input, - InputField, - InvocationContext, - UIType, - tags, - title, -) +from .baseinvocation import BaseInvocation, FieldDescriptions, Input, InputField, InvocationContext, UIType, tags, title from .compel import ConditioningField from .controlnet_image_processors import ControlField from .model import ModelInfo, UNetField, VaeField @@ -64,6 +57,72 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device()) SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] +@title("Create Denoise Mask") +@tags("mask", "denoise") +class CreateDenoiseMaskInvocation(BaseInvocation): + """Creates mask for denoising model run.""" + + # Metadata + type: Literal["create_denoise_mask"] = "create_denoise_mask" + + # Inputs + vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0) + image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1) + mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2) + tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3) + fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32, ui_order=4) + + def prep_mask_tensor(self, mask_image): + if mask_image.mode != "L": + mask_image = mask_image.convert("L") + mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) + if mask_tensor.dim() == 3: + mask_tensor = mask_tensor.unsqueeze(0) + # if shape is not None: + # mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR) + return mask_tensor + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: + if self.image is not None: + image = context.services.images.get_pil_image(self.image.image_name) + image = image_resized_to_grid_as_tensor(image.convert("RGB")) + if image.dim() == 3: + image = image.unsqueeze(0) + else: + image = None + + mask = self.prep_mask_tensor( + context.services.images.get_pil_image(self.mask.image_name), + ) + + if image is not None: + vae_info = context.services.model_manager.get_model( + **self.vae.vae.dict(), + context=context, + ) + + img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR) + masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0) + # TODO: + masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) + + masked_latents_name = f"{context.graph_execution_state_id}__{self.id}_masked_latents" + context.services.latents.save(masked_latents_name, masked_latents) + else: + masked_latents_name = None + + mask_name = f"{context.graph_execution_state_id}__{self.id}_mask" + context.services.latents.save(mask_name, mask) + + return DenoiseMaskOutput( + denoise_mask=DenoiseMaskField( + mask_name=mask_name, + masked_latents_name=masked_latents_name, + ), + ) + + def get_scheduler( context: InvocationContext, scheduler_info: ModelInfo, @@ -126,13 +185,13 @@ class DenoiseLatentsInvocation(BaseInvocation): control: Union[ControlField, list[ControlField]] = InputField( default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5 ) - latents: Optional[LatentsField] = InputField( - description=FieldDescriptions.latents, input=Input.Connection, ui_order=4 - ) - mask: Optional[ImageField] = InputField( + latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection) + denoise_mask: Optional[DenoiseMaskField] = InputField( default=None, description=FieldDescriptions.mask, ) + seamless: bool = InputField(default=False, description="Enable or disable seamless padding") + seamless_axes: str = InputField(default="xy", description="Specify which axes are seamless: 'x', 'y', or 'xy'") @validator("cfg_scale") def ge_one(cls, v): @@ -199,17 +258,46 @@ class DenoiseLatentsInvocation(BaseInvocation): ) return conditioning_data + + def configure_model_padding(self, model): + """ + Modifies the 2D convolution layers to use a circular padding mode based on the `seamless` and `seamless_axes` options. + """ + + def _conv_forward_asymmetric(self, input, weight, bias): + """ + Patch for Conv2d._conv_forward that supports asymmetric padding + """ + working = torch.nn.functional.pad(input, self.asymmetric_padding['x'], mode=self.asymmetric_padding_mode['x']) + working = torch.nn.functional.pad(working, self.asymmetric_padding['y'], mode=self.asymmetric_padding_mode['y']) + return torch.nn.functional.conv2d(working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups) + + for m in model.modules(): + if isinstance(m, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): + if self.seamless: + m.asymmetric_padding_mode = {} + m.asymmetric_padding = {} + m.asymmetric_padding_mode['x'] = 'circular' if ('x' in self.seamless_axes) else 'constant' + m.asymmetric_padding['x'] = (m._reversed_padding_repeated_twice[0], m._reversed_padding_repeated_twice[1], 0, 0) + m.asymmetric_padding_mode['y'] = 'circular' if ('y' in self.seamless_axes) else 'constant' + m.asymmetric_padding['y'] = (0, 0, m._reversed_padding_repeated_twice[2], m._reversed_padding_repeated_twice[3]) + m._conv_forward = _conv_forward_asymmetric.__get__(m, torch.nn.Conv2d) + else: + m._conv_forward = torch.nn.Conv2d._conv_forward.__get__(m, torch.nn.Conv2d) + if hasattr(m, 'asymmetric_padding_mode'): + del m.asymmetric_padding_mode + if hasattr(m, 'asymmetric_padding'): + del m.asymmetric_padding + def create_pipeline( self, unet, scheduler, ) -> StableDiffusionGeneratorPipeline: - # TODO: - # configure_model_padding( - # unet, - # self.seamless, - # self.seamless_axes, - # ) + + self.configure_model_padding( + unet + ) class FakeVae: class FakeVaeConfig: @@ -342,19 +430,18 @@ class DenoiseLatentsInvocation(BaseInvocation): return num_inference_steps, timesteps, init_timestep - def prep_mask_tensor(self, mask, context, lantents): - if mask is None: - return None + def prep_inpaint_mask(self, context, latents): + if self.denoise_mask is None: + return None, None - mask_image = context.services.images.get_pil_image(mask.image_name) - if mask_image.mode != "L": - # FIXME: why do we get passed an RGB image here? We can only use single-channel. - mask_image = mask_image.convert("L") - mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) - if mask_tensor.dim() == 3: - mask_tensor = mask_tensor.unsqueeze(0) - mask_tensor = tv_resize(mask_tensor, lantents.shape[-2:], T.InterpolationMode.BILINEAR) - return 1 - mask_tensor + mask = context.services.latents.get(self.denoise_mask.mask_name) + mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR) + if self.denoise_mask.masked_latents_name is not None: + masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name) + else: + masked_latents = None + + return 1 - mask, masked_latents @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: @@ -375,7 +462,7 @@ class DenoiseLatentsInvocation(BaseInvocation): if seed is None: seed = 0 - mask = self.prep_mask_tensor(self.mask, context, latents) + mask, masked_latents = self.prep_inpaint_mask(context, latents) # Get the source node id (we are invoking the prepared node) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) @@ -406,6 +493,8 @@ class DenoiseLatentsInvocation(BaseInvocation): noise = noise.to(device=unet.device, dtype=unet.dtype) if mask is not None: mask = mask.to(device=unet.device, dtype=unet.dtype) + if masked_latents is not None: + masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype) scheduler = get_scheduler( context=context, @@ -442,6 +531,7 @@ class DenoiseLatentsInvocation(BaseInvocation): noise=noise, seed=seed, mask=mask, + masked_latents=masked_latents, num_inference_steps=num_inference_steps, conditioning_data=conditioning_data, control_data=control_data, # list[ControlNetData] @@ -663,26 +753,11 @@ class ImageToLatentsInvocation(BaseInvocation): tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) - @torch.no_grad() - def invoke(self, context: InvocationContext) -> LatentsOutput: - # image = context.services.images.get( - # self.image.image_type, self.image.image_name - # ) - image = context.services.images.get_pil_image(self.image.image_name) - - # vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) - vae_info = context.services.model_manager.get_model( - **self.vae.vae.dict(), - context=context, - ) - - image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) - if image_tensor.dim() == 3: - image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") - + @staticmethod + def vae_encode(vae_info, upcast, tiled, image_tensor): with vae_info as vae: orig_dtype = vae.dtype - if self.fp32: + if upcast: vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( @@ -707,7 +782,7 @@ class ImageToLatentsInvocation(BaseInvocation): vae.to(dtype=torch.float16) # latents = latents.half() - if self.tiled: + if tiled: vae.enable_tiling() else: vae.disable_tiling() @@ -721,6 +796,23 @@ class ImageToLatentsInvocation(BaseInvocation): latents = vae.config.scaling_factor * latents latents = latents.to(dtype=orig_dtype) + return latents + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> LatentsOutput: + image = context.services.images.get_pil_image(self.image.image_name) + + vae_info = context.services.model_manager.get_model( + **self.vae.vae.dict(), + context=context, + ) + + image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) + if image_tensor.dim() == 3: + image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") + + latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor) + name = f"{context.graph_execution_state_id}__{self.id}" latents = latents.to("cpu") context.services.latents.save(name, latents)