feat(nodes): seamless.py minor cleanup

This commit is contained in:
psychedelicious 2024-02-13 13:34:06 +11:00
parent c3b2a8cb27
commit 3339ad4df8

View File

@ -1,10 +1,11 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Union from typing import Callable, List, Union
import torch.nn as nn import torch.nn as nn
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
def _conv_forward_asymmetric(self, input, weight, bias): def _conv_forward_asymmetric(self, input, weight, bias):
@ -26,12 +27,9 @@ def _conv_forward_asymmetric(self, input, weight, bias):
@contextmanager @contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]): def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
try: try:
to_restore = []
skipped_layers = 0
skip_second_resnet = True
skip_conv2 = True
for m_name, m in model.named_modules(): for m_name, m in model.named_modules():
if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
continue continue
@ -42,14 +40,16 @@ def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axe
block_num = int(block_num) block_num = int(block_num)
resnet_num = int(resnet_num) resnet_num = int(resnet_num)
# if block_num >= seamless_down_blocks: # Could be configurable to allow skipping arbitrary numbers of down blocks
if block_num >= len(model.down_blocks) - skipped_layers: if block_num >= len(model.down_blocks):
continue continue
if resnet_num > 0 and skip_second_resnet: # Skip the second resnet (could be configurable)
if resnet_num > 0:
continue continue
if submodule_name == "conv2" and skip_conv2: # Skip Conv2d layers (could be configurable)
if submodule_name == "conv2":
continue continue
m.asymmetric_padding_mode = {} m.asymmetric_padding_mode = {}