This commit is contained in:
Kent Keirsey
2023-08-27 14:53:57 -04:00
parent 0d2e194213
commit ea40a7844a
3 changed files with 48 additions and 5 deletions

View File

@ -33,7 +33,7 @@ from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management.models import BaseModelType
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.seamless import set_unet_seamless
from ...backend.model_management.seamless import set_unet_seamless, set_vae_seamless
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData,
@ -491,7 +491,7 @@ class LatentsToImageInvocation(BaseInvocation):
context=context,
)
with vae_info as vae:
with set_vae_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)

View File

@ -46,6 +46,7 @@ class ClipField(BaseModel):
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description="Axes(\"x\" and \"y\") to which apply seamless")
class ModelLoaderOutput(BaseInvocationOutput):
@ -398,6 +399,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
# Outputs
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@title("Seamless")
@tags("seamless", "model")
@ -410,7 +412,9 @@ class SeamlessModeInvocation(BaseInvocation):
unet: UNetField = InputField(
description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
)
vae_model: VAEModelField = InputField(
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
)
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
@ -418,6 +422,7 @@ class SeamlessModeInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
unet = copy.deepcopy(self.unet)
vae = copy.deepcopy(self.vae)
seamless_axes_list = []
@ -427,7 +432,9 @@ class SeamlessModeInvocation(BaseInvocation):
seamless_axes_list.append('y')
unet.seamless_axes = seamless_axes_list
vae.seamless_axes = seamless_axes_list
return SeamlessModeOutput(
unet=unet,
vae=vae
)