From 2a1d7342a768758149bdf2a51646e44b7be93ee0 Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Mon, 28 Aug 2023 15:46:49 -0400 Subject: [PATCH] Seamless Patch from Stalker --- invokeai/backend/model_management/seamless.py | 50 +++++++++++++++++-- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/invokeai/backend/model_management/seamless.py b/invokeai/backend/model_management/seamless.py index 997ff3563f..b2215e41fc 100644 --- a/invokeai/backend/model_management/seamless.py +++ b/invokeai/backend/model_management/seamless.py @@ -4,7 +4,7 @@ from contextlib import contextmanager from typing import List, Union import torch.nn as nn -from diffusers.models import AutoencoderKL, UNet2DModel +from diffusers.models import AutoencoderKL, UNet2DConditionModel def _conv_forward_asymmetric(self, input, weight, bias): @@ -25,12 +25,56 @@ def _conv_forward_asymmetric(self, input, weight, bias): @contextmanager -def set_seamless(model: Union[UNet2DModel, AutoencoderKL], seamless_axes: List[str]): +def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]): try: to_restore = [] - for m in model.modules(): + #print("try seamless") + + for m_name, m in model.named_modules(): + + if isinstance(model, UNet2DConditionModel): + if ".attentions." in m_name: + continue + + if ".resnets." in m_name: + if ".conv2" in m_name: + continue + if ".conv_shortcut" in m_name: + continue + + """ + if isinstance(model, UNet2DConditionModel): + if False and ".upsamplers." in m_name: + continue + + if False and ".downsamplers." in m_name: + continue + + if True and ".resnets." in m_name: + if True and ".conv1" in m_name: + if False and "down_blocks" in m_name: + continue + if False and "mid_block" in m_name: + continue + if False and "up_blocks" in m_name: + continue + + if True and ".conv2" in m_name: + continue + + if True and ".conv_shortcut" in m_name: + continue + + if True and ".attentions." in m_name: + continue + + if False and m_name in ["conv_in", "conv_out"]: + continue + """ + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + print(f"applied - {m_name}") m.asymmetric_padding_mode = {} m.asymmetric_padding = {} m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"