diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index e794423fa1..89b292d223 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -33,7 +33,7 @@ class UNetField(BaseModel): unet: ModelInfo = Field(description="Info to load unet submodel") scheduler: ModelInfo = Field(description="Info to load scheduler submodel") loras: List[LoraInfo] = Field(description="Loras to apply on model loading") - seamless_axes: List[str] = Field(default_factory=list, description="Axes(\"x\" and \"y\") to which apply seamless") + seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') class ClipField(BaseModel): @@ -46,7 +46,7 @@ class ClipField(BaseModel): class VaeField(BaseModel): # TODO: better naming? vae: ModelInfo = Field(description="Info to load vae submodel") - seamless_axes: List[str] = Field(default_factory=list, description="Axes(\"x\" and \"y\") to which apply seamless") + seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') class ModelLoaderOutput(BaseInvocationOutput): @@ -401,6 +401,7 @@ class SeamlessModeOutput(BaseInvocationOutput): unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet") vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE") + @title("Seamless") @tags("seamless", "model") class SeamlessModeInvocation(BaseInvocation): @@ -409,32 +410,24 @@ class SeamlessModeInvocation(BaseInvocation): type: Literal["seamless"] = "seamless" # Inputs - unet: UNetField = InputField( - description=FieldDescriptions.unet, input=Input.Connection, title="UNet" - ) - vae: VaeField = InputField( - description=FieldDescriptions.vae_model, input=Input.Any, title="VAE" - ) + unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet") + vae: VaeField = InputField(description=FieldDescriptions.vae_model, input=Input.Any, title="VAE") seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless") seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless") - def invoke(self, context: InvocationContext) -> SeamlessModeOutput: - # Conditionally append 'x' and 'y' based on seamless_x and seamless_y + # Conditionally append 'x' and 'y' based on seamless_x and seamless_y unet = copy.deepcopy(self.unet) vae = copy.deepcopy(self.vae) - + seamless_axes_list = [] if self.seamless_x: - seamless_axes_list.append('x') + seamless_axes_list.append("x") if self.seamless_y: - seamless_axes_list.append('y') + seamless_axes_list.append("y") unet.seamless_axes = seamless_axes_list vae.seamless_axes = seamless_axes_list - - return SeamlessModeOutput( - unet=unet, - vae=vae - ) \ No newline at end of file + + return SeamlessModeOutput(unet=unet, vae=vae) diff --git a/invokeai/backend/model_management/seamless.py b/invokeai/backend/model_management/seamless.py index b56b64f1de..1801c6e057 100644 --- a/invokeai/backend/model_management/seamless.py +++ b/invokeai/backend/model_management/seamless.py @@ -6,6 +6,7 @@ import diffusers import torch.nn as nn from diffusers.models import UNet2DModel, AutoencoderKL + def _conv_forward_asymmetric(self, input, weight, bias): """ Patch for Conv2d._conv_forward that supports asymmetric padding @@ -23,13 +24,14 @@ def _conv_forward_asymmetric(self, input, weight, bias): ) -ModelType = TypeVar('ModelType', UNet2DModel, AutoencoderKL) +ModelType = TypeVar("ModelType", UNet2DModel, AutoencoderKL) + @contextmanager def set_seamless(model: ModelType, seamless_axes): try: - to_restore = [] - + to_restore = [] + for m in model.modules(): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): m.asymmetric_padding_mode = {} @@ -64,4 +66,6 @@ def set_seamless(model: ModelType, seamless_axes): if hasattr(m, "asymmetric_padding"): del m.asymmetric_padding if isinstance(m, diffusers.models.lora.LoRACompatibleConv): - m.forward = diffusers.models.lora.LoRACompatibleConv.forward.__get__(m,diffusers.models.lora.LoRACompatibleConv) \ No newline at end of file + m.forward = diffusers.models.lora.LoRACompatibleConv.forward.__get__( + m, diffusers.models.lora.LoRACompatibleConv + )