mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(nodes): fix slatents and rlatents bugs
This commit is contained in:
parent
c4f4f8b1b8
commit
0453d60c64
@ -379,19 +379,16 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
type: Literal["lresize"] = "lresize"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to")
|
||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to")
|
||||
downsample: int = Field(
|
||||
default=8, ge=1, description="The downsampling factor (leave at 8 for SD)"
|
||||
)
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(
|
||||
default="bilinear", description="The interpolation mode"
|
||||
)
|
||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
downsample: Optional[int] = Field(default=8, ge=1, description="The downsampling factor (leave at 8 for SD)")
|
||||
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
# resizing
|
||||
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents,
|
||||
size=(
|
||||
@ -399,6 +396,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
self.width // self.downsample,
|
||||
),
|
||||
mode=self.mode,
|
||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
@ -415,27 +413,20 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
type: Literal["lscale"] = "lscale"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||
scale: int = Field(
|
||||
default=2, ge=1, description="The factor by which to scale the latents"
|
||||
)
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(
|
||||
default="bilinear", description="The interpolation mode"
|
||||
)
|
||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
||||
scale_factor: float = Field(ge=0, description="The factor by which to scale the latents")
|
||||
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
(_, _, h, w) = latents.size()
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents,
|
||||
size=(
|
||||
h * self.scale,
|
||||
w * self.scale,
|
||||
),
|
||||
scale_factor=self.scale_factor,
|
||||
mode=self.mode,
|
||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
|
Loading…
Reference in New Issue
Block a user