diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py
index 2787074265..560bc9003c 100644
--- a/invokeai/app/invocations/denoise_latents.py
+++ b/invokeai/app/invocations/denoise_latents.py
@@ -37,9 +37,9 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
-from invokeai.backend.model_manager import BaseModelType
+from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
-from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
+from invokeai.backend.stable_diffusion import PipelineIntermediateState
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
@@ -60,8 +60,12 @@ from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionB
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
+from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
+from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
+from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
+from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@@ -498,6 +502,33 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
)
+ @staticmethod
+ def parse_t2i_adapter_field(
+ exit_stack: ExitStack,
+ context: InvocationContext,
+ t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
+ ext_manager: ExtensionsManager,
+ ) -> None:
+ if t2i_adapters is None:
+ return
+
+ # Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField.
+ if isinstance(t2i_adapters, T2IAdapterField):
+ t2i_adapters = [t2i_adapters]
+
+ for t2i_adapter_field in t2i_adapters:
+ ext_manager.add_extension(
+ T2IAdapterExt(
+ node_context=context,
+ model_id=t2i_adapter_field.t2i_adapter_model,
+ image=context.images.get_pil(t2i_adapter_field.image.image_name),
+ weight=t2i_adapter_field.weight,
+ begin_step_percent=t2i_adapter_field.begin_step_percent,
+ end_step_percent=t2i_adapter_field.end_step_percent,
+ resize_mode=t2i_adapter_field.resize_mode,
+ )
+ )
+
def prep_ip_adapter_image_prompts(
self,
context: InvocationContext,
@@ -707,7 +738,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
else:
masked_latents = torch.where(mask < 0.5, 0.0, latents)
- return 1 - mask, masked_latents, self.denoise_mask.gradient
+ return mask, masked_latents, self.denoise_mask.gradient
@staticmethod
def prepare_noise_and_latents(
@@ -765,10 +796,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
dtype = TorchDevice.choose_torch_dtype()
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
- latents = latents.to(device=device, dtype=dtype)
- if noise is not None:
- noise = noise.to(device=device, dtype=dtype)
-
_, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
@@ -801,21 +828,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_end=self.denoising_end,
)
- denoise_ctx = DenoiseContext(
- inputs=DenoiseInputs(
- orig_latents=latents,
- timesteps=timesteps,
- init_timestep=init_timestep,
- noise=noise,
- seed=seed,
- scheduler_step_kwargs=scheduler_step_kwargs,
- conditioning_data=conditioning_data,
- attention_processor_cls=CustomAttnProcessor2_0,
- ),
- unet=None,
- scheduler=scheduler,
- )
-
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
@@ -833,6 +845,40 @@ class DenoiseLatentsInvocation(BaseInvocation):
if self.unet.freeu_config:
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
+ ### seamless
+ if self.unet.seamless_axes:
+ ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
+
+ ### inpaint
+ mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
+ # NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
+ # use the ModelVariantType config. During testing, there was a report of a user with models that had an
+ # incorrect ModelVariantType value. Re-installing the model fixed the issue. If this issue turns out to be
+ # prevalent, we will have to revisit how we initialize the inpainting extensions.
+ if unet_config.variant == ModelVariantType.Inpaint:
+ ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
+ elif mask is not None:
+ ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
+
+ # Initialize context for modular denoise
+ latents = latents.to(device=device, dtype=dtype)
+ if noise is not None:
+ noise = noise.to(device=device, dtype=dtype)
+ denoise_ctx = DenoiseContext(
+ inputs=DenoiseInputs(
+ orig_latents=latents,
+ timesteps=timesteps,
+ init_timestep=init_timestep,
+ noise=noise,
+ seed=seed,
+ scheduler_step_kwargs=scheduler_step_kwargs,
+ conditioning_data=conditioning_data,
+ attention_processor_cls=CustomAttnProcessor2_0,
+ ),
+ unet=None,
+ scheduler=scheduler,
+ )
+
# context for loading additional models
with ExitStack() as exit_stack:
# later should be smth like:
@@ -840,6 +886,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
# ext_manager.add_extension(ext)
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
+ self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager)
# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
@@ -871,6 +918,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
+ # At this point, the mask ranges from 0 (leave unchanged) to 1 (inpaint).
+ # We invert the mask here for compatibility with the old backend implementation.
+ if mask is not None:
+ mask = 1 - mask
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
# below. Investigate whether this is appropriate.
@@ -915,7 +966,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
ExitStack() as exit_stack,
unet_info.model_on_device() as (model_state_dict, unet),
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
- set_seamless(unet, self.unet.seamless_axes), # FIXME
+ SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(
unet,
diff --git a/invokeai/app/invocations/latents_to_image.py b/invokeai/app/invocations/latents_to_image.py
index cc8a9c44a3..35b8483f2c 100644
--- a/invokeai/app/invocations/latents_to_image.py
+++ b/invokeai/app/invocations/latents_to_image.py
@@ -24,7 +24,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
-from invokeai.backend.stable_diffusion import set_seamless
+from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
@@ -59,7 +59,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
- with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
+ with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
if self.fp32:
diff --git a/invokeai/backend/stable_diffusion/__init__.py b/invokeai/backend/stable_diffusion/__init__.py
index 440cb4410b..6a6f2ebc49 100644
--- a/invokeai/backend/stable_diffusion/__init__.py
+++ b/invokeai/backend/stable_diffusion/__init__.py
@@ -7,11 +7,9 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import ( # noqa: F401
StableDiffusionGeneratorPipeline,
)
from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401
-from invokeai.backend.stable_diffusion.seamless import set_seamless # noqa: F401
__all__ = [
"PipelineIntermediateState",
"StableDiffusionGeneratorPipeline",
"InvokeAIDiffuserComponent",
- "set_seamless",
]
diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint.py b/invokeai/backend/stable_diffusion/extensions/inpaint.py
new file mode 100644
index 0000000000..0079359155
--- /dev/null
+++ b/invokeai/backend/stable_diffusion/extensions/inpaint.py
@@ -0,0 +1,120 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Optional
+
+import einops
+import torch
+from diffusers import UNet2DConditionModel
+
+from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
+from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
+
+if TYPE_CHECKING:
+ from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
+
+
+class InpaintExt(ExtensionBase):
+ """An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting
+ models.
+ """
+
+ def __init__(
+ self,
+ mask: torch.Tensor,
+ is_gradient_mask: bool,
+ ):
+ """Initialize InpaintExt.
+ Args:
+ mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
+ expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
+ inpainted.
+ is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
+ from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
+ 1.
+ """
+ super().__init__()
+ self._mask = mask
+ self._is_gradient_mask = is_gradient_mask
+
+ # Noise, which used to noisify unmasked part of image
+ # if noise provided to context, then it will be used
+ # if no noise provided, then noise will be generated based on seed
+ self._noise: Optional[torch.Tensor] = None
+
+ @staticmethod
+ def _is_normal_model(unet: UNet2DConditionModel):
+ """Checks if the provided UNet belongs to a regular model.
+ The `in_channels` of a UNet vary depending on model type:
+ - normal - 4
+ - depth - 5
+ - inpaint - 9
+ """
+ return unet.conv_in.in_channels == 4
+
+ def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
+ batch_size = latents.size(0)
+ mask = einops.repeat(self._mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
+ if t.dim() == 0:
+ # some schedulers expect t to be one-dimensional.
+ # TODO: file diffusers bug about inconsistency?
+ t = einops.repeat(t, "-> batch", batch=batch_size)
+ # Noise shouldn't be re-randomized between steps here. The multistep schedulers
+ # get very confused about what is happening from step to step when we do that.
+ mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self._noise, t)
+ # TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
+ # mask_latents = self.scheduler.scale_model_input(mask_latents, t)
+ mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
+ if self._is_gradient_mask:
+ threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps
+ mask_bool = mask < 1 - threshold
+ masked_input = torch.where(mask_bool, latents, mask_latents)
+ else:
+ masked_input = torch.lerp(latents, mask_latents.to(dtype=latents.dtype), mask.to(dtype=latents.dtype))
+ return masked_input
+
+ @callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
+ def init_tensors(self, ctx: DenoiseContext):
+ if not self._is_normal_model(ctx.unet):
+ raise ValueError(
+ "InpaintExt should be used only on normal (non-inpainting) models. This could be caused by an "
+ "inpainting model that was incorrectly marked as a non-inpainting model. In some cases, this can be "
+ "fixed by removing and re-adding the model (so that it gets re-probed)."
+ )
+
+ self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
+
+ self._noise = ctx.inputs.noise
+ # 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
+ # We still need noise for inpainting, so we generate it from the seed here.
+ if self._noise is None:
+ self._noise = torch.randn(
+ ctx.latents.shape,
+ dtype=torch.float32,
+ device="cpu",
+ generator=torch.Generator(device="cpu").manual_seed(ctx.seed),
+ ).to(device=ctx.latents.device, dtype=ctx.latents.dtype)
+
+ # Use negative order to make extensions with default order work with patched latents
+ @callback(ExtensionCallbackType.PRE_STEP, order=-100)
+ def apply_mask_to_initial_latents(self, ctx: DenoiseContext):
+ ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep)
+
+ # TODO: redo this with preview events rewrite
+ # Use negative order to make extensions with default order work with patched latents
+ @callback(ExtensionCallbackType.POST_STEP, order=-100)
+ def apply_mask_to_step_output(self, ctx: DenoiseContext):
+ timestep = ctx.scheduler.timesteps[-1]
+ if hasattr(ctx.step_output, "denoised"):
+ ctx.step_output.denoised = self._apply_mask(ctx, ctx.step_output.denoised, timestep)
+ elif hasattr(ctx.step_output, "pred_original_sample"):
+ ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.pred_original_sample, timestep)
+ else:
+ ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep)
+
+ # Restore unmasked part after the last step is completed
+ @callback(ExtensionCallbackType.POST_DENOISE_LOOP)
+ def restore_unmasked(self, ctx: DenoiseContext):
+ if self._is_gradient_mask:
+ ctx.latents = torch.where(self._mask < 1, ctx.latents, ctx.inputs.orig_latents)
+ else:
+ ctx.latents = torch.lerp(ctx.latents, ctx.inputs.orig_latents, self._mask)
diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py
new file mode 100644
index 0000000000..6ee8ef6311
--- /dev/null
+++ b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py
@@ -0,0 +1,88 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Optional
+
+import torch
+from diffusers import UNet2DConditionModel
+
+from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
+from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
+
+if TYPE_CHECKING:
+ from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
+
+
+class InpaintModelExt(ExtensionBase):
+ """An extension for inpainting with inpainting models. See `InpaintExt` for inpainting with non-inpainting
+ models.
+ """
+
+ def __init__(
+ self,
+ mask: Optional[torch.Tensor],
+ masked_latents: Optional[torch.Tensor],
+ is_gradient_mask: bool,
+ ):
+ """Initialize InpaintModelExt.
+ Args:
+ mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
+ expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
+ inpainted.
+ masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area.
+ If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width)
+ is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
+ from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
+ 1.
+ """
+ super().__init__()
+ if mask is not None and masked_latents is None:
+ raise ValueError("Source image required for inpaint mask when inpaint model used!")
+
+ # Inverse mask, because inpaint models treat mask as: 0 - remain same, 1 - inpaint
+ self._mask = None
+ if mask is not None:
+ self._mask = 1 - mask
+ self._masked_latents = masked_latents
+ self._is_gradient_mask = is_gradient_mask
+
+ @staticmethod
+ def _is_inpaint_model(unet: UNet2DConditionModel):
+ """Checks if the provided UNet belongs to a regular model.
+ The `in_channels` of a UNet vary depending on model type:
+ - normal - 4
+ - depth - 5
+ - inpaint - 9
+ """
+ return unet.conv_in.in_channels == 9
+
+ @callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
+ def init_tensors(self, ctx: DenoiseContext):
+ if not self._is_inpaint_model(ctx.unet):
+ raise ValueError("InpaintModelExt should be used only on inpaint models!")
+
+ if self._mask is None:
+ self._mask = torch.ones_like(ctx.latents[:1, :1])
+ self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
+
+ if self._masked_latents is None:
+ self._masked_latents = torch.zeros_like(ctx.latents[:1])
+ self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
+
+ # Do last so that other extensions works with normal latents
+ @callback(ExtensionCallbackType.PRE_UNET, order=1000)
+ def append_inpaint_layers(self, ctx: DenoiseContext):
+ batch_size = ctx.unet_kwargs.sample.shape[0]
+ b_mask = torch.cat([self._mask] * batch_size)
+ b_masked_latents = torch.cat([self._masked_latents] * batch_size)
+ ctx.unet_kwargs.sample = torch.cat(
+ [ctx.unet_kwargs.sample, b_mask, b_masked_latents],
+ dim=1,
+ )
+
+ # Restore unmasked part as inpaint model can change unmasked part slightly
+ @callback(ExtensionCallbackType.POST_DENOISE_LOOP)
+ def restore_unmasked(self, ctx: DenoiseContext):
+ if self._is_gradient_mask:
+ ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents)
+ else:
+ ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask)
diff --git a/invokeai/backend/stable_diffusion/extensions/seamless.py b/invokeai/backend/stable_diffusion/extensions/seamless.py
new file mode 100644
index 0000000000..a96ea6e4d2
--- /dev/null
+++ b/invokeai/backend/stable_diffusion/extensions/seamless.py
@@ -0,0 +1,71 @@
+from __future__ import annotations
+
+from contextlib import contextmanager
+from typing import Callable, Dict, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from diffusers import UNet2DConditionModel
+from diffusers.models.lora import LoRACompatibleConv
+
+from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
+
+
+class SeamlessExt(ExtensionBase):
+ def __init__(
+ self,
+ seamless_axes: List[str],
+ ):
+ super().__init__()
+ self._seamless_axes = seamless_axes
+
+ @contextmanager
+ def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
+ with self.static_patch_model(
+ model=unet,
+ seamless_axes=self._seamless_axes,
+ ):
+ yield
+
+ @staticmethod
+ @contextmanager
+ def static_patch_model(
+ model: torch.nn.Module,
+ seamless_axes: List[str],
+ ):
+ if not seamless_axes:
+ yield
+ return
+
+ x_mode = "circular" if "x" in seamless_axes else "constant"
+ y_mode = "circular" if "y" in seamless_axes else "constant"
+
+ # override conv_forward
+ # https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
+ def _conv_forward_asymmetric(
+ self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
+ ):
+ self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
+ self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
+ working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
+ working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
+ return torch.nn.functional.conv2d(
+ working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
+ )
+
+ original_layers: List[Tuple[nn.Conv2d, Callable]] = []
+ try:
+ for layer in model.modules():
+ if not isinstance(layer, torch.nn.Conv2d):
+ continue
+
+ if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
+ layer.lora_layer = lambda *x: 0
+ original_layers.append((layer, layer._conv_forward))
+ layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
+
+ yield
+
+ finally:
+ for layer, orig_conv_forward in original_layers:
+ layer._conv_forward = orig_conv_forward
diff --git a/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py b/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py
new file mode 100644
index 0000000000..5c290ea4e7
--- /dev/null
+++ b/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py
@@ -0,0 +1,120 @@
+from __future__ import annotations
+
+import math
+from typing import TYPE_CHECKING, List, Optional, Union
+
+import torch
+from diffusers import T2IAdapter
+from PIL.Image import Image
+
+from invokeai.app.util.controlnet_utils import prepare_control_image
+from invokeai.backend.model_manager import BaseModelType
+from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
+from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
+from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
+
+if TYPE_CHECKING:
+ from invokeai.app.invocations.model import ModelIdentifierField
+ from invokeai.app.services.shared.invocation_context import InvocationContext
+ from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
+ from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
+
+
+class T2IAdapterExt(ExtensionBase):
+ def __init__(
+ self,
+ node_context: InvocationContext,
+ model_id: ModelIdentifierField,
+ image: Image,
+ weight: Union[float, List[float]],
+ begin_step_percent: float,
+ end_step_percent: float,
+ resize_mode: CONTROLNET_RESIZE_VALUES,
+ ):
+ super().__init__()
+ self._node_context = node_context
+ self._model_id = model_id
+ self._image = image
+ self._weight = weight
+ self._resize_mode = resize_mode
+ self._begin_step_percent = begin_step_percent
+ self._end_step_percent = end_step_percent
+
+ self._adapter_state: Optional[List[torch.Tensor]] = None
+
+ # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
+ model_config = self._node_context.models.get_config(self._model_id.key)
+ if model_config.base == BaseModelType.StableDiffusion1:
+ self._max_unet_downscale = 8
+ elif model_config.base == BaseModelType.StableDiffusionXL:
+ self._max_unet_downscale = 4
+ else:
+ raise ValueError(f"Unexpected T2I-Adapter base model type: '{model_config.base}'.")
+
+ @callback(ExtensionCallbackType.SETUP)
+ def setup(self, ctx: DenoiseContext):
+ t2i_model: T2IAdapter
+ with self._node_context.models.load(self._model_id) as t2i_model:
+ _, _, latents_height, latents_width = ctx.inputs.orig_latents.shape
+
+ self._adapter_state = self._run_model(
+ model=t2i_model,
+ image=self._image,
+ latents_height=latents_height,
+ latents_width=latents_width,
+ )
+
+ def _run_model(
+ self,
+ model: T2IAdapter,
+ image: Image,
+ latents_height: int,
+ latents_width: int,
+ ):
+ # Resize the T2I-Adapter input image.
+ # We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
+ # result will match the latent image's dimensions after max_unet_downscale is applied.
+ input_height = latents_height // self._max_unet_downscale * model.total_downscale_factor
+ input_width = latents_width // self._max_unet_downscale * model.total_downscale_factor
+
+ # Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
+ # a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
+ # T2I-Adapter model.
+ #
+ # Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
+ # of the same requirements (e.g. preserving binary masks during resize).
+ t2i_image = prepare_control_image(
+ image=image,
+ do_classifier_free_guidance=False,
+ width=input_width,
+ height=input_height,
+ num_channels=model.config["in_channels"],
+ device=model.device,
+ dtype=model.dtype,
+ resize_mode=self._resize_mode,
+ )
+
+ return model(t2i_image)
+
+ @callback(ExtensionCallbackType.PRE_UNET)
+ def pre_unet_step(self, ctx: DenoiseContext):
+ # skip if model not active in current step
+ total_steps = len(ctx.inputs.timesteps)
+ first_step = math.floor(self._begin_step_percent * total_steps)
+ last_step = math.ceil(self._end_step_percent * total_steps)
+ if ctx.step_index < first_step or ctx.step_index > last_step:
+ return
+
+ weight = self._weight
+ if isinstance(weight, list):
+ weight = weight[ctx.step_index]
+
+ adapter_state = self._adapter_state
+ if ctx.conditioning_mode == ConditioningMode.Both:
+ adapter_state = [torch.cat([v] * 2) for v in adapter_state]
+
+ if ctx.unet_kwargs.down_intrablock_additional_residuals is None:
+ ctx.unet_kwargs.down_intrablock_additional_residuals = [v * weight for v in adapter_state]
+ else:
+ for i, value in enumerate(adapter_state):
+ ctx.unet_kwargs.down_intrablock_additional_residuals[i] += value * weight
diff --git a/invokeai/backend/stable_diffusion/seamless.py b/invokeai/backend/stable_diffusion/seamless.py
deleted file mode 100644
index 23ed978c6d..0000000000
--- a/invokeai/backend/stable_diffusion/seamless.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from contextlib import contextmanager
-from typing import Callable, List, Optional, Tuple, Union
-
-import torch
-import torch.nn as nn
-from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
-from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
-from diffusers.models.lora import LoRACompatibleConv
-from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
-
-
-@contextmanager
-def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
- if not seamless_axes:
- yield
- return
-
- # override conv_forward
- # https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
- def _conv_forward_asymmetric(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
- self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
- self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
- working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
- working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
- return torch.nn.functional.conv2d(
- working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
- )
-
- original_layers: List[Tuple[nn.Conv2d, Callable]] = []
-
- try:
- x_mode = "circular" if "x" in seamless_axes else "constant"
- y_mode = "circular" if "y" in seamless_axes else "constant"
-
- conv_layers: List[torch.nn.Conv2d] = []
-
- for module in model.modules():
- if isinstance(module, torch.nn.Conv2d):
- conv_layers.append(module)
-
- for layer in conv_layers:
- if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
- layer.lora_layer = lambda *x: 0
- original_layers.append((layer, layer._conv_forward))
- layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
-
- yield
-
- finally:
- for layer, orig_conv_forward in original_layers:
- layer._conv_forward = orig_conv_forward
diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index 659df78d9b..3300f7c7fa 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -31,7 +31,8 @@
"deleteBoard": "Delete Board",
"deleteBoardAndImages": "Delete Board and Images",
"deleteBoardOnly": "Delete Board Only",
- "deletedBoardsCannotbeRestored": "Deleted boards cannot be restored",
+ "deletedBoardsCannotbeRestored": "Deleted boards cannot be restored. Selecting 'Delete Board Only' will move images to an uncategorized state.",
+ "deletedPrivateBoardsCannotbeRestored": "Deleted boards cannot be restored. Selecting 'Delete Board Only' will move images to a private uncategorized state for the image's creator.",
"hideBoards": "Hide Boards",
"loading": "Loading...",
"menuItemAutoAdd": "Auto-add to this Board",
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addArchivedOrDeletedBoardListener.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addArchivedOrDeletedBoardListener.ts
index 1581da9b37..23d3cbc9af 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addArchivedOrDeletedBoardListener.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addArchivedOrDeletedBoardListener.ts
@@ -10,32 +10,32 @@ import {
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
+// Type inference doesn't work for this if you inline it in the listener for some reason
+const matchAnyBoardDeleted = isAnyOf(
+ imagesApi.endpoints.deleteBoard.matchFulfilled,
+ imagesApi.endpoints.deleteBoardAndImages.matchFulfilled
+);
+
export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartListening) => {
/**
* The auto-add board shouldn't be set to an archived board or deleted board. When we archive a board, delete
* a board, or change a the archived board visibility flag, we may need to reset the auto-add board.
*/
startAppListening({
- matcher: isAnyOf(
- // If a board is deleted, we'll need to reset the auto-add board
- imagesApi.endpoints.deleteBoard.matchFulfilled,
- imagesApi.endpoints.deleteBoardAndImages.matchFulfilled
- ),
+ matcher: matchAnyBoardDeleted,
effect: async (action, { dispatch, getState }) => {
const state = getState();
- const queryArgs = selectListBoardsQueryArgs(state);
- const queryResult = boardsApi.endpoints.listAllBoards.select(queryArgs)(state);
+ const deletedBoardId = action.meta.arg.originalArgs;
const { autoAddBoardId, selectedBoardId } = state.gallery;
- if (!queryResult.data) {
- return;
- }
-
- if (!queryResult.data.find((board) => board.board_id === selectedBoardId)) {
+ // If the deleted board was currently selected, we should reset the selected board to uncategorized
+ if (deletedBoardId === selectedBoardId) {
dispatch(boardIdSelected({ boardId: 'none' }));
dispatch(galleryViewChanged('images'));
}
- if (!queryResult.data.find((board) => board.board_id === autoAddBoardId)) {
+
+ // If the deleted board was selected for auto-add, we should reset the auto-add board to uncategorized
+ if (deletedBoardId === autoAddBoardId) {
dispatch(autoAddBoardIdChanged('none'));
}
},
@@ -46,14 +46,8 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
matcher: boardsApi.endpoints.updateBoard.matchFulfilled,
effect: async (action, { dispatch, getState }) => {
const state = getState();
- const queryArgs = selectListBoardsQueryArgs(state);
- const queryResult = boardsApi.endpoints.listAllBoards.select(queryArgs)(state);
const { shouldShowArchivedBoards } = state.gallery;
- if (!queryResult.data) {
- return;
- }
-
const wasArchived = action.meta.arg.originalArgs.changes.archived === true;
if (wasArchived && !shouldShowArchivedBoards) {
@@ -71,7 +65,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
const shouldShowArchivedBoards = action.payload;
// We only need to take action if we have just hidden archived boards.
- if (!shouldShowArchivedBoards) {
+ if (shouldShowArchivedBoards) {
return;
}
@@ -86,14 +80,16 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
// Handle the case where selected board is archived
const selectedBoard = queryResult.data.find((b) => b.board_id === selectedBoardId);
- if (selectedBoard && selectedBoard.archived) {
+ if (!selectedBoard || selectedBoard.archived) {
+ // If we can't find the selected board or it's archived, we should reset the selected board to uncategorized
dispatch(boardIdSelected({ boardId: 'none' }));
dispatch(galleryViewChanged('images'));
}
// Handle the case where auto-add board is archived
const autoAddBoard = queryResult.data.find((b) => b.board_id === autoAddBoardId);
- if (autoAddBoard && autoAddBoard.archived) {
+ if (!autoAddBoard || autoAddBoard.archived) {
+ // If we can't find the auto-add board or it's archived, we should reset the selected board to uncategorized
dispatch(autoAddBoardIdChanged('none'));
}
},
diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/DeleteBoardModal.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/DeleteBoardModal.tsx
index 377636d0d0..3707c24440 100644
--- a/invokeai/frontend/web/src/features/gallery/components/Boards/DeleteBoardModal.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/Boards/DeleteBoardModal.tsx
@@ -120,7 +120,11 @@ const DeleteBoardModal = (props: Props) => {
bottomMessage={t('boards.bottomMessage')}
/>
)}
- {t('boards.deletedBoardsCannotbeRestored')}
+
+ {boardToDelete.is_private
+ ? t('boards.deletedPrivateBoardsCannotbeRestored')
+ : t('boards.deletedBoardsCannotbeRestored')}
+
{canRestoreDeletedImagesFromBin ? t('gallery.deleteImageBin') : t('gallery.deleteImagePermanent')}