diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 825847cf79..ac7139d031 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -41,17 +41,40 @@ class LatentsField(BaseModel): class LatentsOutput(BaseInvocationOutput): """Base class for invocations that output latents""" #fmt: off - type: Literal["latent_output"] = "latent_output" - latents: LatentsField = Field(default=None, description="The output latents") + type: Literal["latents_output"] = "latents_output" + + # Inputs + latents: LatentsField = Field(default=None, description="The output latents") + width: int = Field(description="The width of the latents in pixels") + height: int = Field(description="The height of the latents in pixels") #fmt: on + +def build_latents_output(latents_name: str, latents: torch.Tensor): + return LatentsOutput( + latents=LatentsField(latents_name=latents_name), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) + class NoiseOutput(BaseInvocationOutput): """Invocation noise output""" #fmt: off - type: Literal["noise_output"] = "noise_output" + type: Literal["noise_output"] = "noise_output" + + # Inputs noise: LatentsField = Field(default=None, description="The output noise") + width: int = Field(description="The width of the noise in pixels") + height: int = Field(description="The height of the noise in pixels") #fmt: on +def build_noise_output(latents_name: str, latents: torch.Tensor): + return NoiseOutput( + noise=LatentsField(latents_name=latents_name), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) + SAMPLER_NAME_VALUES = Literal[ tuple(list(SCHEDULER_MAP.keys())) @@ -122,9 +145,7 @@ class NoiseInvocation(BaseInvocation): name = f'{context.graph_execution_state_id}__{self.id}' context.services.latents.set(name, noise) - return NoiseOutput( - noise=LatentsField(latents_name=name) - ) + return build_noise_output(latents_name=name, latents=noise) # Text to image @@ -240,9 +261,7 @@ class TextToLatentsInvocation(BaseInvocation): name = f'{context.graph_execution_state_id}__{self.id}' context.services.latents.set(name, result_latents) - return LatentsOutput( - latents=LatentsField(latents_name=name) - ) + return build_latents_output(latents_name=name, latents=result_latents) class LatentsToLatentsInvocation(TextToLatentsInvocation): @@ -301,9 +320,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): name = f'{context.graph_execution_state_id}__{self.id}' context.services.latents.set(name, result_latents) - return LatentsOutput( - latents=LatentsField(latents_name=name) - ) + return build_latents_output(latents_name=name, latents=result_latents) # Latent to image @@ -367,11 +384,11 @@ class ResizeLatentsInvocation(BaseInvocation): type: Literal["lresize"] = "lresize" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to resize") - width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") - height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") - mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode") - antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") + latents: Optional[LatentsField] = Field(description="The latents to resize") + width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") + height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") + mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") + antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) @@ -388,7 +405,7 @@ class ResizeLatentsInvocation(BaseInvocation): name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.set(name, resized_latents) - return LatentsOutput(latents=LatentsField(latents_name=name)) + return build_latents_output(latents_name=name, latents=resized_latents) class ScaleLatentsInvocation(BaseInvocation): @@ -397,10 +414,10 @@ class ScaleLatentsInvocation(BaseInvocation): type: Literal["lscale"] = "lscale" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to scale") - scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") - mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode") - antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") + latents: Optional[LatentsField] = Field(description="The latents to scale") + scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") + mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") + antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) @@ -418,7 +435,7 @@ class ScaleLatentsInvocation(BaseInvocation): name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.set(name, resized_latents) - return LatentsOutput(latents=LatentsField(latents_name=name)) + return build_latents_output(latents_name=name, latents=resized_latents) class ImageToLatentsInvocation(BaseInvocation): @@ -462,4 +479,4 @@ class ImageToLatentsInvocation(BaseInvocation): name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.set(name, latents) - return LatentsOutput(latents=LatentsField(latents_name=name)) + return build_latents_output(latents_name=name, latents=latents)