feat(nodes): add resize and scale latents nodes

- this resize/scale latents is what is needed for hires fix
- also remove unused `seed` from t2l
This commit is contained in:
psychedelicious 2023-04-24 22:07:53 +10:00
parent 00a0cb3403
commit 3e80eaa342

View File

@ -146,7 +146,6 @@ class TextToLatentsInvocation(BaseInvocation):
# TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
@ -363,9 +362,87 @@ class LatentsToImageInvocation(BaseInvocation):
session_id=context.graph_execution_state_id, node=self
)
torch.cuda.empty_cache()
context.services.images.save(image_type, image_name, image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image
image_type=image_type, image_name=image_name, image=image
)
LATENTS_INTERPOLATION_MODE = Literal[
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
]
class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height."""
type: Literal["lresize"] = "lresize"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to")
downsample: int = Field(
default=8, ge=1, description="The downsampling factor (leave at 8 for SD)"
)
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode"
)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
# resizing
resized_latents = torch.nn.functional.interpolate(
latents,
size=(
self.height // self.downsample,
self.width // self.downsample,
),
mode=self.mode,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name))
class ScaleLatentsInvocation(BaseInvocation):
"""Scales latents by a given factor."""
type: Literal["lscale"] = "lscale"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize")
scale: int = Field(
default=2, ge=1, description="The factor by which to scale the latents"
)
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode"
)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
(_, _, h, w) = latents.size()
# resizing
resized_latents = torch.nn.functional.interpolate(
latents,
size=(
h * self.scale,
w * self.scale,
),
mode=self.mode,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name))