From 93ced0bec6879494d01419d56444aa592755d7f5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 11 May 2023 20:23:02 +1000 Subject: [PATCH] feat(nodes): add w/h to latents outputs This reduces the number of nodes needed when working with latents (ie fewer plain integer value nodes) Also correct a few mistakes in the fields --- invokeai/app/invocations/latent.py | 65 +++++++++++++++++++----------- 1 file changed, 41 insertions(+), 24 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 40575c1f64..3b6eb95fa0 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -40,17 +40,40 @@ class LatentsField(BaseModel): class LatentsOutput(BaseInvocationOutput): """Base class for invocations that output latents""" #fmt: off - type: Literal["latent_output"] = "latent_output" - latents: LatentsField = Field(default=None, description="The output latents") + type: Literal["latents_output"] = "latents_output" + + # Inputs + latents: LatentsField = Field(default=None, description="The output latents") + width: int = Field(description="The width of the latents in pixels") + height: int = Field(description="The height of the latents in pixels") #fmt: on + +def build_latents_output(latents_name: str, latents: torch.Tensor): + return LatentsOutput( + latents=LatentsField(latents_name=latents_name), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) + class NoiseOutput(BaseInvocationOutput): """Invocation noise output""" #fmt: off - type: Literal["noise_output"] = "noise_output" + type: Literal["noise_output"] = "noise_output" + + # Inputs noise: LatentsField = Field(default=None, description="The output noise") + width: int = Field(description="The width of the noise in pixels") + height: int = Field(description="The height of the noise in pixels") #fmt: on +def build_noise_output(latents_name: str, latents: torch.Tensor): + return NoiseOutput( + noise=LatentsField(latents_name=latents_name), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) + # TODO: this seems like a hack scheduler_map = dict( @@ -130,9 +153,7 @@ class NoiseInvocation(BaseInvocation): name = f'{context.graph_execution_state_id}__{self.id}' context.services.latents.set(name, noise) - return NoiseOutput( - noise=LatentsField(latents_name=name) - ) + return build_noise_output(latents_name=name, latents=noise) # Text to image @@ -248,9 +269,7 @@ class TextToLatentsInvocation(BaseInvocation): name = f'{context.graph_execution_state_id}__{self.id}' context.services.latents.set(name, result_latents) - return LatentsOutput( - latents=LatentsField(latents_name=name) - ) + return build_latents_output(latents_name=name, latents=result_latents) class LatentsToLatentsInvocation(TextToLatentsInvocation): @@ -313,9 +332,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): name = f'{context.graph_execution_state_id}__{self.id}' context.services.latents.set(name, result_latents) - return LatentsOutput( - latents=LatentsField(latents_name=name) - ) + return build_latents_output(latents_name=name, latents=result_latents) # Latent to image @@ -379,11 +396,11 @@ class ResizeLatentsInvocation(BaseInvocation): type: Literal["lresize"] = "lresize" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to resize") - width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") - height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") - mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode") - antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") + latents: Optional[LatentsField] = Field(description="The latents to resize") + width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") + height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") + mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") + antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) @@ -400,7 +417,7 @@ class ResizeLatentsInvocation(BaseInvocation): name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.set(name, resized_latents) - return LatentsOutput(latents=LatentsField(latents_name=name)) + return build_latents_output(latents_name=name, latents=resized_latents) class ScaleLatentsInvocation(BaseInvocation): @@ -409,10 +426,10 @@ class ScaleLatentsInvocation(BaseInvocation): type: Literal["lscale"] = "lscale" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to scale") - scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") - mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode") - antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") + latents: Optional[LatentsField] = Field(description="The latents to scale") + scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") + mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") + antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) @@ -430,7 +447,7 @@ class ScaleLatentsInvocation(BaseInvocation): name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.set(name, resized_latents) - return LatentsOutput(latents=LatentsField(latents_name=name)) + return build_latents_output(latents_name=name, latents=resized_latents) class ImageToLatentsInvocation(BaseInvocation): @@ -474,4 +491,4 @@ class ImageToLatentsInvocation(BaseInvocation): name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.set(name, latents) - return LatentsOutput(latents=LatentsField(latents_name=name)) + return build_latents_output(latents_name=name, latents=latents)