Seamless fixes

This commit is contained in:
Kent Keirsey 2023-08-28 00:10:46 -04:00
parent 5fdd25501b
commit 1f476692da
2 changed files with 9 additions and 5 deletions

View File

@ -404,7 +404,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
@title("Seamless") @title("Seamless")
@tags("seamless", "model") @tags("seamless", "model")
class SeamlessModeInvocation(BaseInvocation): class SeamlessModeInvocation(BaseInvocation):
"""Apply seamless mode to unet.""" """Applies the seamless transformation to the Model UNet and VAE."""
type: Literal["seamless"] = "seamless" type: Literal["seamless"] = "seamless"

View File

@ -1,7 +1,8 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from typing import TypeVar, Union from typing import TypeVar
import diffusers
import torch.nn as nn import torch.nn as nn
from diffusers.models import UNet2DModel, AutoencoderKL from diffusers.models import UNet2DModel, AutoencoderKL
@ -22,10 +23,9 @@ def _conv_forward_asymmetric(self, input, weight, bias):
) )
@contextmanager
ModelType = TypeVar('ModelType', UNet2DModel, AutoencoderKL) ModelType = TypeVar('ModelType', UNet2DModel, AutoencoderKL)
@contextmanager
def set_seamless(model: ModelType, seamless_axes): def set_seamless(model: ModelType, seamless_axes):
try: try:
to_restore = [] to_restore = []
@ -51,6 +51,8 @@ def set_seamless(model: ModelType, seamless_axes):
to_restore.append((m, m._conv_forward)) to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
if isinstance(m, diffusers.models.lora.LoRACompatibleConv) and m.lora_layer is None:
m.forward = nn.Conv2d.forward.__get__(m, nn.Conv2d)
yield yield
@ -61,3 +63,5 @@ def set_seamless(model: ModelType, seamless_axes):
del m.asymmetric_padding_mode del m.asymmetric_padding_mode
if hasattr(m, "asymmetric_padding"): if hasattr(m, "asymmetric_padding"):
del m.asymmetric_padding del m.asymmetric_padding
if isinstance(m, diffusers.models.lora.LoRACompatibleConv):
m.forward = diffusers.models.lora.LoRACompatibleConv.forward.__get__(m,diffusers.models.lora.LoRACompatibleConv)