fix(nodes): fix slatents and rlatents bugs

This commit is contained in:
psychedelicious 2023-04-26 21:45:05 +10:00
parent c4f4f8b1b8
commit 0453d60c64

View File

@ -380,18 +380,15 @@ class ResizeLatentsInvocation(BaseInvocation):
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize") latents: Optional[LatentsField] = Field(description="The latents to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to") width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to") height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
downsample: int = Field( downsample: Optional[int] = Field(default=8, ge=1, description="The downsampling factor (leave at 8 for SD)")
default=8, ge=1, description="The downsampling factor (leave at 8 for SD)" mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
) antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode"
)
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
# resizing
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents, latents,
size=( size=(
@ -399,6 +396,7 @@ class ResizeLatentsInvocation(BaseInvocation):
self.width // self.downsample, self.width // self.downsample,
), ),
mode=self.mode, mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
@ -415,27 +413,20 @@ class ScaleLatentsInvocation(BaseInvocation):
type: Literal["lscale"] = "lscale" type: Literal["lscale"] = "lscale"
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize") latents: Optional[LatentsField] = Field(description="The latents to scale")
scale: int = Field( scale_factor: float = Field(ge=0, description="The factor by which to scale the latents")
default=2, ge=1, description="The factor by which to scale the latents" mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
) antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode"
)
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
(_, _, h, w) = latents.size()
# resizing # resizing
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents, latents,
size=( scale_factor=self.scale_factor,
h * self.scale,
w * self.scale,
),
mode=self.mode, mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699