feat(nodes): hardcode resize latents downsampling

This commit is contained in:
psychedelicious 2023-04-27 09:59:22 +10:00
parent 0453d60c64
commit 4a924c9b54

View File

@ -374,7 +374,7 @@ LATENTS_INTERPOLATION_MODE = Literal[
class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height."""
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
type: Literal["lresize"] = "lresize"
@ -382,7 +382,6 @@ class ResizeLatentsInvocation(BaseInvocation):
latents: Optional[LatentsField] = Field(description="The latents to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
downsample: Optional[int] = Field(default=8, ge=1, description="The downsampling factor (leave at 8 for SD)")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
@ -391,10 +390,7 @@ class ResizeLatentsInvocation(BaseInvocation):
resized_latents = torch.nn.functional.interpolate(
latents,
size=(
self.height // self.downsample,
self.width // self.downsample,
),
size=(self.height // 8, self.width // 8),
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
@ -414,7 +410,7 @@ class ScaleLatentsInvocation(BaseInvocation):
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to scale")
scale_factor: float = Field(ge=0, description="The factor by which to scale the latents")
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")