Use LATENT_SCALE_FACTOR = 8 constant in CropLatentsInvocation.

This commit is contained in:
Ryan Dick 2023-11-27 11:12:15 -05:00
parent 121b930abf
commit e1c53a2465

View File

@ -79,6 +79,12 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
# HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to
# be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale
# factor is hard-coded to a literal '8' rather than using this constant.
# The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
LATENT_SCALE_FACTOR = 8
@invocation_output("scheduler_output") @invocation_output("scheduler_output")
class SchedulerOutput(BaseInvocationOutput): class SchedulerOutput(BaseInvocationOutput):
@ -390,9 +396,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
exit_stack: ExitStack, exit_stack: ExitStack,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]: ) -> List[ControlNetData]:
# assuming fixed dimensional scaling of 8:1 for image:latents # Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
control_height_resize = latents_shape[2] * 8 control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
control_width_resize = latents_shape[3] * 8 control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
if control_input is None: if control_input is None:
control_list = None control_list = None
elif isinstance(control_input, list) and len(control_input) == 0: elif isinstance(control_input, list) and len(control_input) == 0:
@ -905,12 +911,12 @@ class ResizeLatentsInvocation(BaseInvocation):
) )
width: int = InputField( width: int = InputField(
ge=64, ge=64,
multiple_of=8, multiple_of=LATENT_SCALE_FACTOR,
description=FieldDescriptions.width, description=FieldDescriptions.width,
) )
height: int = InputField( height: int = InputField(
ge=64, ge=64,
multiple_of=8, multiple_of=LATENT_SCALE_FACTOR,
description=FieldDescriptions.width, description=FieldDescriptions.width,
) )
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
@ -924,7 +930,7 @@ class ResizeLatentsInvocation(BaseInvocation):
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents.to(device), latents.to(device),
size=(self.height // 8, self.width // 8), size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR),
mode=self.mode, mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
) )
@ -1180,32 +1186,32 @@ class CropLatentsInvocation(BaseInvocation):
) )
width: int = InputField( width: int = InputField(
ge=64, ge=64,
multiple_of=_downsampling_factor, multiple_of=LATENT_SCALE_FACTOR,
description=FieldDescriptions.width, description=FieldDescriptions.width,
) )
height: int = InputField( height: int = InputField(
ge=64, ge=64,
multiple_of=_downsampling_factor, multiple_of=LATENT_SCALE_FACTOR,
description=FieldDescriptions.width, description=FieldDescriptions.width,
) )
x_offset: int = InputField( x_offset: int = InputField(
ge=0, ge=0,
multiple_of=_downsampling_factor, multiple_of=LATENT_SCALE_FACTOR,
description="x-coordinate", description="x-coordinate",
) )
y_offset: int = InputField( y_offset: int = InputField(
ge=0, ge=0,
multiple_of=_downsampling_factor, multiple_of=LATENT_SCALE_FACTOR,
description="y-coordinate", description="y-coordinate",
) )
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
x1 = self.x_offset // _downsampling_factor x1 = self.x_offset // LATENT_SCALE_FACTOR
y1 = self.y_offset // _downsampling_factor y1 = self.y_offset // LATENT_SCALE_FACTOR
x2 = x1 + (self.width // _downsampling_factor) x2 = x1 + (self.width // LATENT_SCALE_FACTOR)
y2 = y1 + (self.height // _downsampling_factor) y2 = y1 + (self.height // LATENT_SCALE_FACTOR)
cropped_latents = latents[:, :, y1:y2, x1:x2] cropped_latents = latents[:, :, y1:y2, x1:x2]