From ea40a7844a5f2227a6fdece9c425d51ecb2bc362 Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Sun, 27 Aug 2023 14:53:57 -0400 Subject: [PATCH] add VAE --- invokeai/app/invocations/latent.py | 4 +- invokeai/app/invocations/model.py | 11 +++++- invokeai/backend/model_management/seamless.py | 38 ++++++++++++++++++- 3 files changed, 48 insertions(+), 5 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 708dfe81b9..80f90157df 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -33,7 +33,7 @@ from invokeai.backend.model_management.models import ModelType, SilenceWarnings from ...backend.model_management.models import BaseModelType from ...backend.model_management.lora import ModelPatcher -from ...backend.model_management.seamless import set_unet_seamless +from ...backend.model_management.seamless import set_unet_seamless, set_vae_seamless from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ConditioningData, @@ -491,7 +491,7 @@ class LatentsToImageInvocation(BaseInvocation): context=context, ) - with vae_info as vae: + with set_vae_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae: latents = latents.to(vae.device) if self.fp32: vae.to(dtype=torch.float32) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 1bb67b8c91..eac3638ad3 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -46,6 +46,7 @@ class ClipField(BaseModel): class VaeField(BaseModel): # TODO: better naming? vae: ModelInfo = Field(description="Info to load vae submodel") + seamless_axes: List[str] = Field(default_factory=list, description="Axes(\"x\" and \"y\") to which apply seamless") class ModelLoaderOutput(BaseInvocationOutput): @@ -398,6 +399,7 @@ class SeamlessModeOutput(BaseInvocationOutput): # Outputs unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") + vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") @title("Seamless") @tags("seamless", "model") @@ -410,7 +412,9 @@ class SeamlessModeInvocation(BaseInvocation): unet: UNetField = InputField( description=FieldDescriptions.unet, input=Input.Connection, title="UNet" ) - + vae_model: VAEModelField = InputField( + description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE" + ) seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless") seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless") @@ -418,6 +422,7 @@ class SeamlessModeInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> SeamlessModeOutput: # Conditionally append 'x' and 'y' based on seamless_x and seamless_y unet = copy.deepcopy(self.unet) + vae = copy.deepcopy(self.vae) seamless_axes_list = [] @@ -427,7 +432,9 @@ class SeamlessModeInvocation(BaseInvocation): seamless_axes_list.append('y') unet.seamless_axes = seamless_axes_list - + vae.seamless_axes = seamless_axes_list + return SeamlessModeOutput( unet=unet, + vae=vae ) \ No newline at end of file diff --git a/invokeai/backend/model_management/seamless.py b/invokeai/backend/model_management/seamless.py index ea4037077e..49770b4281 100644 --- a/invokeai/backend/model_management/seamless.py +++ b/invokeai/backend/model_management/seamless.py @@ -3,7 +3,7 @@ from __future__ import annotations from contextlib import contextmanager import torch.nn as nn -from diffusers.models import UNet2DModel +from diffusers.models import UNet2DModel, AutoencoderKL def _conv_forward_asymmetric(self, input, weight, bias): """ @@ -51,6 +51,42 @@ def set_unet_seamless(model: UNet2DModel, seamless_axes): yield + finally: + for module, orig_conv_forward in to_restore: + module._conv_forward = orig_conv_forward + if hasattr(m, "asymmetric_padding_mode"): + del m.asymmetric_padding_mode + if hasattr(m, "asymmetric_padding"): + del m.asymmetric_padding + +def set_vae_seamless(model: AutoencoderKL, seamless_axes): + try: + to_restore = [] + + for m in model.modules(): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + m.asymmetric_padding_mode = {} + m.asymmetric_padding = {} + m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" + m.asymmetric_padding["x"] = ( + m._reversed_padding_repeated_twice[0], + m._reversed_padding_repeated_twice[1], + 0, + 0, + ) + m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant" + m.asymmetric_padding["y"] = ( + 0, + 0, + m._reversed_padding_repeated_twice[2], + m._reversed_padding_repeated_twice[3], + ) + + to_restore.append((m, m._conv_forward)) + m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) + + yield + finally: for module, orig_conv_forward in to_restore: module._conv_forward = orig_conv_forward