feat(nodes): add w/h to latents outputs

This reduces the number of nodes needed when working with latents (ie fewer plain integer value nodes)

Also correct a few mistakes in the fields
This commit is contained in:
psychedelicious 2023-05-11 20:23:02 +10:00
parent 4333852c37
commit 93ced0bec6

View File

@ -40,17 +40,40 @@ class LatentsField(BaseModel):
class LatentsOutput(BaseInvocationOutput): class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents""" """Base class for invocations that output latents"""
#fmt: off #fmt: off
type: Literal["latent_output"] = "latent_output" type: Literal["latents_output"] = "latents_output"
# Inputs
latents: LatentsField = Field(default=None, description="The output latents") latents: LatentsField = Field(default=None, description="The output latents")
width: int = Field(description="The width of the latents in pixels")
height: int = Field(description="The height of the latents in pixels")
#fmt: on #fmt: on
def build_latents_output(latents_name: str, latents: torch.Tensor):
return LatentsOutput(
latents=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
class NoiseOutput(BaseInvocationOutput): class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output""" """Invocation noise output"""
#fmt: off #fmt: off
type: Literal["noise_output"] = "noise_output" type: Literal["noise_output"] = "noise_output"
# Inputs
noise: LatentsField = Field(default=None, description="The output noise") noise: LatentsField = Field(default=None, description="The output noise")
width: int = Field(description="The width of the noise in pixels")
height: int = Field(description="The height of the noise in pixels")
#fmt: on #fmt: on
def build_noise_output(latents_name: str, latents: torch.Tensor):
return NoiseOutput(
noise=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
# TODO: this seems like a hack # TODO: this seems like a hack
scheduler_map = dict( scheduler_map = dict(
@ -130,9 +153,7 @@ class NoiseInvocation(BaseInvocation):
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, noise) context.services.latents.set(name, noise)
return NoiseOutput( return build_noise_output(latents_name=name, latents=noise)
noise=LatentsField(latents_name=name)
)
# Text to image # Text to image
@ -248,9 +269,7 @@ class TextToLatentsInvocation(BaseInvocation):
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents) context.services.latents.set(name, result_latents)
return LatentsOutput( return build_latents_output(latents_name=name, latents=result_latents)
latents=LatentsField(latents_name=name)
)
class LatentsToLatentsInvocation(TextToLatentsInvocation): class LatentsToLatentsInvocation(TextToLatentsInvocation):
@ -313,9 +332,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents) context.services.latents.set(name, result_latents)
return LatentsOutput( return build_latents_output(latents_name=name, latents=result_latents)
latents=LatentsField(latents_name=name)
)
# Latent to image # Latent to image
@ -382,8 +399,8 @@ class ResizeLatentsInvocation(BaseInvocation):
latents: Optional[LatentsField] = Field(description="The latents to resize") latents: Optional[LatentsField] = Field(description="The latents to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode") mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
@ -400,7 +417,7 @@ class ResizeLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents) context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name)) return build_latents_output(latents_name=name, latents=resized_latents)
class ScaleLatentsInvocation(BaseInvocation): class ScaleLatentsInvocation(BaseInvocation):
@ -411,8 +428,8 @@ class ScaleLatentsInvocation(BaseInvocation):
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to scale") latents: Optional[LatentsField] = Field(description="The latents to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode") mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
@ -430,7 +447,7 @@ class ScaleLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents) context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name)) return build_latents_output(latents_name=name, latents=resized_latents)
class ImageToLatentsInvocation(BaseInvocation): class ImageToLatentsInvocation(BaseInvocation):
@ -474,4 +491,4 @@ class ImageToLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, latents) context.services.latents.set(name, latents)
return LatentsOutput(latents=LatentsField(latents_name=name)) return build_latents_output(latents_name=name, latents=latents)