Tidy DenoiseLatentsInvocation.prep_control_data(...) and fix some type errors.

This commit is contained in:
Ryan Dick 2024-06-12 11:48:07 -04:00
parent 66cf2c59bd
commit 7ee5db87ad

View File

@ -55,6 +55,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
) )
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
from invokeai.backend.util.mask import to_standard_float_mask from invokeai.backend.util.mask import to_standard_float_mask
from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings
@ -389,35 +390,35 @@ class DenoiseLatentsInvocation(BaseInvocation):
@staticmethod @staticmethod
def prep_control_data( def prep_control_data(
context: InvocationContext, context: InvocationContext,
control_input: Optional[Union[ControlField, List[ControlField]]], control_input: ControlField | list[ControlField] | None,
latents_shape: List[int], latents_shape: List[int],
exit_stack: ExitStack, exit_stack: ExitStack,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
) -> Optional[List[ControlNetData]]: ) -> list[ControlNetData] | None:
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR. # Normalize control_input to a list.
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR control_list: list[ControlField]
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR if isinstance(control_input, ControlField):
if control_input is None:
control_list = None
elif isinstance(control_input, list) and len(control_input) == 0:
control_list = None
elif isinstance(control_input, ControlField):
control_list = [control_input] control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField): elif isinstance(control_input, list):
control_list = control_input control_list = control_input
elif control_input is None:
control_list = []
else: else:
control_list = None raise ValueError(f"Unexpected control_input type: {type(control_input)}")
if control_list is None:
return None
# After above handling, any control that is not None should now be of type list[ControlField].
# FIXME: add checks to skip entry if model or image is None if len(control_list) == 0:
# and if weight is None, populate with default 1.0? return None
controlnet_data = []
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
_, _, latent_height, latent_width = latents_shape
control_height_resize = latent_height * LATENT_SCALE_FACTOR
control_width_resize = latent_width * LATENT_SCALE_FACTOR
controlnet_data: list[ControlNetData] = []
for control_info in control_list: for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(control_info.control_model)) control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
assert isinstance(control_model, ControlNetModel)
# control_models.append(control_model)
control_image_field = control_info.image control_image_field = control_info.image
input_image = context.images.get_pil(control_image_field.image_name) input_image = context.images.get_pil(control_image_field.image_name)
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name
@ -438,7 +439,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
resize_mode=control_info.resize_mode, resize_mode=control_info.resize_mode,
) )
control_item = ControlNetData( control_item = ControlNetData(
model=control_model, # model object model=control_model,
image_tensor=control_image, image_tensor=control_image,
weight=control_info.control_weight, weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent, begin_step_percent=control_info.begin_step_percent,