mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add w/h to latents outputs (#3389)
This reduces the number of nodes needed when working with latents (ie fewer plain integer value nodes) Also correct a few mistakes in the fields
This commit is contained in:
commit
e559730b6e
@ -41,17 +41,40 @@ class LatentsField(BaseModel):
|
|||||||
class LatentsOutput(BaseInvocationOutput):
|
class LatentsOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output latents"""
|
"""Base class for invocations that output latents"""
|
||||||
#fmt: off
|
#fmt: off
|
||||||
type: Literal["latent_output"] = "latent_output"
|
type: Literal["latents_output"] = "latents_output"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
latents: LatentsField = Field(default=None, description="The output latents")
|
latents: LatentsField = Field(default=None, description="The output latents")
|
||||||
|
width: int = Field(description="The width of the latents in pixels")
|
||||||
|
height: int = Field(description="The height of the latents in pixels")
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
||||||
|
return LatentsOutput(
|
||||||
|
latents=LatentsField(latents_name=latents_name),
|
||||||
|
width=latents.size()[3] * 8,
|
||||||
|
height=latents.size()[2] * 8,
|
||||||
|
)
|
||||||
|
|
||||||
class NoiseOutput(BaseInvocationOutput):
|
class NoiseOutput(BaseInvocationOutput):
|
||||||
"""Invocation noise output"""
|
"""Invocation noise output"""
|
||||||
#fmt: off
|
#fmt: off
|
||||||
type: Literal["noise_output"] = "noise_output"
|
type: Literal["noise_output"] = "noise_output"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
noise: LatentsField = Field(default=None, description="The output noise")
|
noise: LatentsField = Field(default=None, description="The output noise")
|
||||||
|
width: int = Field(description="The width of the noise in pixels")
|
||||||
|
height: int = Field(description="The height of the noise in pixels")
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
def build_noise_output(latents_name: str, latents: torch.Tensor):
|
||||||
|
return NoiseOutput(
|
||||||
|
noise=LatentsField(latents_name=latents_name),
|
||||||
|
width=latents.size()[3] * 8,
|
||||||
|
height=latents.size()[2] * 8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[
|
SAMPLER_NAME_VALUES = Literal[
|
||||||
tuple(list(SCHEDULER_MAP.keys()))
|
tuple(list(SCHEDULER_MAP.keys()))
|
||||||
@ -122,9 +145,7 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
context.services.latents.set(name, noise)
|
context.services.latents.set(name, noise)
|
||||||
return NoiseOutput(
|
return build_noise_output(latents_name=name, latents=noise)
|
||||||
noise=LatentsField(latents_name=name)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Text to image
|
# Text to image
|
||||||
@ -240,9 +261,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
context.services.latents.set(name, result_latents)
|
context.services.latents.set(name, result_latents)
|
||||||
return LatentsOutput(
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
latents=LatentsField(latents_name=name)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||||
@ -301,9 +320,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
context.services.latents.set(name, result_latents)
|
context.services.latents.set(name, result_latents)
|
||||||
return LatentsOutput(
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
latents=LatentsField(latents_name=name)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Latent to image
|
# Latent to image
|
||||||
@ -370,8 +387,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||||
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
|
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||||
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
@ -388,7 +405,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.set(name, resized_latents)
|
context.services.latents.set(name, resized_latents)
|
||||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||||
|
|
||||||
|
|
||||||
class ScaleLatentsInvocation(BaseInvocation):
|
class ScaleLatentsInvocation(BaseInvocation):
|
||||||
@ -399,8 +416,8 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
||||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
||||||
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
|
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||||
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
@ -418,7 +435,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.set(name, resized_latents)
|
context.services.latents.set(name, resized_latents)
|
||||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||||
|
|
||||||
|
|
||||||
class ImageToLatentsInvocation(BaseInvocation):
|
class ImageToLatentsInvocation(BaseInvocation):
|
||||||
@ -462,4 +479,4 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.set(name, latents)
|
context.services.latents.set(name, latents)
|
||||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
return build_latents_output(latents_name=name, latents=latents)
|
||||||
|
Loading…
Reference in New Issue
Block a user