Suggested changes

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
Sergey Borisov 2024-07-23 01:46:29 +03:00
parent 3cb13d6288
commit 4e8dcb7a1a
2 changed files with 13 additions and 11 deletions

View File

@ -495,7 +495,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
resize_mode=control_info.resize_mode,
)
)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
def prep_ip_adapter_image_prompts(
self,

View File

@ -8,7 +8,7 @@ import torch
from PIL.Image import Image
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
@ -27,8 +27,8 @@ class ControlNetExt(ExtensionBase):
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
control_mode: str,
resize_mode: str,
control_mode: CONTROLNET_MODE_VALUES,
resize_mode: CONTROLNET_RESIZE_VALUES,
):
super().__init__()
self._model = model
@ -43,8 +43,8 @@ class ControlNetExt(ExtensionBase):
@contextmanager
def patch_extension(self, ctx: DenoiseContext):
original_processors = self._model.attn_processors
try:
original_processors = self._model.attn_processors
self._model.set_attn_processor(ctx.inputs.attention_processor_cls())
yield None
@ -62,8 +62,6 @@ class ControlNetExt(ExtensionBase):
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=ctx.latents.device,
dtype=ctx.latents.dtype,
control_mode=self._control_mode,
@ -125,7 +123,7 @@ class ControlNetExt(ExtensionBase):
cn_unet_kwargs = UNetKwargs(
sample=model_input,
timestep=ctx.timestep,
encoder_hidden_states=None, # set later by conditoning
encoder_hidden_states=None, # set later by conditioning
cross_attention_kwargs=dict( # noqa: C408
percent_through=ctx.step_index / total_steps,
),
@ -139,9 +137,14 @@ class ControlNetExt(ExtensionBase):
weight = weight[ctx.step_index]
tmp_kwargs = vars(cn_unet_kwargs)
tmp_kwargs.pop("down_block_additional_residuals", None)
tmp_kwargs.pop("mid_block_additional_residual", None)
tmp_kwargs.pop("down_intrablock_additional_residuals", None)
# Remove kwargs not related to ControlNet unet
# ControlNet guidance fields
del tmp_kwargs["down_block_additional_residuals"]
del tmp_kwargs["mid_block_additional_residual"]
# T2i Adapter guidance fields
del tmp_kwargs["down_intrablock_additional_residuals"]
# controlnet(s) inference
down_samples, mid_sample = self._model(