Tidy DenoiseLatentsInvocation.prep_control_data(...) and fix some type errors.

This commit is contained in:
Ryan Dick 2024-06-12 11:48:07 -04:00
parent 66cf2c59bd
commit 7ee5db87ad

View File

@ -55,6 +55,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
)
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
from invokeai.backend.util.mask import to_standard_float_mask
from invokeai.backend.util.silence_warnings import SilenceWarnings
@ -389,35 +390,35 @@ class DenoiseLatentsInvocation(BaseInvocation):
@staticmethod
def prep_control_data(
context: InvocationContext,
control_input: Optional[Union[ControlField, List[ControlField]]],
control_input: ControlField | list[ControlField] | None,
latents_shape: List[int],
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> Optional[List[ControlNetData]]:
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
if control_input is None:
control_list = None
elif isinstance(control_input, list) and len(control_input) == 0:
control_list = None
elif isinstance(control_input, ControlField):
) -> list[ControlNetData] | None:
# Normalize control_input to a list.
control_list: list[ControlField]
if isinstance(control_input, ControlField):
control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
elif isinstance(control_input, list):
control_list = control_input
elif control_input is None:
control_list = []
else:
control_list = None
if control_list is None:
return None
# After above handling, any control that is not None should now be of type list[ControlField].
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
controlnet_data = []
if len(control_list) == 0:
return None
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
_, _, latent_height, latent_width = latents_shape
control_height_resize = latent_height * LATENT_SCALE_FACTOR
control_width_resize = latent_width * LATENT_SCALE_FACTOR
controlnet_data: list[ControlNetData] = []
for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
assert isinstance(control_model, ControlNetModel)
# control_models.append(control_model)
control_image_field = control_info.image
input_image = context.images.get_pil(control_image_field.image_name)
# self.image.image_type, self.image.image_name
@ -438,7 +439,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
resize_mode=control_info.resize_mode,
)
control_item = ControlNetData(
model=control_model, # model object
model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,