diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 2787074265..5fa2068bde 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -62,6 +62,7 @@ from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetEx from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt +from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES @@ -833,6 +834,10 @@ class DenoiseLatentsInvocation(BaseInvocation): if self.unet.freeu_config: ext_manager.add_extension(FreeUExt(self.unet.freeu_config)) + ### seamless + if self.unet.seamless_axes: + ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes)) + # context for loading additional models with ExitStack() as exit_stack: # later should be smth like: diff --git a/invokeai/backend/stable_diffusion/extensions/seamless.py b/invokeai/backend/stable_diffusion/extensions/seamless.py new file mode 100644 index 0000000000..3e303bc31b --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/seamless.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import Callable, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from diffusers import UNet2DConditionModel +from diffusers.models.lora import LoRACompatibleConv + +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase + + +class SeamlessExt(ExtensionBase): + def __init__( + self, + seamless_axes: List[str], + ): + super().__init__() + self._seamless_axes = seamless_axes + + @contextmanager + def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): + with self.static_patch_model( + model=unet, + seamless_axes=self._seamless_axes, + ): + yield + + @staticmethod + @contextmanager + def static_patch_model( + model: torch.nn.Module, + seamless_axes: List[str], + ): + if not seamless_axes: + yield + return + + # override conv_forward + # https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019 + def _conv_forward_asymmetric( + self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None + ): + self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0) + self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3]) + working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode) + working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode) + return torch.nn.functional.conv2d( + working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups + ) + + original_layers: List[Tuple[nn.Conv2d, Callable]] = [] + + try: + x_mode = "circular" if "x" in seamless_axes else "constant" + y_mode = "circular" if "y" in seamless_axes else "constant" + + conv_layers: List[torch.nn.Conv2d] = [] + + for module in model.modules(): + if isinstance(module, torch.nn.Conv2d): + conv_layers.append(module) + + for layer in conv_layers: + if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None: + layer.lora_layer = lambda *x: 0 + original_layers.append((layer, layer._conv_forward)) + layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d) + + yield + + finally: + for layer, orig_conv_forward in original_layers: + layer._conv_forward = orig_conv_forward