mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Use LATENT_SCALE_FACTOR = 8 constant in CropLatentsInvocation.
This commit is contained in:
parent
121b930abf
commit
e1c53a2465
@ -79,6 +79,12 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
|||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
||||||
|
|
||||||
|
# HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to
|
||||||
|
# be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale
|
||||||
|
# factor is hard-coded to a literal '8' rather than using this constant.
|
||||||
|
# The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
|
||||||
|
LATENT_SCALE_FACTOR = 8
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("scheduler_output")
|
@invocation_output("scheduler_output")
|
||||||
class SchedulerOutput(BaseInvocationOutput):
|
class SchedulerOutput(BaseInvocationOutput):
|
||||||
@ -390,9 +396,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
exit_stack: ExitStack,
|
exit_stack: ExitStack,
|
||||||
do_classifier_free_guidance: bool = True,
|
do_classifier_free_guidance: bool = True,
|
||||||
) -> List[ControlNetData]:
|
) -> List[ControlNetData]:
|
||||||
# assuming fixed dimensional scaling of 8:1 for image:latents
|
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
||||||
control_height_resize = latents_shape[2] * 8
|
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
|
||||||
control_width_resize = latents_shape[3] * 8
|
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
|
||||||
if control_input is None:
|
if control_input is None:
|
||||||
control_list = None
|
control_list = None
|
||||||
elif isinstance(control_input, list) and len(control_input) == 0:
|
elif isinstance(control_input, list) and len(control_input) == 0:
|
||||||
@ -905,12 +911,12 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
width: int = InputField(
|
width: int = InputField(
|
||||||
ge=64,
|
ge=64,
|
||||||
multiple_of=8,
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
description=FieldDescriptions.width,
|
description=FieldDescriptions.width,
|
||||||
)
|
)
|
||||||
height: int = InputField(
|
height: int = InputField(
|
||||||
ge=64,
|
ge=64,
|
||||||
multiple_of=8,
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
description=FieldDescriptions.width,
|
description=FieldDescriptions.width,
|
||||||
)
|
)
|
||||||
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
||||||
@ -924,7 +930,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents.to(device),
|
latents.to(device),
|
||||||
size=(self.height // 8, self.width // 8),
|
size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR),
|
||||||
mode=self.mode,
|
mode=self.mode,
|
||||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||||
)
|
)
|
||||||
@ -1180,32 +1186,32 @@ class CropLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
width: int = InputField(
|
width: int = InputField(
|
||||||
ge=64,
|
ge=64,
|
||||||
multiple_of=_downsampling_factor,
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
description=FieldDescriptions.width,
|
description=FieldDescriptions.width,
|
||||||
)
|
)
|
||||||
height: int = InputField(
|
height: int = InputField(
|
||||||
ge=64,
|
ge=64,
|
||||||
multiple_of=_downsampling_factor,
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
description=FieldDescriptions.width,
|
description=FieldDescriptions.width,
|
||||||
)
|
)
|
||||||
x_offset: int = InputField(
|
x_offset: int = InputField(
|
||||||
ge=0,
|
ge=0,
|
||||||
multiple_of=_downsampling_factor,
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
description="x-coordinate",
|
description="x-coordinate",
|
||||||
)
|
)
|
||||||
y_offset: int = InputField(
|
y_offset: int = InputField(
|
||||||
ge=0,
|
ge=0,
|
||||||
multiple_of=_downsampling_factor,
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
description="y-coordinate",
|
description="y-coordinate",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
x1 = self.x_offset // _downsampling_factor
|
x1 = self.x_offset // LATENT_SCALE_FACTOR
|
||||||
y1 = self.y_offset // _downsampling_factor
|
y1 = self.y_offset // LATENT_SCALE_FACTOR
|
||||||
x2 = x1 + (self.width // _downsampling_factor)
|
x2 = x1 + (self.width // LATENT_SCALE_FACTOR)
|
||||||
y2 = y1 + (self.height // _downsampling_factor)
|
y2 = y1 + (self.height // LATENT_SCALE_FACTOR)
|
||||||
|
|
||||||
cropped_latents = latents[:, :, y1:y2, x1:x2]
|
cropped_latents = latents[:, :, y1:y2, x1:x2]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user