From 7f3be627c28e0af655d02cdb7a851ef1d92d9798 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 20 Sep 2023 01:10:37 +0300 Subject: [PATCH] Add more seamless configuration options. --- invokeai/app/invocations/latent.py | 4 +- invokeai/app/invocations/model.py | 30 +++++- invokeai/backend/model_management/seamless.py | 92 ++++++++----------- 3 files changed, 66 insertions(+), 60 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 9b93cf0a3d..3428fd1c2e 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -719,7 +719,7 @@ class DenoiseLatentsInvocation(BaseInvocation): with ( ExitStack() as exit_stack, ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config), - set_seamless(unet_info.context.model, self.unet.seamless_axes), + set_seamless(unet_info.context.model, **self.unet.seamless.dict()), unet_info as unet, # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet(unet, _lora_loader()), @@ -826,7 +826,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata): context=context, ) - with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae: + with set_seamless(vae_info.context.model, **self.vae.seamless.dict()), vae_info as vae: latents = latents.to(vae.device) if self.fp32: vae.to(dtype=torch.float32) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 99dcc72999..420e0439e0 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -19,6 +19,13 @@ from .baseinvocation import ( ) +class SeamlessSettings(BaseModel): + axes: List[str] = Field(description="Axes('x' and 'y') to which apply seamless") + skipped_layers: int = Field(description="How much down layers skip when applying seamless") + skip_second_resnet: bool = Field(description="Skip or not second resnet in down blocks when applying seamless") + skip_conv2: bool = Field(description="Skip or not conv2 in down blocks when applying seamless") + + class ModelInfo(BaseModel): model_name: str = Field(description="Info to load submodel") base_model: BaseModelType = Field(description="Base model") @@ -36,8 +43,8 @@ class UNetField(BaseModel): unet: ModelInfo = Field(description="Info to load unet submodel") scheduler: ModelInfo = Field(description="Info to load scheduler submodel") loras: List[LoraInfo] = Field(description="Loras to apply on model loading") - seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration") + seamless: Optional[SeamlessSettings] = Field(default=None, description="Seamless settings applied to model") class ClipField(BaseModel): @@ -50,7 +57,7 @@ class ClipField(BaseModel): class VaeField(BaseModel): # TODO: better naming? vae: ModelInfo = Field(description="Info to load vae submodel") - seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') + seamless: Optional[SeamlessSettings] = Field(default=None, description="Seamless settings applied to model") @invocation_output("unet_output") @@ -451,6 +458,11 @@ class SeamlessModeInvocation(BaseInvocation): ) seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless") seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless") + skipped_layers: int = InputField(default=0, input=Input.Any, description="How much model's down layers to skip") + skip_second_resnet: bool = InputField( + default=True, input=Input.Any, description="Skip or not second resnet in down layers" + ) + skip_conv2: bool = InputField(default=True, input=Input.Any, description="Skip or not conv2 in down layers") def invoke(self, context: InvocationContext) -> SeamlessModeOutput: # Conditionally append 'x' and 'y' based on seamless_x and seamless_y @@ -465,9 +477,19 @@ class SeamlessModeInvocation(BaseInvocation): seamless_axes_list.append("y") if unet is not None: - unet.seamless_axes = seamless_axes_list + unet.seamless = SeamlessSettings( + axes=seamless_axes_list, + skipped_layers=self.skipped_layers, + skip_second_resnet=self.skip_second_resnet, + skip_conv2=self.skip_conv2, + ) if vae is not None: - vae.seamless_axes = seamless_axes_list + vae.seamless = SeamlessSettings( + axes=seamless_axes_list, + skipped_layers=self.skipped_layers, + skip_second_resnet=self.skip_second_resnet, + skip_conv2=self.skip_conv2, + ) return SeamlessModeOutput(unet=unet, vae=vae) diff --git a/invokeai/backend/model_management/seamless.py b/invokeai/backend/model_management/seamless.py index bfdf9e0c53..7c7cab136f 100644 --- a/invokeai/backend/model_management/seamless.py +++ b/invokeai/backend/model_management/seamless.py @@ -25,71 +25,55 @@ def _conv_forward_asymmetric(self, input, weight, bias): @contextmanager -def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]): +def set_seamless( + model: Union[UNet2DConditionModel, AutoencoderKL], + axes: List[str], + skipped_layers: int, + skip_second_resnet: bool, + skip_conv2: bool, +): try: to_restore = [] for m_name, m in model.named_modules(): - if isinstance(model, UNet2DConditionModel): - if ".attentions." in m_name: + if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + continue + + if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name: + # down_blocks.1.resnets.1.conv1 + _, block_num, _, resnet_num, submodule_name = m_name.split(".") + block_num = int(block_num) + resnet_num = int(resnet_num) + + # if block_num >= seamless_down_blocks: + if block_num >= len(model.down_blocks) - skipped_layers: continue - if ".resnets." in m_name: - if ".conv2" in m_name: - continue - if ".conv_shortcut" in m_name: - continue - - """ - if isinstance(model, UNet2DConditionModel): - if False and ".upsamplers." in m_name: + if resnet_num > 0 and skip_second_resnet: continue - if False and ".downsamplers." in m_name: + if submodule_name == "conv2" and skip_conv2: continue - if True and ".resnets." in m_name: - if True and ".conv1" in m_name: - if False and "down_blocks" in m_name: - continue - if False and "mid_block" in m_name: - continue - if False and "up_blocks" in m_name: - continue + m.asymmetric_padding_mode = {} + m.asymmetric_padding = {} + m.asymmetric_padding_mode["x"] = "circular" if ("x" in axes) else "constant" + m.asymmetric_padding["x"] = ( + m._reversed_padding_repeated_twice[0], + m._reversed_padding_repeated_twice[1], + 0, + 0, + ) + m.asymmetric_padding_mode["y"] = "circular" if ("y" in axes) else "constant" + m.asymmetric_padding["y"] = ( + 0, + 0, + m._reversed_padding_repeated_twice[2], + m._reversed_padding_repeated_twice[3], + ) - if True and ".conv2" in m_name: - continue - - if True and ".conv_shortcut" in m_name: - continue - - if True and ".attentions." in m_name: - continue - - if False and m_name in ["conv_in", "conv_out"]: - continue - """ - - if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): - m.asymmetric_padding_mode = {} - m.asymmetric_padding = {} - m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" - m.asymmetric_padding["x"] = ( - m._reversed_padding_repeated_twice[0], - m._reversed_padding_repeated_twice[1], - 0, - 0, - ) - m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant" - m.asymmetric_padding["y"] = ( - 0, - 0, - m._reversed_padding_repeated_twice[2], - m._reversed_padding_repeated_twice[3], - ) - - to_restore.append((m, m._conv_forward)) - m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) + to_restore.append((m, m._conv_forward)) + m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) yield