Quick Seamless Fixes (#5685)

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ X ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [ X ] No, because: It's small

      
## Have you updated all relevant documentation?
- [ ] Yes
- [ X ] No


## Description
This pulls out some of the updates from the WIP Seamless branch that has
yet to be completed, and hardcodes values that are exposed in that
branch. Given that seamless currently does not generate seamless
textures, and this fix results in seamless outputs, it's an improvement
even if it doesn't resolve this in a "perfect" way that exposes all
variables to the end user.

better over perfect.


![f07b7e49-80c2-4659-bb36-d50ec80b1f8b](https://github.com/invoke-ai/InvokeAI/assets/31807370/36a40bd9-8fc4-41d5-bd1e-209fc828987e)
This commit is contained in:
Millun Atluri 2024-02-13 11:08:07 -07:00 committed by GitHub
commit 8bd65be8c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,10 +1,11 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import List, Union
from typing import Callable, List, Union
import torch.nn as nn
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
def _conv_forward_asymmetric(self, input, weight, bias):
@ -26,70 +27,50 @@ def _conv_forward_asymmetric(self, input, weight, bias):
@contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
try:
to_restore = []
for m_name, m in model.named_modules():
if isinstance(model, UNet2DConditionModel):
if ".attentions." in m_name:
if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
continue
if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name:
# down_blocks.1.resnets.1.conv1
_, block_num, _, resnet_num, submodule_name = m_name.split(".")
block_num = int(block_num)
resnet_num = int(resnet_num)
# Could be configurable to allow skipping arbitrary numbers of down blocks
if block_num >= len(model.down_blocks):
continue
if ".resnets." in m_name:
if ".conv2" in m_name:
continue
if ".conv_shortcut" in m_name:
continue
"""
if isinstance(model, UNet2DConditionModel):
if False and ".upsamplers." in m_name:
# Skip the second resnet (could be configurable)
if resnet_num > 0:
continue
if False and ".downsamplers." in m_name:
# Skip Conv2d layers (could be configurable)
if submodule_name == "conv2":
continue
if True and ".resnets." in m_name:
if True and ".conv1" in m_name:
if False and "down_blocks" in m_name:
continue
if False and "mid_block" in m_name:
continue
if False and "up_blocks" in m_name:
continue
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
if True and ".conv2" in m_name:
continue
if True and ".conv_shortcut" in m_name:
continue
if True and ".attentions." in m_name:
continue
if False and m_name in ["conv_in", "conv_out"]:
continue
"""
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
yield