Seamless Updates

This commit is contained in:
Kent Keirsey 2023-08-28 08:43:08 -04:00
parent 3ef36707a8
commit 421f5b7d75
2 changed files with 8 additions and 6 deletions

View File

@ -410,8 +410,8 @@ class SeamlessModeInvocation(BaseInvocation):
type: Literal["seamless"] = "seamless"
# Inputs
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet")
vae: VaeField = InputField(description=FieldDescriptions.vae_model, input=Input.Any, title="VAE")
unet: Optional[UNetField] = InputField(default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet")
vae: Optional[VaeField] = InputField(default=None, description=FieldDescriptions.vae_model, input=Input.Any, title="VAE")
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
@ -427,7 +427,9 @@ class SeamlessModeInvocation(BaseInvocation):
if self.seamless_y:
seamless_axes_list.append("y")
unet.seamless_axes = seamless_axes_list
vae.seamless_axes = seamless_axes_list
if unet is not None:
unet.seamless_axes = seamless_axes_list
if vae is not None:
vae.seamless_axes = seamless_axes_list
return SeamlessModeOutput(unet=unet, vae=vae)

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TypeVar
from typing import Union
import diffusers
import torch.nn as nn
from diffusers.models import UNet2DModel, AutoencoderKL
@ -24,7 +24,7 @@ def _conv_forward_asymmetric(self, input, weight, bias):
)
ModelType = TypeVar("ModelType", UNet2DModel, AutoencoderKL)
ModelType = Union[UNet2DModel, AutoencoderKL]
@contextmanager