Compare commits

...

2 Commits

3 changed files with 92 additions and 60 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import contextlib
from contextlib import ExitStack from contextlib import ExitStack
from functools import singledispatchmethod from functools import singledispatchmethod
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
@ -716,10 +717,23 @@ class DenoiseLatentsInvocation(BaseInvocation):
**self.unet.unet.model_dump(), **self.unet.unet.model_dump(),
context=context, context=context,
) )
# Prepare seamless context, if configured.
seamless_context = contextlib.nullcontext()
seamless_config = self.unet.seamless
if seamless_config is not None:
seamless_context = set_seamless(
model=unet_info.context.model,
axes=seamless_config.axes,
skipped_layers=seamless_config.skipped_layers,
skip_second_resnet=seamless_config.skip_second_resnet,
skip_conv2=seamless_config.skip_conv2,
)
with ( with (
ExitStack() as exit_stack, ExitStack() as exit_stack,
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config), ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
set_seamless(unet_info.context.model, self.unet.seamless_axes), seamless_context,
unet_info as unet, unet_info as unet,
# Apply the LoRA after unet has been moved to its target device for faster patching. # Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()), ModelPatcher.apply_lora_unet(unet, _lora_loader()),
@ -826,7 +840,19 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
context=context, context=context,
) )
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae: # Prepare seamless context, if configured.
seamless_context = contextlib.nullcontext()
seamless_config = self.vae.seamless
if seamless_config is not None:
seamless_context = set_seamless(
model=vae_info.context.model,
axes=seamless_config.axes,
skipped_layers=seamless_config.skipped_layers,
skip_second_resnet=seamless_config.skip_second_resnet,
skip_conv2=seamless_config.skip_conv2,
)
with seamless_context, vae_info as vae:
latents = latents.to(vae.device) latents = latents.to(vae.device)
if self.fp32: if self.fp32:
vae.to(dtype=torch.float32) vae.to(dtype=torch.float32)

View File

@ -19,6 +19,13 @@ from .baseinvocation import (
) )
class SeamlessSettings(BaseModel):
axes: List[str] = Field(description="Axes('x' and 'y') to which apply seamless")
skipped_layers: int = Field(description="How much down layers skip when applying seamless")
skip_second_resnet: bool = Field(description="Skip or not second resnet in down blocks when applying seamless")
skip_conv2: bool = Field(description="Skip or not conv2 in down blocks when applying seamless")
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel") model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
@ -36,8 +43,8 @@ class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel") unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel") scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading") loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration") freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
seamless: Optional[SeamlessSettings] = Field(default=None, description="Seamless settings applied to model")
class ClipField(BaseModel): class ClipField(BaseModel):
@ -50,7 +57,7 @@ class ClipField(BaseModel):
class VaeField(BaseModel): class VaeField(BaseModel):
# TODO: better naming? # TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel") vae: ModelInfo = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') seamless: Optional[SeamlessSettings] = Field(default=None, description="Seamless settings applied to model")
@invocation_output("unet_output") @invocation_output("unet_output")
@ -451,6 +458,11 @@ class SeamlessModeInvocation(BaseInvocation):
) )
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless") seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless") seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
skipped_layers: int = InputField(default=0, input=Input.Any, description="How much model's down layers to skip")
skip_second_resnet: bool = InputField(
default=True, input=Input.Any, description="Skip or not second resnet in down layers"
)
skip_conv2: bool = InputField(default=True, input=Input.Any, description="Skip or not conv2 in down layers")
def invoke(self, context: InvocationContext) -> SeamlessModeOutput: def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y # Conditionally append 'x' and 'y' based on seamless_x and seamless_y
@ -465,9 +477,19 @@ class SeamlessModeInvocation(BaseInvocation):
seamless_axes_list.append("y") seamless_axes_list.append("y")
if unet is not None: if unet is not None:
unet.seamless_axes = seamless_axes_list unet.seamless = SeamlessSettings(
axes=seamless_axes_list,
skipped_layers=self.skipped_layers,
skip_second_resnet=self.skip_second_resnet,
skip_conv2=self.skip_conv2,
)
if vae is not None: if vae is not None:
vae.seamless_axes = seamless_axes_list vae.seamless = SeamlessSettings(
axes=seamless_axes_list,
skipped_layers=self.skipped_layers,
skip_second_resnet=self.skip_second_resnet,
skip_conv2=self.skip_conv2,
)
return SeamlessModeOutput(unet=unet, vae=vae) return SeamlessModeOutput(unet=unet, vae=vae)

View File

@ -25,71 +25,55 @@ def _conv_forward_asymmetric(self, input, weight, bias):
@contextmanager @contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]): def set_seamless(
model: Union[UNet2DConditionModel, AutoencoderKL],
axes: List[str],
skipped_layers: int,
skip_second_resnet: bool,
skip_conv2: bool,
):
try: try:
to_restore = [] to_restore = []
for m_name, m in model.named_modules(): for m_name, m in model.named_modules():
if isinstance(model, UNet2DConditionModel): if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
if ".attentions." in m_name: continue
if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name:
# down_blocks.1.resnets.1.conv1
_, block_num, _, resnet_num, submodule_name = m_name.split(".")
block_num = int(block_num)
resnet_num = int(resnet_num)
# if block_num >= seamless_down_blocks:
if block_num >= len(model.down_blocks) - skipped_layers:
continue continue
if ".resnets." in m_name: if resnet_num > 0 and skip_second_resnet:
if ".conv2" in m_name:
continue
if ".conv_shortcut" in m_name:
continue
"""
if isinstance(model, UNet2DConditionModel):
if False and ".upsamplers." in m_name:
continue continue
if False and ".downsamplers." in m_name: if submodule_name == "conv2" and skip_conv2:
continue continue
if True and ".resnets." in m_name: m.asymmetric_padding_mode = {}
if True and ".conv1" in m_name: m.asymmetric_padding = {}
if False and "down_blocks" in m_name: m.asymmetric_padding_mode["x"] = "circular" if ("x" in axes) else "constant"
continue m.asymmetric_padding["x"] = (
if False and "mid_block" in m_name: m._reversed_padding_repeated_twice[0],
continue m._reversed_padding_repeated_twice[1],
if False and "up_blocks" in m_name: 0,
continue 0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
if True and ".conv2" in m_name: to_restore.append((m, m._conv_forward))
continue m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
if True and ".conv_shortcut" in m_name:
continue
if True and ".attentions." in m_name:
continue
if False and m_name in ["conv_in", "conv_out"]:
continue
"""
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
yield yield