fix: fix seamless

This commit is contained in:
blessedcoolant 2024-05-10 20:18:54 +05:30 committed by psychedelicious
parent 63e62c5720
commit 6c9fb617dc

View File

@ -1,89 +1,53 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable, List, Union from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn import torch.nn as nn
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from diffusers.models.lora import LoRACompatibleConv
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
def _conv_forward_asymmetric(self, input, weight, bias):
"""
Patch for Conv2d._conv_forward that supports asymmetric padding
"""
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
return nn.functional.conv2d(
working,
weight,
bias,
self.stride,
nn.modules.utils._pair(0),
self.dilation,
self.groups,
)
@contextmanager @contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]): def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
if not seamless_axes: if not seamless_axes:
yield yield
return return
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor # override conv_forward
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = [] # https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
def _conv_forward_asymmetric(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
return torch.nn.functional.conv2d(
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
)
original_layers: List[Tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
try: try:
# Hard coded to skip down block layers, allowing for seamless tiling at the expense of prompt adherence x_mode = "circular" if "x" in seamless_axes else "constant"
skipped_layers = 1 y_mode = "circular" if "y" in seamless_axes else "constant"
for m_name, m in model.named_modules():
if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
continue
if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name: conv_layers: List[torch.nn.Conv2d] = []
# down_blocks.1.resnets.1.conv1
_, block_num, _, resnet_num, submodule_name = m_name.split(".")
block_num = int(block_num)
resnet_num = int(resnet_num)
if block_num >= len(model.down_blocks) - skipped_layers: for module in model.modules():
continue if isinstance(module, torch.nn.Conv2d):
conv_layers.append(module)
# Skip the second resnet (could be configurable) for layer in conv_layers:
if resnet_num > 0: if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
continue layer.lora_layer = lambda *x: 0
original_layers.append((layer, layer._conv_forward))
# Skip Conv2d layers (could be configurable) layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
if submodule_name == "conv2":
continue
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
yield yield
finally: finally:
for module, orig_conv_forward in to_restore: for layer, orig_conv_forward in original_layers:
module._conv_forward = orig_conv_forward layer._conv_forward = orig_conv_forward
if hasattr(module, "asymmetric_padding_mode"):
del module.asymmetric_padding_mode
if hasattr(module, "asymmetric_padding"):
del module.asymmetric_padding