From e1c53a2465583869126833704b5b5a0b1d42f062 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 27 Nov 2023 11:12:15 -0500 Subject: [PATCH] Use LATENT_SCALE_FACTOR = 8 constant in CropLatentsInvocation. --- invokeai/app/invocations/latent.py | 34 ++++++++++++++++++------------ 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index c143eb891c..b5c9c876c8 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -79,6 +79,12 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device()) SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] +# HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to +# be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale +# factor is hard-coded to a literal '8' rather than using this constant. +# The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. +LATENT_SCALE_FACTOR = 8 + @invocation_output("scheduler_output") class SchedulerOutput(BaseInvocationOutput): @@ -390,9 +396,9 @@ class DenoiseLatentsInvocation(BaseInvocation): exit_stack: ExitStack, do_classifier_free_guidance: bool = True, ) -> List[ControlNetData]: - # assuming fixed dimensional scaling of 8:1 for image:latents - control_height_resize = latents_shape[2] * 8 - control_width_resize = latents_shape[3] * 8 + # Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR. + control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR + control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR if control_input is None: control_list = None elif isinstance(control_input, list) and len(control_input) == 0: @@ -905,12 +911,12 @@ class ResizeLatentsInvocation(BaseInvocation): ) width: int = InputField( ge=64, - multiple_of=8, + multiple_of=LATENT_SCALE_FACTOR, description=FieldDescriptions.width, ) height: int = InputField( ge=64, - multiple_of=8, + multiple_of=LATENT_SCALE_FACTOR, description=FieldDescriptions.width, ) mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) @@ -924,7 +930,7 @@ class ResizeLatentsInvocation(BaseInvocation): resized_latents = torch.nn.functional.interpolate( latents.to(device), - size=(self.height // 8, self.width // 8), + size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR), mode=self.mode, antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, ) @@ -1180,32 +1186,32 @@ class CropLatentsInvocation(BaseInvocation): ) width: int = InputField( ge=64, - multiple_of=_downsampling_factor, + multiple_of=LATENT_SCALE_FACTOR, description=FieldDescriptions.width, ) height: int = InputField( ge=64, - multiple_of=_downsampling_factor, + multiple_of=LATENT_SCALE_FACTOR, description=FieldDescriptions.width, ) x_offset: int = InputField( ge=0, - multiple_of=_downsampling_factor, + multiple_of=LATENT_SCALE_FACTOR, description="x-coordinate", ) y_offset: int = InputField( ge=0, - multiple_of=_downsampling_factor, + multiple_of=LATENT_SCALE_FACTOR, description="y-coordinate", ) def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) - x1 = self.x_offset // _downsampling_factor - y1 = self.y_offset // _downsampling_factor - x2 = x1 + (self.width // _downsampling_factor) - y2 = y1 + (self.height // _downsampling_factor) + x1 = self.x_offset // LATENT_SCALE_FACTOR + y1 = self.y_offset // LATENT_SCALE_FACTOR + x2 = x1 + (self.width // LATENT_SCALE_FACTOR) + y2 = y1 + (self.height // LATENT_SCALE_FACTOR) cropped_latents = latents[:, :, y1:y2, x1:x2]