diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 2787074265..560bc9003c 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -37,9 +37,9 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.lora import LoRAModelRaw -from invokeai.backend.model_manager import BaseModelType +from invokeai.backend.model_manager import BaseModelType, ModelVariantType from invokeai.backend.model_patcher import ModelPatcher -from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless +from invokeai.backend.stable_diffusion import PipelineIntermediateState from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs from invokeai.backend.stable_diffusion.diffusers_pipeline import ( ControlNetData, @@ -60,8 +60,12 @@ from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionB from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt +from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt +from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt +from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt +from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES @@ -498,6 +502,33 @@ class DenoiseLatentsInvocation(BaseInvocation): ) ) + @staticmethod + def parse_t2i_adapter_field( + exit_stack: ExitStack, + context: InvocationContext, + t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]], + ext_manager: ExtensionsManager, + ) -> None: + if t2i_adapters is None: + return + + # Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField. + if isinstance(t2i_adapters, T2IAdapterField): + t2i_adapters = [t2i_adapters] + + for t2i_adapter_field in t2i_adapters: + ext_manager.add_extension( + T2IAdapterExt( + node_context=context, + model_id=t2i_adapter_field.t2i_adapter_model, + image=context.images.get_pil(t2i_adapter_field.image.image_name), + weight=t2i_adapter_field.weight, + begin_step_percent=t2i_adapter_field.begin_step_percent, + end_step_percent=t2i_adapter_field.end_step_percent, + resize_mode=t2i_adapter_field.resize_mode, + ) + ) + def prep_ip_adapter_image_prompts( self, context: InvocationContext, @@ -707,7 +738,7 @@ class DenoiseLatentsInvocation(BaseInvocation): else: masked_latents = torch.where(mask < 0.5, 0.0, latents) - return 1 - mask, masked_latents, self.denoise_mask.gradient + return mask, masked_latents, self.denoise_mask.gradient @staticmethod def prepare_noise_and_latents( @@ -765,10 +796,6 @@ class DenoiseLatentsInvocation(BaseInvocation): dtype = TorchDevice.choose_torch_dtype() seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) - latents = latents.to(device=device, dtype=dtype) - if noise is not None: - noise = noise.to(device=device, dtype=dtype) - _, _, latent_height, latent_width = latents.shape conditioning_data = self.get_conditioning_data( @@ -801,21 +828,6 @@ class DenoiseLatentsInvocation(BaseInvocation): denoising_end=self.denoising_end, ) - denoise_ctx = DenoiseContext( - inputs=DenoiseInputs( - orig_latents=latents, - timesteps=timesteps, - init_timestep=init_timestep, - noise=noise, - seed=seed, - scheduler_step_kwargs=scheduler_step_kwargs, - conditioning_data=conditioning_data, - attention_processor_cls=CustomAttnProcessor2_0, - ), - unet=None, - scheduler=scheduler, - ) - # get the unet's config so that we can pass the base to sd_step_callback() unet_config = context.models.get_config(self.unet.unet.key) @@ -833,6 +845,40 @@ class DenoiseLatentsInvocation(BaseInvocation): if self.unet.freeu_config: ext_manager.add_extension(FreeUExt(self.unet.freeu_config)) + ### seamless + if self.unet.seamless_axes: + ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes)) + + ### inpaint + mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents) + # NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we + # use the ModelVariantType config. During testing, there was a report of a user with models that had an + # incorrect ModelVariantType value. Re-installing the model fixed the issue. If this issue turns out to be + # prevalent, we will have to revisit how we initialize the inpainting extensions. + if unet_config.variant == ModelVariantType.Inpaint: + ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask)) + elif mask is not None: + ext_manager.add_extension(InpaintExt(mask, is_gradient_mask)) + + # Initialize context for modular denoise + latents = latents.to(device=device, dtype=dtype) + if noise is not None: + noise = noise.to(device=device, dtype=dtype) + denoise_ctx = DenoiseContext( + inputs=DenoiseInputs( + orig_latents=latents, + timesteps=timesteps, + init_timestep=init_timestep, + noise=noise, + seed=seed, + scheduler_step_kwargs=scheduler_step_kwargs, + conditioning_data=conditioning_data, + attention_processor_cls=CustomAttnProcessor2_0, + ), + unet=None, + scheduler=scheduler, + ) + # context for loading additional models with ExitStack() as exit_stack: # later should be smth like: @@ -840,6 +886,7 @@ class DenoiseLatentsInvocation(BaseInvocation): # ext = extension_field.to_extension(exit_stack, context, ext_manager) # ext_manager.add_extension(ext) self.parse_controlnet_field(exit_stack, context, self.control, ext_manager) + self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager) # ext: t2i/ip adapter ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) @@ -871,6 +918,10 @@ class DenoiseLatentsInvocation(BaseInvocation): seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) + # At this point, the mask ranges from 0 (leave unchanged) to 1 (inpaint). + # We invert the mask here for compatibility with the old backend implementation. + if mask is not None: + mask = 1 - mask # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets, # below. Investigate whether this is appropriate. @@ -915,7 +966,7 @@ class DenoiseLatentsInvocation(BaseInvocation): ExitStack() as exit_stack, unet_info.model_on_device() as (model_state_dict, unet), ModelPatcher.apply_freeu(unet, self.unet.freeu_config), - set_seamless(unet, self.unet.seamless_axes), # FIXME + SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet( unet, diff --git a/invokeai/app/invocations/latents_to_image.py b/invokeai/app/invocations/latents_to_image.py index cc8a9c44a3..35b8483f2c 100644 --- a/invokeai/app/invocations/latents_to_image.py +++ b/invokeai/app/invocations/latents_to_image.py @@ -24,7 +24,7 @@ from invokeai.app.invocations.fields import ( from invokeai.app.invocations.model import VAEField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.stable_diffusion import set_seamless +from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params from invokeai.backend.util.devices import TorchDevice @@ -59,7 +59,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): vae_info = context.models.load(self.vae.vae) assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)) - with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: + with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae: assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)) latents = latents.to(vae.device) if self.fp32: diff --git a/invokeai/backend/stable_diffusion/__init__.py b/invokeai/backend/stable_diffusion/__init__.py index 440cb4410b..6a6f2ebc49 100644 --- a/invokeai/backend/stable_diffusion/__init__.py +++ b/invokeai/backend/stable_diffusion/__init__.py @@ -7,11 +7,9 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import ( # noqa: F401 StableDiffusionGeneratorPipeline, ) from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401 -from invokeai.backend.stable_diffusion.seamless import set_seamless # noqa: F401 __all__ = [ "PipelineIntermediateState", "StableDiffusionGeneratorPipeline", "InvokeAIDiffuserComponent", - "set_seamless", ] diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint.py b/invokeai/backend/stable_diffusion/extensions/inpaint.py new file mode 100644 index 0000000000..0079359155 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/inpaint.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import einops +import torch +from diffusers import UNet2DConditionModel + +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + + +class InpaintExt(ExtensionBase): + """An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting + models. + """ + + def __init__( + self, + mask: torch.Tensor, + is_gradient_mask: bool, + ): + """Initialize InpaintExt. + Args: + mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are + expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be + inpainted. + is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range + from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or + 1. + """ + super().__init__() + self._mask = mask + self._is_gradient_mask = is_gradient_mask + + # Noise, which used to noisify unmasked part of image + # if noise provided to context, then it will be used + # if no noise provided, then noise will be generated based on seed + self._noise: Optional[torch.Tensor] = None + + @staticmethod + def _is_normal_model(unet: UNet2DConditionModel): + """Checks if the provided UNet belongs to a regular model. + The `in_channels` of a UNet vary depending on model type: + - normal - 4 + - depth - 5 + - inpaint - 9 + """ + return unet.conv_in.in_channels == 4 + + def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + batch_size = latents.size(0) + mask = einops.repeat(self._mask, "b c h w -> (repeat b) c h w", repeat=batch_size) + if t.dim() == 0: + # some schedulers expect t to be one-dimensional. + # TODO: file diffusers bug about inconsistency? + t = einops.repeat(t, "-> batch", batch=batch_size) + # Noise shouldn't be re-randomized between steps here. The multistep schedulers + # get very confused about what is happening from step to step when we do that. + mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self._noise, t) + # TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already? + # mask_latents = self.scheduler.scale_model_input(mask_latents, t) + mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) + if self._is_gradient_mask: + threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps + mask_bool = mask < 1 - threshold + masked_input = torch.where(mask_bool, latents, mask_latents) + else: + masked_input = torch.lerp(latents, mask_latents.to(dtype=latents.dtype), mask.to(dtype=latents.dtype)) + return masked_input + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def init_tensors(self, ctx: DenoiseContext): + if not self._is_normal_model(ctx.unet): + raise ValueError( + "InpaintExt should be used only on normal (non-inpainting) models. This could be caused by an " + "inpainting model that was incorrectly marked as a non-inpainting model. In some cases, this can be " + "fixed by removing and re-adding the model (so that it gets re-probed)." + ) + + self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype) + + self._noise = ctx.inputs.noise + # 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner). + # We still need noise for inpainting, so we generate it from the seed here. + if self._noise is None: + self._noise = torch.randn( + ctx.latents.shape, + dtype=torch.float32, + device="cpu", + generator=torch.Generator(device="cpu").manual_seed(ctx.seed), + ).to(device=ctx.latents.device, dtype=ctx.latents.dtype) + + # Use negative order to make extensions with default order work with patched latents + @callback(ExtensionCallbackType.PRE_STEP, order=-100) + def apply_mask_to_initial_latents(self, ctx: DenoiseContext): + ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep) + + # TODO: redo this with preview events rewrite + # Use negative order to make extensions with default order work with patched latents + @callback(ExtensionCallbackType.POST_STEP, order=-100) + def apply_mask_to_step_output(self, ctx: DenoiseContext): + timestep = ctx.scheduler.timesteps[-1] + if hasattr(ctx.step_output, "denoised"): + ctx.step_output.denoised = self._apply_mask(ctx, ctx.step_output.denoised, timestep) + elif hasattr(ctx.step_output, "pred_original_sample"): + ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.pred_original_sample, timestep) + else: + ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep) + + # Restore unmasked part after the last step is completed + @callback(ExtensionCallbackType.POST_DENOISE_LOOP) + def restore_unmasked(self, ctx: DenoiseContext): + if self._is_gradient_mask: + ctx.latents = torch.where(self._mask < 1, ctx.latents, ctx.inputs.orig_latents) + else: + ctx.latents = torch.lerp(ctx.latents, ctx.inputs.orig_latents, self._mask) diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py new file mode 100644 index 0000000000..6ee8ef6311 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import torch +from diffusers import UNet2DConditionModel + +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + + +class InpaintModelExt(ExtensionBase): + """An extension for inpainting with inpainting models. See `InpaintExt` for inpainting with non-inpainting + models. + """ + + def __init__( + self, + mask: Optional[torch.Tensor], + masked_latents: Optional[torch.Tensor], + is_gradient_mask: bool, + ): + """Initialize InpaintModelExt. + Args: + mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are + expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be + inpainted. + masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area. + If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width) + is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range + from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or + 1. + """ + super().__init__() + if mask is not None and masked_latents is None: + raise ValueError("Source image required for inpaint mask when inpaint model used!") + + # Inverse mask, because inpaint models treat mask as: 0 - remain same, 1 - inpaint + self._mask = None + if mask is not None: + self._mask = 1 - mask + self._masked_latents = masked_latents + self._is_gradient_mask = is_gradient_mask + + @staticmethod + def _is_inpaint_model(unet: UNet2DConditionModel): + """Checks if the provided UNet belongs to a regular model. + The `in_channels` of a UNet vary depending on model type: + - normal - 4 + - depth - 5 + - inpaint - 9 + """ + return unet.conv_in.in_channels == 9 + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def init_tensors(self, ctx: DenoiseContext): + if not self._is_inpaint_model(ctx.unet): + raise ValueError("InpaintModelExt should be used only on inpaint models!") + + if self._mask is None: + self._mask = torch.ones_like(ctx.latents[:1, :1]) + self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype) + + if self._masked_latents is None: + self._masked_latents = torch.zeros_like(ctx.latents[:1]) + self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype) + + # Do last so that other extensions works with normal latents + @callback(ExtensionCallbackType.PRE_UNET, order=1000) + def append_inpaint_layers(self, ctx: DenoiseContext): + batch_size = ctx.unet_kwargs.sample.shape[0] + b_mask = torch.cat([self._mask] * batch_size) + b_masked_latents = torch.cat([self._masked_latents] * batch_size) + ctx.unet_kwargs.sample = torch.cat( + [ctx.unet_kwargs.sample, b_mask, b_masked_latents], + dim=1, + ) + + # Restore unmasked part as inpaint model can change unmasked part slightly + @callback(ExtensionCallbackType.POST_DENOISE_LOOP) + def restore_unmasked(self, ctx: DenoiseContext): + if self._is_gradient_mask: + ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents) + else: + ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask) diff --git a/invokeai/backend/stable_diffusion/extensions/seamless.py b/invokeai/backend/stable_diffusion/extensions/seamless.py new file mode 100644 index 0000000000..a96ea6e4d2 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/seamless.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import Callable, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from diffusers import UNet2DConditionModel +from diffusers.models.lora import LoRACompatibleConv + +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase + + +class SeamlessExt(ExtensionBase): + def __init__( + self, + seamless_axes: List[str], + ): + super().__init__() + self._seamless_axes = seamless_axes + + @contextmanager + def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): + with self.static_patch_model( + model=unet, + seamless_axes=self._seamless_axes, + ): + yield + + @staticmethod + @contextmanager + def static_patch_model( + model: torch.nn.Module, + seamless_axes: List[str], + ): + if not seamless_axes: + yield + return + + x_mode = "circular" if "x" in seamless_axes else "constant" + y_mode = "circular" if "y" in seamless_axes else "constant" + + # override conv_forward + # https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019 + def _conv_forward_asymmetric( + self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None + ): + self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0) + self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3]) + working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode) + working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode) + return torch.nn.functional.conv2d( + working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups + ) + + original_layers: List[Tuple[nn.Conv2d, Callable]] = [] + try: + for layer in model.modules(): + if not isinstance(layer, torch.nn.Conv2d): + continue + + if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None: + layer.lora_layer = lambda *x: 0 + original_layers.append((layer, layer._conv_forward)) + layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d) + + yield + + finally: + for layer, orig_conv_forward in original_layers: + layer._conv_forward = orig_conv_forward diff --git a/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py b/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py new file mode 100644 index 0000000000..5c290ea4e7 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, List, Optional, Union + +import torch +from diffusers import T2IAdapter +from PIL.Image import Image + +from invokeai.app.util.controlnet_utils import prepare_control_image +from invokeai.backend.model_manager import BaseModelType +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback + +if TYPE_CHECKING: + from invokeai.app.invocations.model import ModelIdentifierField + from invokeai.app.services.shared.invocation_context import InvocationContext + from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + + +class T2IAdapterExt(ExtensionBase): + def __init__( + self, + node_context: InvocationContext, + model_id: ModelIdentifierField, + image: Image, + weight: Union[float, List[float]], + begin_step_percent: float, + end_step_percent: float, + resize_mode: CONTROLNET_RESIZE_VALUES, + ): + super().__init__() + self._node_context = node_context + self._model_id = model_id + self._image = image + self._weight = weight + self._resize_mode = resize_mode + self._begin_step_percent = begin_step_percent + self._end_step_percent = end_step_percent + + self._adapter_state: Optional[List[torch.Tensor]] = None + + # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. + model_config = self._node_context.models.get_config(self._model_id.key) + if model_config.base == BaseModelType.StableDiffusion1: + self._max_unet_downscale = 8 + elif model_config.base == BaseModelType.StableDiffusionXL: + self._max_unet_downscale = 4 + else: + raise ValueError(f"Unexpected T2I-Adapter base model type: '{model_config.base}'.") + + @callback(ExtensionCallbackType.SETUP) + def setup(self, ctx: DenoiseContext): + t2i_model: T2IAdapter + with self._node_context.models.load(self._model_id) as t2i_model: + _, _, latents_height, latents_width = ctx.inputs.orig_latents.shape + + self._adapter_state = self._run_model( + model=t2i_model, + image=self._image, + latents_height=latents_height, + latents_width=latents_width, + ) + + def _run_model( + self, + model: T2IAdapter, + image: Image, + latents_height: int, + latents_width: int, + ): + # Resize the T2I-Adapter input image. + # We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the + # result will match the latent image's dimensions after max_unet_downscale is applied. + input_height = latents_height // self._max_unet_downscale * model.total_downscale_factor + input_width = latents_width // self._max_unet_downscale * model.total_downscale_factor + + # Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare + # a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the + # T2I-Adapter model. + # + # Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many + # of the same requirements (e.g. preserving binary masks during resize). + t2i_image = prepare_control_image( + image=image, + do_classifier_free_guidance=False, + width=input_width, + height=input_height, + num_channels=model.config["in_channels"], + device=model.device, + dtype=model.dtype, + resize_mode=self._resize_mode, + ) + + return model(t2i_image) + + @callback(ExtensionCallbackType.PRE_UNET) + def pre_unet_step(self, ctx: DenoiseContext): + # skip if model not active in current step + total_steps = len(ctx.inputs.timesteps) + first_step = math.floor(self._begin_step_percent * total_steps) + last_step = math.ceil(self._end_step_percent * total_steps) + if ctx.step_index < first_step or ctx.step_index > last_step: + return + + weight = self._weight + if isinstance(weight, list): + weight = weight[ctx.step_index] + + adapter_state = self._adapter_state + if ctx.conditioning_mode == ConditioningMode.Both: + adapter_state = [torch.cat([v] * 2) for v in adapter_state] + + if ctx.unet_kwargs.down_intrablock_additional_residuals is None: + ctx.unet_kwargs.down_intrablock_additional_residuals = [v * weight for v in adapter_state] + else: + for i, value in enumerate(adapter_state): + ctx.unet_kwargs.down_intrablock_additional_residuals[i] += value * weight diff --git a/invokeai/backend/stable_diffusion/seamless.py b/invokeai/backend/stable_diffusion/seamless.py deleted file mode 100644 index 23ed978c6d..0000000000 --- a/invokeai/backend/stable_diffusion/seamless.py +++ /dev/null @@ -1,51 +0,0 @@ -from contextlib import contextmanager -from typing import Callable, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL -from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny -from diffusers.models.lora import LoRACompatibleConv -from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel - - -@contextmanager -def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]): - if not seamless_axes: - yield - return - - # override conv_forward - # https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019 - def _conv_forward_asymmetric(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0) - self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3]) - working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode) - working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode) - return torch.nn.functional.conv2d( - working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups - ) - - original_layers: List[Tuple[nn.Conv2d, Callable]] = [] - - try: - x_mode = "circular" if "x" in seamless_axes else "constant" - y_mode = "circular" if "y" in seamless_axes else "constant" - - conv_layers: List[torch.nn.Conv2d] = [] - - for module in model.modules(): - if isinstance(module, torch.nn.Conv2d): - conv_layers.append(module) - - for layer in conv_layers: - if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None: - layer.lora_layer = lambda *x: 0 - original_layers.append((layer, layer._conv_forward)) - layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d) - - yield - - finally: - for layer, orig_conv_forward in original_layers: - layer._conv_forward = orig_conv_forward diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 659df78d9b..3300f7c7fa 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -31,7 +31,8 @@ "deleteBoard": "Delete Board", "deleteBoardAndImages": "Delete Board and Images", "deleteBoardOnly": "Delete Board Only", - "deletedBoardsCannotbeRestored": "Deleted boards cannot be restored", + "deletedBoardsCannotbeRestored": "Deleted boards cannot be restored. Selecting 'Delete Board Only' will move images to an uncategorized state.", + "deletedPrivateBoardsCannotbeRestored": "Deleted boards cannot be restored. Selecting 'Delete Board Only' will move images to a private uncategorized state for the image's creator.", "hideBoards": "Hide Boards", "loading": "Loading...", "menuItemAutoAdd": "Auto-add to this Board", diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addArchivedOrDeletedBoardListener.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addArchivedOrDeletedBoardListener.ts index 1581da9b37..23d3cbc9af 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addArchivedOrDeletedBoardListener.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addArchivedOrDeletedBoardListener.ts @@ -10,32 +10,32 @@ import { import { boardsApi } from 'services/api/endpoints/boards'; import { imagesApi } from 'services/api/endpoints/images'; +// Type inference doesn't work for this if you inline it in the listener for some reason +const matchAnyBoardDeleted = isAnyOf( + imagesApi.endpoints.deleteBoard.matchFulfilled, + imagesApi.endpoints.deleteBoardAndImages.matchFulfilled +); + export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartListening) => { /** * The auto-add board shouldn't be set to an archived board or deleted board. When we archive a board, delete * a board, or change a the archived board visibility flag, we may need to reset the auto-add board. */ startAppListening({ - matcher: isAnyOf( - // If a board is deleted, we'll need to reset the auto-add board - imagesApi.endpoints.deleteBoard.matchFulfilled, - imagesApi.endpoints.deleteBoardAndImages.matchFulfilled - ), + matcher: matchAnyBoardDeleted, effect: async (action, { dispatch, getState }) => { const state = getState(); - const queryArgs = selectListBoardsQueryArgs(state); - const queryResult = boardsApi.endpoints.listAllBoards.select(queryArgs)(state); + const deletedBoardId = action.meta.arg.originalArgs; const { autoAddBoardId, selectedBoardId } = state.gallery; - if (!queryResult.data) { - return; - } - - if (!queryResult.data.find((board) => board.board_id === selectedBoardId)) { + // If the deleted board was currently selected, we should reset the selected board to uncategorized + if (deletedBoardId === selectedBoardId) { dispatch(boardIdSelected({ boardId: 'none' })); dispatch(galleryViewChanged('images')); } - if (!queryResult.data.find((board) => board.board_id === autoAddBoardId)) { + + // If the deleted board was selected for auto-add, we should reset the auto-add board to uncategorized + if (deletedBoardId === autoAddBoardId) { dispatch(autoAddBoardIdChanged('none')); } }, @@ -46,14 +46,8 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis matcher: boardsApi.endpoints.updateBoard.matchFulfilled, effect: async (action, { dispatch, getState }) => { const state = getState(); - const queryArgs = selectListBoardsQueryArgs(state); - const queryResult = boardsApi.endpoints.listAllBoards.select(queryArgs)(state); const { shouldShowArchivedBoards } = state.gallery; - if (!queryResult.data) { - return; - } - const wasArchived = action.meta.arg.originalArgs.changes.archived === true; if (wasArchived && !shouldShowArchivedBoards) { @@ -71,7 +65,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis const shouldShowArchivedBoards = action.payload; // We only need to take action if we have just hidden archived boards. - if (!shouldShowArchivedBoards) { + if (shouldShowArchivedBoards) { return; } @@ -86,14 +80,16 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis // Handle the case where selected board is archived const selectedBoard = queryResult.data.find((b) => b.board_id === selectedBoardId); - if (selectedBoard && selectedBoard.archived) { + if (!selectedBoard || selectedBoard.archived) { + // If we can't find the selected board or it's archived, we should reset the selected board to uncategorized dispatch(boardIdSelected({ boardId: 'none' })); dispatch(galleryViewChanged('images')); } // Handle the case where auto-add board is archived const autoAddBoard = queryResult.data.find((b) => b.board_id === autoAddBoardId); - if (autoAddBoard && autoAddBoard.archived) { + if (!autoAddBoard || autoAddBoard.archived) { + // If we can't find the auto-add board or it's archived, we should reset the selected board to uncategorized dispatch(autoAddBoardIdChanged('none')); } }, diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/DeleteBoardModal.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/DeleteBoardModal.tsx index 377636d0d0..3707c24440 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/DeleteBoardModal.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/DeleteBoardModal.tsx @@ -120,7 +120,11 @@ const DeleteBoardModal = (props: Props) => { bottomMessage={t('boards.bottomMessage')} /> )} - {t('boards.deletedBoardsCannotbeRestored')} + + {boardToDelete.is_private + ? t('boards.deletedPrivateBoardsCannotbeRestored') + : t('boards.deletedBoardsCannotbeRestored')} + {canRestoreDeletedImagesFromBin ? t('gallery.deleteImageBin') : t('gallery.deleteImagePermanent')}