diff --git a/invokeai/backend/model_management/seamless.py b/invokeai/backend/model_management/seamless.py index 3ab2db1d90..e145c6f481 100644 --- a/invokeai/backend/model_management/seamless.py +++ b/invokeai/backend/model_management/seamless.py @@ -1,10 +1,11 @@ from __future__ import annotations from contextlib import contextmanager -from typing import List, Union +from typing import Callable, List, Union import torch.nn as nn -from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel def _conv_forward_asymmetric(self, input, weight, bias): @@ -26,12 +27,9 @@ def _conv_forward_asymmetric(self, input, weight, bias): @contextmanager def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]): + # Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor + to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = [] try: - to_restore = [] - skipped_layers = 0 - skip_second_resnet = True - skip_conv2 = True - for m_name, m in model.named_modules(): if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): continue @@ -42,14 +40,16 @@ def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axe block_num = int(block_num) resnet_num = int(resnet_num) - # if block_num >= seamless_down_blocks: - if block_num >= len(model.down_blocks) - skipped_layers: + # Could be configurable to allow skipping arbitrary numbers of down blocks + if block_num >= len(model.down_blocks): continue - if resnet_num > 0 and skip_second_resnet: + # Skip the second resnet (could be configurable) + if resnet_num > 0: continue - if submodule_name == "conv2" and skip_conv2: + # Skip Conv2d layers (could be configurable) + if submodule_name == "conv2": continue m.asymmetric_padding_mode = {}