From 421f5b7d75f4cd1eba0e7725eba1d5471905e4c0 Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Mon, 28 Aug 2023 08:43:08 -0400 Subject: [PATCH] Seamless Updates --- invokeai/app/invocations/model.py | 10 ++++++---- invokeai/backend/model_management/seamless.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 89b292d223..a2cc4eb349 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -410,8 +410,8 @@ class SeamlessModeInvocation(BaseInvocation): type: Literal["seamless"] = "seamless" # Inputs - unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet") - vae: VaeField = InputField(description=FieldDescriptions.vae_model, input=Input.Any, title="VAE") + unet: Optional[UNetField] = InputField(default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet") + vae: Optional[VaeField] = InputField(default=None, description=FieldDescriptions.vae_model, input=Input.Any, title="VAE") seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless") seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless") @@ -427,7 +427,9 @@ class SeamlessModeInvocation(BaseInvocation): if self.seamless_y: seamless_axes_list.append("y") - unet.seamless_axes = seamless_axes_list - vae.seamless_axes = seamless_axes_list + if unet is not None: + unet.seamless_axes = seamless_axes_list + if vae is not None: + vae.seamless_axes = seamless_axes_list return SeamlessModeOutput(unet=unet, vae=vae) diff --git a/invokeai/backend/model_management/seamless.py b/invokeai/backend/model_management/seamless.py index 1801c6e057..ec81ed9a74 100644 --- a/invokeai/backend/model_management/seamless.py +++ b/invokeai/backend/model_management/seamless.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import contextmanager -from typing import TypeVar +from typing import Union import diffusers import torch.nn as nn from diffusers.models import UNet2DModel, AutoencoderKL @@ -24,7 +24,7 @@ def _conv_forward_asymmetric(self, input, weight, bias): ) -ModelType = TypeVar("ModelType", UNet2DModel, AutoencoderKL) +ModelType = Union[UNet2DModel, AutoencoderKL] @contextmanager