Add more seamless configuration options.

This commit is contained in:
Sergey Borisov
2023-09-20 01:10:37 +03:00
committed by Ryan Dick
parent cb7e56a9a3
commit 7f3be627c2
3 changed files with 66 additions and 60 deletions

View File

@ -719,7 +719,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
with ( with (
ExitStack() as exit_stack, ExitStack() as exit_stack,
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config), ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
set_seamless(unet_info.context.model, self.unet.seamless_axes), set_seamless(unet_info.context.model, **self.unet.seamless.dict()),
unet_info as unet, unet_info as unet,
# Apply the LoRA after unet has been moved to its target device for faster patching. # Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()), ModelPatcher.apply_lora_unet(unet, _lora_loader()),
@ -826,7 +826,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
context=context, context=context,
) )
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae: with set_seamless(vae_info.context.model, **self.vae.seamless.dict()), vae_info as vae:
latents = latents.to(vae.device) latents = latents.to(vae.device)
if self.fp32: if self.fp32:
vae.to(dtype=torch.float32) vae.to(dtype=torch.float32)

View File

@ -19,6 +19,13 @@ from .baseinvocation import (
) )
class SeamlessSettings(BaseModel):
axes: List[str] = Field(description="Axes('x' and 'y') to which apply seamless")
skipped_layers: int = Field(description="How much down layers skip when applying seamless")
skip_second_resnet: bool = Field(description="Skip or not second resnet in down blocks when applying seamless")
skip_conv2: bool = Field(description="Skip or not conv2 in down blocks when applying seamless")
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel") model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
@ -36,8 +43,8 @@ class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel") unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel") scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading") loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration") freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
seamless: Optional[SeamlessSettings] = Field(default=None, description="Seamless settings applied to model")
class ClipField(BaseModel): class ClipField(BaseModel):
@ -50,7 +57,7 @@ class ClipField(BaseModel):
class VaeField(BaseModel): class VaeField(BaseModel):
# TODO: better naming? # TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel") vae: ModelInfo = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') seamless: Optional[SeamlessSettings] = Field(default=None, description="Seamless settings applied to model")
@invocation_output("unet_output") @invocation_output("unet_output")
@ -451,6 +458,11 @@ class SeamlessModeInvocation(BaseInvocation):
) )
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless") seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless") seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
skipped_layers: int = InputField(default=0, input=Input.Any, description="How much model's down layers to skip")
skip_second_resnet: bool = InputField(
default=True, input=Input.Any, description="Skip or not second resnet in down layers"
)
skip_conv2: bool = InputField(default=True, input=Input.Any, description="Skip or not conv2 in down layers")
def invoke(self, context: InvocationContext) -> SeamlessModeOutput: def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y # Conditionally append 'x' and 'y' based on seamless_x and seamless_y
@ -465,9 +477,19 @@ class SeamlessModeInvocation(BaseInvocation):
seamless_axes_list.append("y") seamless_axes_list.append("y")
if unet is not None: if unet is not None:
unet.seamless_axes = seamless_axes_list unet.seamless = SeamlessSettings(
axes=seamless_axes_list,
skipped_layers=self.skipped_layers,
skip_second_resnet=self.skip_second_resnet,
skip_conv2=self.skip_conv2,
)
if vae is not None: if vae is not None:
vae.seamless_axes = seamless_axes_list vae.seamless = SeamlessSettings(
axes=seamless_axes_list,
skipped_layers=self.skipped_layers,
skip_second_resnet=self.skip_second_resnet,
skip_conv2=self.skip_conv2,
)
return SeamlessModeOutput(unet=unet, vae=vae) return SeamlessModeOutput(unet=unet, vae=vae)

View File

@ -25,62 +25,46 @@ def _conv_forward_asymmetric(self, input, weight, bias):
@contextmanager @contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]): def set_seamless(
model: Union[UNet2DConditionModel, AutoencoderKL],
axes: List[str],
skipped_layers: int,
skip_second_resnet: bool,
skip_conv2: bool,
):
try: try:
to_restore = [] to_restore = []
for m_name, m in model.named_modules(): for m_name, m in model.named_modules():
if isinstance(model, UNet2DConditionModel): if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
if ".attentions." in m_name:
continue continue
if ".resnets." in m_name: if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name:
if ".conv2" in m_name: # down_blocks.1.resnets.1.conv1
continue _, block_num, _, resnet_num, submodule_name = m_name.split(".")
if ".conv_shortcut" in m_name: block_num = int(block_num)
resnet_num = int(resnet_num)
# if block_num >= seamless_down_blocks:
if block_num >= len(model.down_blocks) - skipped_layers:
continue continue
""" if resnet_num > 0 and skip_second_resnet:
if isinstance(model, UNet2DConditionModel):
if False and ".upsamplers." in m_name:
continue continue
if False and ".downsamplers." in m_name: if submodule_name == "conv2" and skip_conv2:
continue continue
if True and ".resnets." in m_name:
if True and ".conv1" in m_name:
if False and "down_blocks" in m_name:
continue
if False and "mid_block" in m_name:
continue
if False and "up_blocks" in m_name:
continue
if True and ".conv2" in m_name:
continue
if True and ".conv_shortcut" in m_name:
continue
if True and ".attentions." in m_name:
continue
if False and m_name in ["conv_in", "conv_out"]:
continue
"""
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m.asymmetric_padding_mode = {} m.asymmetric_padding_mode = {}
m.asymmetric_padding = {} m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" m.asymmetric_padding_mode["x"] = "circular" if ("x" in axes) else "constant"
m.asymmetric_padding["x"] = ( m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0], m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1], m._reversed_padding_repeated_twice[1],
0, 0,
0, 0,
) )
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant" m.asymmetric_padding_mode["y"] = "circular" if ("y" in axes) else "constant"
m.asymmetric_padding["y"] = ( m.asymmetric_padding["y"] = (
0, 0,
0, 0,