From 3de45af73424f537b67736d892bdd49e27ed0b54 Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Sun, 27 Aug 2023 14:13:00 -0400 Subject: [PATCH] updates --- invokeai/app/invocations/model.py | 43 +++++++++++++ invokeai/backend/model_management/seamless.py | 60 +++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 invokeai/backend/model_management/seamless.py diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 3cae4b3383..1bb67b8c91 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -33,6 +33,7 @@ class UNetField(BaseModel): unet: ModelInfo = Field(description="Info to load unet submodel") scheduler: ModelInfo = Field(description="Info to load scheduler submodel") loras: List[LoraInfo] = Field(description="Loras to apply on model loading") + seamless_axes: List[str] = Field(default_factory=list, description="Axes(\"x\" and \"y\") to which apply seamless") class ClipField(BaseModel): @@ -388,3 +389,45 @@ class VaeLoaderInvocation(BaseInvocation): ) ) ) + + +class SeamlessModeOutput(BaseInvocationOutput): + """Modified Seamless Model output""" + + type: Literal["seamless_output"] = "seamless_output" + + # Outputs + unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") + +@title("Seamless") +@tags("seamless", "model") +class SeamlessModeInvocation(BaseInvocation): + """Apply seamless mode to unet.""" + + type: Literal["seamless"] = "seamless" + + # Inputs + unet: UNetField = InputField( + description=FieldDescriptions.unet, input=Input.Connection, title="UNet" + ) + + seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless") + seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless") + + + def invoke(self, context: InvocationContext) -> SeamlessModeOutput: + # Conditionally append 'x' and 'y' based on seamless_x and seamless_y + unet = copy.deepcopy(self.unet) + + seamless_axes_list = [] + + if self.seamless_x: + seamless_axes_list.append('x') + if self.seamless_y: + seamless_axes_list.append('y') + + unet.seamless_axes = seamless_axes_list + + return SeamlessModeOutput( + unet=unet, + ) \ No newline at end of file diff --git a/invokeai/backend/model_management/seamless.py b/invokeai/backend/model_management/seamless.py new file mode 100644 index 0000000000..0568c68ac2 --- /dev/null +++ b/invokeai/backend/model_management/seamless.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from contextlib import contextmanager + +import torch.nn as nn +from diffusers.models import UNet2DModel + +def _conv_forward_asymmetric(self, input, weight, bias): + """ + Patch for Conv2d._conv_forward that supports asymmetric padding + """ + working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"]) + working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"]) + return nn.functional.conv2d( + working, + weight, + bias, + self.stride, + nn.modules.utils._pair(0), + self.dilation, + self.groups, + ) + + +@contextmanager +def set_unet_seamless(model: UNet2DModel, seamless: bool, seamless_axes): + try: + to_restore = dict() + if seamless: + for m in model.modules(): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + m.asymmetric_padding_mode = {} + m.asymmetric_padding = {} + m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" + m.asymmetric_padding["x"] = ( + m._reversed_padding_repeated_twice[0], + m._reversed_padding_repeated_twice[1], + 0, + 0, + ) + m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant" + m.asymmetric_padding["y"] = ( + 0, + 0, + m._reversed_padding_repeated_twice[2], + m._reversed_padding_repeated_twice[3], + ) + + to_restore.append((m, m._conv_forward)) + m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) + + yield + + finally: + for module, orig_conv_forward in to_restore: + module._conv_forward = orig_conv_forward + if hasattr(m, "asymmetric_padding_mode"): + del m.asymmetric_padding_mode + if hasattr(m, "asymmetric_padding"): + del m.asymmetric_padding \ No newline at end of file