mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Tidy DenoiseLatentsInvocation.prep_control_data(...) and fix some type errors.
This commit is contained in:
parent
66cf2c59bd
commit
7ee5db87ad
@ -55,6 +55,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||||
from invokeai.backend.util.mask import to_standard_float_mask
|
from invokeai.backend.util.mask import to_standard_float_mask
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
@ -389,35 +390,35 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def prep_control_data(
|
def prep_control_data(
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
control_input: Optional[Union[ControlField, List[ControlField]]],
|
control_input: ControlField | list[ControlField] | None,
|
||||||
latents_shape: List[int],
|
latents_shape: List[int],
|
||||||
exit_stack: ExitStack,
|
exit_stack: ExitStack,
|
||||||
do_classifier_free_guidance: bool = True,
|
do_classifier_free_guidance: bool = True,
|
||||||
) -> Optional[List[ControlNetData]]:
|
) -> list[ControlNetData] | None:
|
||||||
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
# Normalize control_input to a list.
|
||||||
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
|
control_list: list[ControlField]
|
||||||
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
|
if isinstance(control_input, ControlField):
|
||||||
if control_input is None:
|
|
||||||
control_list = None
|
|
||||||
elif isinstance(control_input, list) and len(control_input) == 0:
|
|
||||||
control_list = None
|
|
||||||
elif isinstance(control_input, ControlField):
|
|
||||||
control_list = [control_input]
|
control_list = [control_input]
|
||||||
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
|
elif isinstance(control_input, list):
|
||||||
control_list = control_input
|
control_list = control_input
|
||||||
|
elif control_input is None:
|
||||||
|
control_list = []
|
||||||
else:
|
else:
|
||||||
control_list = None
|
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
|
||||||
if control_list is None:
|
|
||||||
return None
|
|
||||||
# After above handling, any control that is not None should now be of type list[ControlField].
|
|
||||||
|
|
||||||
# FIXME: add checks to skip entry if model or image is None
|
if len(control_list) == 0:
|
||||||
# and if weight is None, populate with default 1.0?
|
return None
|
||||||
controlnet_data = []
|
|
||||||
|
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
||||||
|
_, _, latent_height, latent_width = latents_shape
|
||||||
|
control_height_resize = latent_height * LATENT_SCALE_FACTOR
|
||||||
|
control_width_resize = latent_width * LATENT_SCALE_FACTOR
|
||||||
|
|
||||||
|
controlnet_data: list[ControlNetData] = []
|
||||||
for control_info in control_list:
|
for control_info in control_list:
|
||||||
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||||
|
assert isinstance(control_model, ControlNetModel)
|
||||||
|
|
||||||
# control_models.append(control_model)
|
|
||||||
control_image_field = control_info.image
|
control_image_field = control_info.image
|
||||||
input_image = context.images.get_pil(control_image_field.image_name)
|
input_image = context.images.get_pil(control_image_field.image_name)
|
||||||
# self.image.image_type, self.image.image_name
|
# self.image.image_type, self.image.image_name
|
||||||
@ -438,7 +439,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
resize_mode=control_info.resize_mode,
|
resize_mode=control_info.resize_mode,
|
||||||
)
|
)
|
||||||
control_item = ControlNetData(
|
control_item = ControlNetData(
|
||||||
model=control_model, # model object
|
model=control_model,
|
||||||
image_tensor=control_image,
|
image_tensor=control_image,
|
||||||
weight=control_info.control_weight,
|
weight=control_info.control_weight,
|
||||||
begin_step_percent=control_info.begin_step_percent,
|
begin_step_percent=control_info.begin_step_percent,
|
||||||
|
Loading…
Reference in New Issue
Block a user