mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
parent
3cb13d6288
commit
4e8dcb7a1a
@ -495,7 +495,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
resize_mode=control_info.resize_mode,
|
resize_mode=control_info.resize_mode,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
|
||||||
|
|
||||||
def prep_ip_adapter_image_prompts(
|
def prep_ip_adapter_image_prompts(
|
||||||
self,
|
self,
|
||||||
|
@ -8,7 +8,7 @@ import torch
|
|||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
|
|
||||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
|
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
@ -27,8 +27,8 @@ class ControlNetExt(ExtensionBase):
|
|||||||
weight: Union[float, List[float]],
|
weight: Union[float, List[float]],
|
||||||
begin_step_percent: float,
|
begin_step_percent: float,
|
||||||
end_step_percent: float,
|
end_step_percent: float,
|
||||||
control_mode: str,
|
control_mode: CONTROLNET_MODE_VALUES,
|
||||||
resize_mode: str,
|
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._model = model
|
self._model = model
|
||||||
@ -43,8 +43,8 @@ class ControlNetExt(ExtensionBase):
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_extension(self, ctx: DenoiseContext):
|
def patch_extension(self, ctx: DenoiseContext):
|
||||||
try:
|
|
||||||
original_processors = self._model.attn_processors
|
original_processors = self._model.attn_processors
|
||||||
|
try:
|
||||||
self._model.set_attn_processor(ctx.inputs.attention_processor_cls())
|
self._model.set_attn_processor(ctx.inputs.attention_processor_cls())
|
||||||
|
|
||||||
yield None
|
yield None
|
||||||
@ -62,8 +62,6 @@ class ControlNetExt(ExtensionBase):
|
|||||||
do_classifier_free_guidance=False,
|
do_classifier_free_guidance=False,
|
||||||
width=image_width,
|
width=image_width,
|
||||||
height=image_height,
|
height=image_height,
|
||||||
# batch_size=batch_size * num_images_per_prompt,
|
|
||||||
# num_images_per_prompt=num_images_per_prompt,
|
|
||||||
device=ctx.latents.device,
|
device=ctx.latents.device,
|
||||||
dtype=ctx.latents.dtype,
|
dtype=ctx.latents.dtype,
|
||||||
control_mode=self._control_mode,
|
control_mode=self._control_mode,
|
||||||
@ -125,7 +123,7 @@ class ControlNetExt(ExtensionBase):
|
|||||||
cn_unet_kwargs = UNetKwargs(
|
cn_unet_kwargs = UNetKwargs(
|
||||||
sample=model_input,
|
sample=model_input,
|
||||||
timestep=ctx.timestep,
|
timestep=ctx.timestep,
|
||||||
encoder_hidden_states=None, # set later by conditoning
|
encoder_hidden_states=None, # set later by conditioning
|
||||||
cross_attention_kwargs=dict( # noqa: C408
|
cross_attention_kwargs=dict( # noqa: C408
|
||||||
percent_through=ctx.step_index / total_steps,
|
percent_through=ctx.step_index / total_steps,
|
||||||
),
|
),
|
||||||
@ -139,9 +137,14 @@ class ControlNetExt(ExtensionBase):
|
|||||||
weight = weight[ctx.step_index]
|
weight = weight[ctx.step_index]
|
||||||
|
|
||||||
tmp_kwargs = vars(cn_unet_kwargs)
|
tmp_kwargs = vars(cn_unet_kwargs)
|
||||||
tmp_kwargs.pop("down_block_additional_residuals", None)
|
|
||||||
tmp_kwargs.pop("mid_block_additional_residual", None)
|
# Remove kwargs not related to ControlNet unet
|
||||||
tmp_kwargs.pop("down_intrablock_additional_residuals", None)
|
# ControlNet guidance fields
|
||||||
|
del tmp_kwargs["down_block_additional_residuals"]
|
||||||
|
del tmp_kwargs["mid_block_additional_residual"]
|
||||||
|
|
||||||
|
# T2i Adapter guidance fields
|
||||||
|
del tmp_kwargs["down_intrablock_additional_residuals"]
|
||||||
|
|
||||||
# controlnet(s) inference
|
# controlnet(s) inference
|
||||||
down_samples, mid_sample = self._model(
|
down_samples, mid_sample = self._model(
|
||||||
|
Loading…
Reference in New Issue
Block a user