mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): seamless.py minor cleanup
This commit is contained in:
parent
c3b2a8cb27
commit
3339ad4df8
@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Union
|
||||
from typing import Callable, List, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
|
||||
def _conv_forward_asymmetric(self, input, weight, bias):
|
||||
@ -26,12 +27,9 @@ def _conv_forward_asymmetric(self, input, weight, bias):
|
||||
|
||||
@contextmanager
|
||||
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
|
||||
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor
|
||||
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
|
||||
try:
|
||||
to_restore = []
|
||||
skipped_layers = 0
|
||||
skip_second_resnet = True
|
||||
skip_conv2 = True
|
||||
|
||||
for m_name, m in model.named_modules():
|
||||
if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
continue
|
||||
@ -42,14 +40,16 @@ def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axe
|
||||
block_num = int(block_num)
|
||||
resnet_num = int(resnet_num)
|
||||
|
||||
# if block_num >= seamless_down_blocks:
|
||||
if block_num >= len(model.down_blocks) - skipped_layers:
|
||||
# Could be configurable to allow skipping arbitrary numbers of down blocks
|
||||
if block_num >= len(model.down_blocks):
|
||||
continue
|
||||
|
||||
if resnet_num > 0 and skip_second_resnet:
|
||||
# Skip the second resnet (could be configurable)
|
||||
if resnet_num > 0:
|
||||
continue
|
||||
|
||||
if submodule_name == "conv2" and skip_conv2:
|
||||
# Skip Conv2d layers (could be configurable)
|
||||
if submodule_name == "conv2":
|
||||
continue
|
||||
|
||||
m.asymmetric_padding_mode = {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user