From 7e943503511c8abf503f424f3f9fedc7eaffd74a Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 12 Jun 2024 11:48:07 -0400 Subject: [PATCH] Tidy DenoiseLatentsInvocation.prep_control_data(...) and fix some type errors. --- invokeai/app/invocations/denoise_latents.py | 41 +++++++++++---------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index a572cd6f05..c275243b96 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -55,6 +55,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( ) from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.hotfixes import ControlNetModel from invokeai.backend.util.mask import to_standard_float_mask from invokeai.backend.util.silence_warnings import SilenceWarnings @@ -389,35 +390,35 @@ class DenoiseLatentsInvocation(BaseInvocation): @staticmethod def prep_control_data( context: InvocationContext, - control_input: Optional[Union[ControlField, List[ControlField]]], + control_input: ControlField | list[ControlField] | None, latents_shape: List[int], exit_stack: ExitStack, do_classifier_free_guidance: bool = True, - ) -> Optional[List[ControlNetData]]: - # Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR. - control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR - control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR - if control_input is None: - control_list = None - elif isinstance(control_input, list) and len(control_input) == 0: - control_list = None - elif isinstance(control_input, ControlField): + ) -> list[ControlNetData] | None: + # Normalize control_input to a list. + control_list: list[ControlField] + if isinstance(control_input, ControlField): control_list = [control_input] - elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField): + elif isinstance(control_input, list): control_list = control_input + elif control_input is None: + control_list = [] else: - control_list = None - if control_list is None: - return None - # After above handling, any control that is not None should now be of type list[ControlField]. + raise ValueError(f"Unexpected control_input type: {type(control_input)}") - # FIXME: add checks to skip entry if model or image is None - # and if weight is None, populate with default 1.0? - controlnet_data = [] + if len(control_list) == 0: + return None + + # Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR. + _, _, latent_height, latent_width = latents_shape + control_height_resize = latent_height * LATENT_SCALE_FACTOR + control_width_resize = latent_width * LATENT_SCALE_FACTOR + + controlnet_data: list[ControlNetData] = [] for control_info in control_list: control_model = exit_stack.enter_context(context.models.load(control_info.control_model)) + assert isinstance(control_model, ControlNetModel) - # control_models.append(control_model) control_image_field = control_info.image input_image = context.images.get_pil(control_image_field.image_name) # self.image.image_type, self.image.image_name @@ -438,7 +439,7 @@ class DenoiseLatentsInvocation(BaseInvocation): resize_mode=control_info.resize_mode, ) control_item = ControlNetData( - model=control_model, # model object + model=control_model, image_tensor=control_image, weight=control_info.control_weight, begin_step_percent=control_info.begin_step_percent,