Seamless Patch from Stalker

This commit is contained in:
Kent Keirsey 2023-08-28 15:46:49 -04:00
parent d1efabaf2f
commit 2a1d7342a7

View File

@ -4,7 +4,7 @@ from contextlib import contextmanager
from typing import List, Union from typing import List, Union
import torch.nn as nn import torch.nn as nn
from diffusers.models import AutoencoderKL, UNet2DModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
def _conv_forward_asymmetric(self, input, weight, bias): def _conv_forward_asymmetric(self, input, weight, bias):
@ -25,12 +25,56 @@ def _conv_forward_asymmetric(self, input, weight, bias):
@contextmanager @contextmanager
def set_seamless(model: Union[UNet2DModel, AutoencoderKL], seamless_axes: List[str]): def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
try: try:
to_restore = [] to_restore = []
for m in model.modules(): #print("try seamless")
for m_name, m in model.named_modules():
if isinstance(model, UNet2DConditionModel):
if ".attentions." in m_name:
continue
if ".resnets." in m_name:
if ".conv2" in m_name:
continue
if ".conv_shortcut" in m_name:
continue
"""
if isinstance(model, UNet2DConditionModel):
if False and ".upsamplers." in m_name:
continue
if False and ".downsamplers." in m_name:
continue
if True and ".resnets." in m_name:
if True and ".conv1" in m_name:
if False and "down_blocks" in m_name:
continue
if False and "mid_block" in m_name:
continue
if False and "up_blocks" in m_name:
continue
if True and ".conv2" in m_name:
continue
if True and ".conv_shortcut" in m_name:
continue
if True and ".attentions." in m_name:
continue
if False and m_name in ["conv_in", "conv_out"]:
continue
"""
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
print(f"applied - {m_name}")
m.asymmetric_padding_mode = {} m.asymmetric_padding_mode = {}
m.asymmetric_padding = {} m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"