Seamless Patch from Stalker (#4372)

Last commit that didn't get merged in with #4370
This commit is contained in:
blessedcoolant 2023-08-29 08:57:06 +12:00 committed by GitHub
commit 59511783fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,7 +4,7 @@ from contextlib import contextmanager
from typing import List, Union from typing import List, Union
import torch.nn as nn import torch.nn as nn
from diffusers.models import AutoencoderKL, UNet2DModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
def _conv_forward_asymmetric(self, input, weight, bias): def _conv_forward_asymmetric(self, input, weight, bias):
@ -25,12 +25,53 @@ def _conv_forward_asymmetric(self, input, weight, bias):
@contextmanager @contextmanager
def set_seamless(model: Union[UNet2DModel, AutoencoderKL], seamless_axes: List[str]): def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
try: try:
to_restore = [] to_restore = []
for m in model.modules(): for m_name, m in model.named_modules():
if isinstance(model, UNet2DConditionModel):
if ".attentions." in m_name:
continue
if ".resnets." in m_name:
if ".conv2" in m_name:
continue
if ".conv_shortcut" in m_name:
continue
"""
if isinstance(model, UNet2DConditionModel):
if False and ".upsamplers." in m_name:
continue
if False and ".downsamplers." in m_name:
continue
if True and ".resnets." in m_name:
if True and ".conv1" in m_name:
if False and "down_blocks" in m_name:
continue
if False and "mid_block" in m_name:
continue
if False and "up_blocks" in m_name:
continue
if True and ".conv2" in m_name:
continue
if True and ".conv_shortcut" in m_name:
continue
if True and ".attentions." in m_name:
continue
if False and m_name in ["conv_in", "conv_out"]:
continue
"""
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
print(f"applied - {m_name}")
m.asymmetric_padding_mode = {} m.asymmetric_padding_mode = {}
m.asymmetric_padding = {} m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"