Suggested changes

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
Sergey Borisov 2024-07-23 01:46:29 +03:00
parent 3cb13d6288
commit 4e8dcb7a1a
2 changed files with 13 additions and 11 deletions

View File

@ -495,7 +495,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
resize_mode=control_info.resize_mode, resize_mode=control_info.resize_mode,
) )
) )
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
def prep_ip_adapter_image_prompts( def prep_ip_adapter_image_prompts(
self, self,

View File

@ -8,7 +8,7 @@ import torch
from PIL.Image import Image from PIL.Image import Image
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
@ -27,8 +27,8 @@ class ControlNetExt(ExtensionBase):
weight: Union[float, List[float]], weight: Union[float, List[float]],
begin_step_percent: float, begin_step_percent: float,
end_step_percent: float, end_step_percent: float,
control_mode: str, control_mode: CONTROLNET_MODE_VALUES,
resize_mode: str, resize_mode: CONTROLNET_RESIZE_VALUES,
): ):
super().__init__() super().__init__()
self._model = model self._model = model
@ -43,8 +43,8 @@ class ControlNetExt(ExtensionBase):
@contextmanager @contextmanager
def patch_extension(self, ctx: DenoiseContext): def patch_extension(self, ctx: DenoiseContext):
try:
original_processors = self._model.attn_processors original_processors = self._model.attn_processors
try:
self._model.set_attn_processor(ctx.inputs.attention_processor_cls()) self._model.set_attn_processor(ctx.inputs.attention_processor_cls())
yield None yield None
@ -62,8 +62,6 @@ class ControlNetExt(ExtensionBase):
do_classifier_free_guidance=False, do_classifier_free_guidance=False,
width=image_width, width=image_width,
height=image_height, height=image_height,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=ctx.latents.device, device=ctx.latents.device,
dtype=ctx.latents.dtype, dtype=ctx.latents.dtype,
control_mode=self._control_mode, control_mode=self._control_mode,
@ -125,7 +123,7 @@ class ControlNetExt(ExtensionBase):
cn_unet_kwargs = UNetKwargs( cn_unet_kwargs = UNetKwargs(
sample=model_input, sample=model_input,
timestep=ctx.timestep, timestep=ctx.timestep,
encoder_hidden_states=None, # set later by conditoning encoder_hidden_states=None, # set later by conditioning
cross_attention_kwargs=dict( # noqa: C408 cross_attention_kwargs=dict( # noqa: C408
percent_through=ctx.step_index / total_steps, percent_through=ctx.step_index / total_steps,
), ),
@ -139,9 +137,14 @@ class ControlNetExt(ExtensionBase):
weight = weight[ctx.step_index] weight = weight[ctx.step_index]
tmp_kwargs = vars(cn_unet_kwargs) tmp_kwargs = vars(cn_unet_kwargs)
tmp_kwargs.pop("down_block_additional_residuals", None)
tmp_kwargs.pop("mid_block_additional_residual", None) # Remove kwargs not related to ControlNet unet
tmp_kwargs.pop("down_intrablock_additional_residuals", None) # ControlNet guidance fields
del tmp_kwargs["down_block_additional_residuals"]
del tmp_kwargs["mid_block_additional_residual"]
# T2i Adapter guidance fields
del tmp_kwargs["down_intrablock_additional_residuals"]
# controlnet(s) inference # controlnet(s) inference
down_samples, mid_sample = self._model( down_samples, mid_sample = self._model(