mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add more seamless configuration options.
This commit is contained in:
committed by
Ryan Dick
parent
cb7e56a9a3
commit
7f3be627c2
@ -719,7 +719,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
|
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
|
||||||
set_seamless(unet_info.context.model, self.unet.seamless_axes),
|
set_seamless(unet_info.context.model, **self.unet.seamless.dict()),
|
||||||
unet_info as unet,
|
unet_info as unet,
|
||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||||
@ -826,7 +826,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
|||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
|
with set_seamless(vae_info.context.model, **self.vae.seamless.dict()), vae_info as vae:
|
||||||
latents = latents.to(vae.device)
|
latents = latents.to(vae.device)
|
||||||
if self.fp32:
|
if self.fp32:
|
||||||
vae.to(dtype=torch.float32)
|
vae.to(dtype=torch.float32)
|
||||||
|
@ -19,6 +19,13 @@ from .baseinvocation import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SeamlessSettings(BaseModel):
|
||||||
|
axes: List[str] = Field(description="Axes('x' and 'y') to which apply seamless")
|
||||||
|
skipped_layers: int = Field(description="How much down layers skip when applying seamless")
|
||||||
|
skip_second_resnet: bool = Field(description="Skip or not second resnet in down blocks when applying seamless")
|
||||||
|
skip_conv2: bool = Field(description="Skip or not conv2 in down blocks when applying seamless")
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(BaseModel):
|
||||||
model_name: str = Field(description="Info to load submodel")
|
model_name: str = Field(description="Info to load submodel")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
@ -36,8 +43,8 @@ class UNetField(BaseModel):
|
|||||||
unet: ModelInfo = Field(description="Info to load unet submodel")
|
unet: ModelInfo = Field(description="Info to load unet submodel")
|
||||||
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
||||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
|
||||||
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
||||||
|
seamless: Optional[SeamlessSettings] = Field(default=None, description="Seamless settings applied to model")
|
||||||
|
|
||||||
|
|
||||||
class ClipField(BaseModel):
|
class ClipField(BaseModel):
|
||||||
@ -50,7 +57,7 @@ class ClipField(BaseModel):
|
|||||||
class VaeField(BaseModel):
|
class VaeField(BaseModel):
|
||||||
# TODO: better naming?
|
# TODO: better naming?
|
||||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
seamless: Optional[SeamlessSettings] = Field(default=None, description="Seamless settings applied to model")
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("unet_output")
|
@invocation_output("unet_output")
|
||||||
@ -451,6 +458,11 @@ class SeamlessModeInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
|
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
|
||||||
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
||||||
|
skipped_layers: int = InputField(default=0, input=Input.Any, description="How much model's down layers to skip")
|
||||||
|
skip_second_resnet: bool = InputField(
|
||||||
|
default=True, input=Input.Any, description="Skip or not second resnet in down layers"
|
||||||
|
)
|
||||||
|
skip_conv2: bool = InputField(default=True, input=Input.Any, description="Skip or not conv2 in down layers")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
|
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
|
||||||
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
|
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
|
||||||
@ -465,9 +477,19 @@ class SeamlessModeInvocation(BaseInvocation):
|
|||||||
seamless_axes_list.append("y")
|
seamless_axes_list.append("y")
|
||||||
|
|
||||||
if unet is not None:
|
if unet is not None:
|
||||||
unet.seamless_axes = seamless_axes_list
|
unet.seamless = SeamlessSettings(
|
||||||
|
axes=seamless_axes_list,
|
||||||
|
skipped_layers=self.skipped_layers,
|
||||||
|
skip_second_resnet=self.skip_second_resnet,
|
||||||
|
skip_conv2=self.skip_conv2,
|
||||||
|
)
|
||||||
if vae is not None:
|
if vae is not None:
|
||||||
vae.seamless_axes = seamless_axes_list
|
vae.seamless = SeamlessSettings(
|
||||||
|
axes=seamless_axes_list,
|
||||||
|
skipped_layers=self.skipped_layers,
|
||||||
|
skip_second_resnet=self.skip_second_resnet,
|
||||||
|
skip_conv2=self.skip_conv2,
|
||||||
|
)
|
||||||
|
|
||||||
return SeamlessModeOutput(unet=unet, vae=vae)
|
return SeamlessModeOutput(unet=unet, vae=vae)
|
||||||
|
|
||||||
|
@ -25,62 +25,46 @@ def _conv_forward_asymmetric(self, input, weight, bias):
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
|
def set_seamless(
|
||||||
|
model: Union[UNet2DConditionModel, AutoencoderKL],
|
||||||
|
axes: List[str],
|
||||||
|
skipped_layers: int,
|
||||||
|
skip_second_resnet: bool,
|
||||||
|
skip_conv2: bool,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
to_restore = []
|
to_restore = []
|
||||||
|
|
||||||
for m_name, m in model.named_modules():
|
for m_name, m in model.named_modules():
|
||||||
if isinstance(model, UNet2DConditionModel):
|
if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||||
if ".attentions." in m_name:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if ".resnets." in m_name:
|
if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name:
|
||||||
if ".conv2" in m_name:
|
# down_blocks.1.resnets.1.conv1
|
||||||
continue
|
_, block_num, _, resnet_num, submodule_name = m_name.split(".")
|
||||||
if ".conv_shortcut" in m_name:
|
block_num = int(block_num)
|
||||||
|
resnet_num = int(resnet_num)
|
||||||
|
|
||||||
|
# if block_num >= seamless_down_blocks:
|
||||||
|
if block_num >= len(model.down_blocks) - skipped_layers:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
"""
|
if resnet_num > 0 and skip_second_resnet:
|
||||||
if isinstance(model, UNet2DConditionModel):
|
|
||||||
if False and ".upsamplers." in m_name:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if False and ".downsamplers." in m_name:
|
if submodule_name == "conv2" and skip_conv2:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if True and ".resnets." in m_name:
|
|
||||||
if True and ".conv1" in m_name:
|
|
||||||
if False and "down_blocks" in m_name:
|
|
||||||
continue
|
|
||||||
if False and "mid_block" in m_name:
|
|
||||||
continue
|
|
||||||
if False and "up_blocks" in m_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if True and ".conv2" in m_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if True and ".conv_shortcut" in m_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if True and ".attentions." in m_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if False and m_name in ["conv_in", "conv_out"]:
|
|
||||||
continue
|
|
||||||
"""
|
|
||||||
|
|
||||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
|
||||||
m.asymmetric_padding_mode = {}
|
m.asymmetric_padding_mode = {}
|
||||||
m.asymmetric_padding = {}
|
m.asymmetric_padding = {}
|
||||||
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
|
m.asymmetric_padding_mode["x"] = "circular" if ("x" in axes) else "constant"
|
||||||
m.asymmetric_padding["x"] = (
|
m.asymmetric_padding["x"] = (
|
||||||
m._reversed_padding_repeated_twice[0],
|
m._reversed_padding_repeated_twice[0],
|
||||||
m._reversed_padding_repeated_twice[1],
|
m._reversed_padding_repeated_twice[1],
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
|
m.asymmetric_padding_mode["y"] = "circular" if ("y" in axes) else "constant"
|
||||||
m.asymmetric_padding["y"] = (
|
m.asymmetric_padding["y"] = (
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
Reference in New Issue
Block a user