From 0453d60c645de3415612281dc34f15133c5ee06e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 26 Apr 2023 21:45:05 +1000 Subject: [PATCH] fix(nodes): fix slatents and rlatents bugs --- invokeai/app/invocations/latent.py | 37 +++++++++++------------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index a7ed313d17..ca65b4d9ed 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -379,19 +379,16 @@ class ResizeLatentsInvocation(BaseInvocation): type: Literal["lresize"] = "lresize" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to resize") - width: int = Field(ge=64, multiple_of=8, description="The width to resize to") - height: int = Field(ge=64, multiple_of=8, description="The height to resize to") - downsample: int = Field( - default=8, ge=1, description="The downsampling factor (leave at 8 for SD)" - ) - mode: LATENTS_INTERPOLATION_MODE = Field( - default="bilinear", description="The interpolation mode" - ) + latents: Optional[LatentsField] = Field(description="The latents to resize") + width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") + height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") + downsample: Optional[int] = Field(default=8, ge=1, description="The downsampling factor (leave at 8 for SD)") + mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode") + antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) - # resizing + resized_latents = torch.nn.functional.interpolate( latents, size=( @@ -399,6 +396,7 @@ class ResizeLatentsInvocation(BaseInvocation): self.width // self.downsample, ), mode=self.mode, + antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 @@ -415,27 +413,20 @@ class ScaleLatentsInvocation(BaseInvocation): type: Literal["lscale"] = "lscale" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to resize") - scale: int = Field( - default=2, ge=1, description="The factor by which to scale the latents" - ) - mode: LATENTS_INTERPOLATION_MODE = Field( - default="bilinear", description="The interpolation mode" - ) + latents: Optional[LatentsField] = Field(description="The latents to scale") + scale_factor: float = Field(ge=0, description="The factor by which to scale the latents") + mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode") + antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) - (_, _, h, w) = latents.size() - # resizing resized_latents = torch.nn.functional.interpolate( latents, - size=( - h * self.scale, - w * self.scale, - ), + scale_factor=self.scale_factor, mode=self.mode, + antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699