mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add w/h to latents outputs
This reduces the number of nodes needed when working with latents (ie fewer plain integer value nodes) Also correct a few mistakes in the fields
This commit is contained in:
parent
4333852c37
commit
93ced0bec6
@ -40,17 +40,40 @@ class LatentsField(BaseModel):
|
|||||||
class LatentsOutput(BaseInvocationOutput):
|
class LatentsOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output latents"""
|
"""Base class for invocations that output latents"""
|
||||||
#fmt: off
|
#fmt: off
|
||||||
type: Literal["latent_output"] = "latent_output"
|
type: Literal["latents_output"] = "latents_output"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
latents: LatentsField = Field(default=None, description="The output latents")
|
latents: LatentsField = Field(default=None, description="The output latents")
|
||||||
|
width: int = Field(description="The width of the latents in pixels")
|
||||||
|
height: int = Field(description="The height of the latents in pixels")
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
||||||
|
return LatentsOutput(
|
||||||
|
latents=LatentsField(latents_name=latents_name),
|
||||||
|
width=latents.size()[3] * 8,
|
||||||
|
height=latents.size()[2] * 8,
|
||||||
|
)
|
||||||
|
|
||||||
class NoiseOutput(BaseInvocationOutput):
|
class NoiseOutput(BaseInvocationOutput):
|
||||||
"""Invocation noise output"""
|
"""Invocation noise output"""
|
||||||
#fmt: off
|
#fmt: off
|
||||||
type: Literal["noise_output"] = "noise_output"
|
type: Literal["noise_output"] = "noise_output"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
noise: LatentsField = Field(default=None, description="The output noise")
|
noise: LatentsField = Field(default=None, description="The output noise")
|
||||||
|
width: int = Field(description="The width of the noise in pixels")
|
||||||
|
height: int = Field(description="The height of the noise in pixels")
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
def build_noise_output(latents_name: str, latents: torch.Tensor):
|
||||||
|
return NoiseOutput(
|
||||||
|
noise=LatentsField(latents_name=latents_name),
|
||||||
|
width=latents.size()[3] * 8,
|
||||||
|
height=latents.size()[2] * 8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: this seems like a hack
|
# TODO: this seems like a hack
|
||||||
scheduler_map = dict(
|
scheduler_map = dict(
|
||||||
@ -130,9 +153,7 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
context.services.latents.set(name, noise)
|
context.services.latents.set(name, noise)
|
||||||
return NoiseOutput(
|
return build_noise_output(latents_name=name, latents=noise)
|
||||||
noise=LatentsField(latents_name=name)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Text to image
|
# Text to image
|
||||||
@ -248,9 +269,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
context.services.latents.set(name, result_latents)
|
context.services.latents.set(name, result_latents)
|
||||||
return LatentsOutput(
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
latents=LatentsField(latents_name=name)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||||
@ -313,9 +332,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
context.services.latents.set(name, result_latents)
|
context.services.latents.set(name, result_latents)
|
||||||
return LatentsOutput(
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
latents=LatentsField(latents_name=name)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Latent to image
|
# Latent to image
|
||||||
@ -382,8 +399,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||||
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
|
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||||
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
@ -400,7 +417,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.set(name, resized_latents)
|
context.services.latents.set(name, resized_latents)
|
||||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||||
|
|
||||||
|
|
||||||
class ScaleLatentsInvocation(BaseInvocation):
|
class ScaleLatentsInvocation(BaseInvocation):
|
||||||
@ -411,8 +428,8 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
||||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
||||||
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
|
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||||
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
@ -430,7 +447,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.set(name, resized_latents)
|
context.services.latents.set(name, resized_latents)
|
||||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||||
|
|
||||||
|
|
||||||
class ImageToLatentsInvocation(BaseInvocation):
|
class ImageToLatentsInvocation(BaseInvocation):
|
||||||
@ -474,4 +491,4 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.set(name, latents)
|
context.services.latents.set(name, latents)
|
||||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
return build_latents_output(latents_name=name, latents=latents)
|
||||||
|
Loading…
Reference in New Issue
Block a user