mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore: Black lint
This commit is contained in:
parent
b9731cb434
commit
3ef36707a8
@ -33,7 +33,7 @@ class UNetField(BaseModel):
|
|||||||
unet: ModelInfo = Field(description="Info to load unet submodel")
|
unet: ModelInfo = Field(description="Info to load unet submodel")
|
||||||
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
||||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||||
seamless_axes: List[str] = Field(default_factory=list, description="Axes(\"x\" and \"y\") to which apply seamless")
|
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||||
|
|
||||||
|
|
||||||
class ClipField(BaseModel):
|
class ClipField(BaseModel):
|
||||||
@ -46,7 +46,7 @@ class ClipField(BaseModel):
|
|||||||
class VaeField(BaseModel):
|
class VaeField(BaseModel):
|
||||||
# TODO: better naming?
|
# TODO: better naming?
|
||||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||||
seamless_axes: List[str] = Field(default_factory=list, description="Axes(\"x\" and \"y\") to which apply seamless")
|
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||||
|
|
||||||
|
|
||||||
class ModelLoaderOutput(BaseInvocationOutput):
|
class ModelLoaderOutput(BaseInvocationOutput):
|
||||||
@ -401,6 +401,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
|
|||||||
unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet")
|
unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||||
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@title("Seamless")
|
@title("Seamless")
|
||||||
@tags("seamless", "model")
|
@tags("seamless", "model")
|
||||||
class SeamlessModeInvocation(BaseInvocation):
|
class SeamlessModeInvocation(BaseInvocation):
|
||||||
@ -409,32 +410,24 @@ class SeamlessModeInvocation(BaseInvocation):
|
|||||||
type: Literal["seamless"] = "seamless"
|
type: Literal["seamless"] = "seamless"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
unet: UNetField = InputField(
|
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet")
|
||||||
description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
vae: VaeField = InputField(description=FieldDescriptions.vae_model, input=Input.Any, title="VAE")
|
||||||
)
|
|
||||||
vae: VaeField = InputField(
|
|
||||||
description=FieldDescriptions.vae_model, input=Input.Any, title="VAE"
|
|
||||||
)
|
|
||||||
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
|
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
|
||||||
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
||||||
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
|
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
|
||||||
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
|
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
|
||||||
unet = copy.deepcopy(self.unet)
|
unet = copy.deepcopy(self.unet)
|
||||||
vae = copy.deepcopy(self.vae)
|
vae = copy.deepcopy(self.vae)
|
||||||
|
|
||||||
seamless_axes_list = []
|
seamless_axes_list = []
|
||||||
|
|
||||||
if self.seamless_x:
|
if self.seamless_x:
|
||||||
seamless_axes_list.append('x')
|
seamless_axes_list.append("x")
|
||||||
if self.seamless_y:
|
if self.seamless_y:
|
||||||
seamless_axes_list.append('y')
|
seamless_axes_list.append("y")
|
||||||
|
|
||||||
unet.seamless_axes = seamless_axes_list
|
unet.seamless_axes = seamless_axes_list
|
||||||
vae.seamless_axes = seamless_axes_list
|
vae.seamless_axes = seamless_axes_list
|
||||||
|
|
||||||
return SeamlessModeOutput(
|
return SeamlessModeOutput(unet=unet, vae=vae)
|
||||||
unet=unet,
|
|
||||||
vae=vae
|
|
||||||
)
|
|
||||||
|
@ -6,6 +6,7 @@ import diffusers
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from diffusers.models import UNet2DModel, AutoencoderKL
|
from diffusers.models import UNet2DModel, AutoencoderKL
|
||||||
|
|
||||||
|
|
||||||
def _conv_forward_asymmetric(self, input, weight, bias):
|
def _conv_forward_asymmetric(self, input, weight, bias):
|
||||||
"""
|
"""
|
||||||
Patch for Conv2d._conv_forward that supports asymmetric padding
|
Patch for Conv2d._conv_forward that supports asymmetric padding
|
||||||
@ -23,13 +24,14 @@ def _conv_forward_asymmetric(self, input, weight, bias):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
ModelType = TypeVar('ModelType', UNet2DModel, AutoencoderKL)
|
ModelType = TypeVar("ModelType", UNet2DModel, AutoencoderKL)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def set_seamless(model: ModelType, seamless_axes):
|
def set_seamless(model: ModelType, seamless_axes):
|
||||||
try:
|
try:
|
||||||
to_restore = []
|
to_restore = []
|
||||||
|
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||||
m.asymmetric_padding_mode = {}
|
m.asymmetric_padding_mode = {}
|
||||||
@ -64,4 +66,6 @@ def set_seamless(model: ModelType, seamless_axes):
|
|||||||
if hasattr(m, "asymmetric_padding"):
|
if hasattr(m, "asymmetric_padding"):
|
||||||
del m.asymmetric_padding
|
del m.asymmetric_padding
|
||||||
if isinstance(m, diffusers.models.lora.LoRACompatibleConv):
|
if isinstance(m, diffusers.models.lora.LoRACompatibleConv):
|
||||||
m.forward = diffusers.models.lora.LoRACompatibleConv.forward.__get__(m,diffusers.models.lora.LoRACompatibleConv)
|
m.forward = diffusers.models.lora.LoRACompatibleConv.forward.__get__(
|
||||||
|
m, diffusers.models.lora.LoRACompatibleConv
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user