mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
updates
This commit is contained in:
parent
95883c2efd
commit
3de45af734
@ -33,6 +33,7 @@ class UNetField(BaseModel):
|
|||||||
unet: ModelInfo = Field(description="Info to load unet submodel")
|
unet: ModelInfo = Field(description="Info to load unet submodel")
|
||||||
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
||||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||||
|
seamless_axes: List[str] = Field(default_factory=list, description="Axes(\"x\" and \"y\") to which apply seamless")
|
||||||
|
|
||||||
|
|
||||||
class ClipField(BaseModel):
|
class ClipField(BaseModel):
|
||||||
@ -388,3 +389,45 @@ class VaeLoaderInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SeamlessModeOutput(BaseInvocationOutput):
|
||||||
|
"""Modified Seamless Model output"""
|
||||||
|
|
||||||
|
type: Literal["seamless_output"] = "seamless_output"
|
||||||
|
|
||||||
|
# Outputs
|
||||||
|
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||||
|
|
||||||
|
@title("Seamless")
|
||||||
|
@tags("seamless", "model")
|
||||||
|
class SeamlessModeInvocation(BaseInvocation):
|
||||||
|
"""Apply seamless mode to unet."""
|
||||||
|
|
||||||
|
type: Literal["seamless"] = "seamless"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
unet: UNetField = InputField(
|
||||||
|
description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
||||||
|
)
|
||||||
|
|
||||||
|
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
|
||||||
|
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
||||||
|
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
|
||||||
|
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
|
||||||
|
unet = copy.deepcopy(self.unet)
|
||||||
|
|
||||||
|
seamless_axes_list = []
|
||||||
|
|
||||||
|
if self.seamless_x:
|
||||||
|
seamless_axes_list.append('x')
|
||||||
|
if self.seamless_y:
|
||||||
|
seamless_axes_list.append('y')
|
||||||
|
|
||||||
|
unet.seamless_axes = seamless_axes_list
|
||||||
|
|
||||||
|
return SeamlessModeOutput(
|
||||||
|
unet=unet,
|
||||||
|
)
|
60
invokeai/backend/model_management/seamless.py
Normal file
60
invokeai/backend/model_management/seamless.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from diffusers.models import UNet2DModel
|
||||||
|
|
||||||
|
def _conv_forward_asymmetric(self, input, weight, bias):
|
||||||
|
"""
|
||||||
|
Patch for Conv2d._conv_forward that supports asymmetric padding
|
||||||
|
"""
|
||||||
|
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
|
||||||
|
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
|
||||||
|
return nn.functional.conv2d(
|
||||||
|
working,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
self.stride,
|
||||||
|
nn.modules.utils._pair(0),
|
||||||
|
self.dilation,
|
||||||
|
self.groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def set_unet_seamless(model: UNet2DModel, seamless: bool, seamless_axes):
|
||||||
|
try:
|
||||||
|
to_restore = dict()
|
||||||
|
if seamless:
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||||
|
m.asymmetric_padding_mode = {}
|
||||||
|
m.asymmetric_padding = {}
|
||||||
|
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
|
||||||
|
m.asymmetric_padding["x"] = (
|
||||||
|
m._reversed_padding_repeated_twice[0],
|
||||||
|
m._reversed_padding_repeated_twice[1],
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
|
||||||
|
m.asymmetric_padding["y"] = (
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
m._reversed_padding_repeated_twice[2],
|
||||||
|
m._reversed_padding_repeated_twice[3],
|
||||||
|
)
|
||||||
|
|
||||||
|
to_restore.append((m, m._conv_forward))
|
||||||
|
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
for module, orig_conv_forward in to_restore:
|
||||||
|
module._conv_forward = orig_conv_forward
|
||||||
|
if hasattr(m, "asymmetric_padding_mode"):
|
||||||
|
del m.asymmetric_padding_mode
|
||||||
|
if hasattr(m, "asymmetric_padding"):
|
||||||
|
del m.asymmetric_padding
|
Loading…
Reference in New Issue
Block a user