fix(nodes): fix slatents and rlatents bugs

This commit is contained in:
psychedelicious 2023-04-26 21:45:05 +10:00
parent c4f4f8b1b8
commit 0453d60c64

View File

@ -379,19 +379,16 @@ class ResizeLatentsInvocation(BaseInvocation):
type: Literal["lresize"] = "lresize"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to")
downsample: int = Field(
default=8, ge=1, description="The downsampling factor (leave at 8 for SD)"
)
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode"
)
latents: Optional[LatentsField] = Field(description="The latents to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
downsample: Optional[int] = Field(default=8, ge=1, description="The downsampling factor (leave at 8 for SD)")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
# resizing
resized_latents = torch.nn.functional.interpolate(
latents,
size=(
@ -399,6 +396,7 @@ class ResizeLatentsInvocation(BaseInvocation):
self.width // self.downsample,
),
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
@ -415,27 +413,20 @@ class ScaleLatentsInvocation(BaseInvocation):
type: Literal["lscale"] = "lscale"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize")
scale: int = Field(
default=2, ge=1, description="The factor by which to scale the latents"
)
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode"
)
latents: Optional[LatentsField] = Field(description="The latents to scale")
scale_factor: float = Field(ge=0, description="The factor by which to scale the latents")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
(_, _, h, w) = latents.size()
# resizing
resized_latents = torch.nn.functional.interpolate(
latents,
size=(
h * self.scale,
w * self.scale,
),
scale_factor=self.scale_factor,
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699