mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add resize and scale latents nodes
- this resize/scale latents is what is needed for hires fix - also remove unused `seed` from t2l
This commit is contained in:
parent
00a0cb3403
commit
3e80eaa342
@ -146,7 +146,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||
@ -363,9 +362,87 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
context.services.images.save(image_type, image_name, image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image
|
||||
image_type=image_type, image_name=image_name, image=image
|
||||
)
|
||||
|
||||
|
||||
LATENTS_INTERPOLATION_MODE = Literal[
|
||||
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
|
||||
]
|
||||
|
||||
|
||||
class ResizeLatentsInvocation(BaseInvocation):
|
||||
"""Resizes latents to explicit width/height."""
|
||||
|
||||
type: Literal["lresize"] = "lresize"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to")
|
||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to")
|
||||
downsample: int = Field(
|
||||
default=8, ge=1, description="The downsampling factor (leave at 8 for SD)"
|
||||
)
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(
|
||||
default="bilinear", description="The interpolation mode"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents,
|
||||
size=(
|
||||
self.height // self.downsample,
|
||||
self.width // self.downsample,
|
||||
),
|
||||
mode=self.mode,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.set(name, resized_latents)
|
||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
||||
|
||||
|
||||
class ScaleLatentsInvocation(BaseInvocation):
|
||||
"""Scales latents by a given factor."""
|
||||
|
||||
type: Literal["lscale"] = "lscale"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||
scale: int = Field(
|
||||
default=2, ge=1, description="The factor by which to scale the latents"
|
||||
)
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(
|
||||
default="bilinear", description="The interpolation mode"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
(_, _, h, w) = latents.size()
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents,
|
||||
size=(
|
||||
h * self.scale,
|
||||
w * self.scale,
|
||||
),
|
||||
mode=self.mode,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.set(name, resized_latents)
|
||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
||||
|
Loading…
Reference in New Issue
Block a user