mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add ControlNet support to denoise
This commit is contained in:
parent
f9c61f1b6c
commit
42356ec866
@ -58,6 +58,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||||
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
@ -463,6 +464,39 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
return controlnet_data
|
return controlnet_data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_controlnet_field(
|
||||||
|
exit_stack: ExitStack,
|
||||||
|
context: InvocationContext,
|
||||||
|
control_input: ControlField | list[ControlField] | None,
|
||||||
|
ext_manager: ExtensionsManager,
|
||||||
|
) -> None:
|
||||||
|
# Normalize control_input to a list.
|
||||||
|
control_list: list[ControlField]
|
||||||
|
if isinstance(control_input, ControlField):
|
||||||
|
control_list = [control_input]
|
||||||
|
elif isinstance(control_input, list):
|
||||||
|
control_list = control_input
|
||||||
|
elif control_input is None:
|
||||||
|
control_list = []
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
|
||||||
|
|
||||||
|
for control_info in control_list:
|
||||||
|
model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||||
|
ext_manager.add_extension(
|
||||||
|
ControlNetExt(
|
||||||
|
model=model,
|
||||||
|
image=context.images.get_pil(control_info.image.image_name),
|
||||||
|
weight=control_info.control_weight,
|
||||||
|
begin_step_percent=control_info.begin_step_percent,
|
||||||
|
end_step_percent=control_info.end_step_percent,
|
||||||
|
control_mode=control_info.control_mode,
|
||||||
|
resize_mode=control_info.resize_mode,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||||
|
|
||||||
def prep_ip_adapter_image_prompts(
|
def prep_ip_adapter_image_prompts(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
@ -790,22 +824,30 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
ext_manager.add_extension(PreviewExt(step_callback))
|
ext_manager.add_extension(PreviewExt(step_callback))
|
||||||
|
|
||||||
# ext: t2i/ip adapter
|
# context for loading additional models
|
||||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
with ExitStack() as exit_stack:
|
||||||
|
# later should be smth like:
|
||||||
|
# for extension_field in self.extensions:
|
||||||
|
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
|
||||||
|
# ext_manager.add_extension(ext)
|
||||||
|
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
|
||||||
|
|
||||||
unet_info = context.models.load(self.unet.unet)
|
# ext: t2i/ip adapter
|
||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||||
with (
|
|
||||||
unet_info.model_on_device() as (model_state_dict, unet),
|
unet_info = context.models.load(self.unet.unet)
|
||||||
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
# ext: controlnet
|
with (
|
||||||
ext_manager.patch_extensions(unet),
|
unet_info.model_on_device() as (model_state_dict, unet),
|
||||||
# ext: freeu, seamless, ip adapter, lora
|
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
||||||
ext_manager.patch_unet(model_state_dict, unet),
|
# ext: controlnet
|
||||||
):
|
ext_manager.patch_extensions(denoise_ctx),
|
||||||
sd_backend = StableDiffusionBackend(unet, scheduler)
|
# ext: freeu, seamless, ip adapter, lora
|
||||||
denoise_ctx.unet = unet
|
ext_manager.patch_unet(model_state_dict, unet),
|
||||||
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
|
):
|
||||||
|
sd_backend = StableDiffusionBackend(unet, scheduler)
|
||||||
|
denoise_ctx.unet = unet
|
||||||
|
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
result_latents = result_latents.detach().to("cpu")
|
result_latents = result_latents.detach().to("cpu")
|
||||||
|
155
invokeai/backend/stable_diffusion/extensions/controlnet.py
Normal file
155
invokeai/backend/stable_diffusion/extensions/controlnet.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL.Image import Image
|
||||||
|
|
||||||
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||||
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetExt(ExtensionBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: ControlNetModel,
|
||||||
|
image: Image,
|
||||||
|
weight: Union[float, List[float]],
|
||||||
|
begin_step_percent: float,
|
||||||
|
end_step_percent: float,
|
||||||
|
control_mode: str,
|
||||||
|
resize_mode: str,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.image = image
|
||||||
|
self.weight = weight
|
||||||
|
self.begin_step_percent = begin_step_percent
|
||||||
|
self.end_step_percent = end_step_percent
|
||||||
|
self.control_mode = control_mode
|
||||||
|
self.resize_mode = resize_mode
|
||||||
|
|
||||||
|
self.image_tensor: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_extension(self, ctx: DenoiseContext):
|
||||||
|
try:
|
||||||
|
original_processors = self.model.attn_processors
|
||||||
|
self.model.set_attn_processor(ctx.inputs.attention_processor_cls())
|
||||||
|
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
self.model.set_attn_processor(original_processors)
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
|
def resize_image(self, ctx: DenoiseContext):
|
||||||
|
_, _, latent_height, latent_width = ctx.latents.shape
|
||||||
|
image_height = latent_height * LATENT_SCALE_FACTOR
|
||||||
|
image_width = latent_width * LATENT_SCALE_FACTOR
|
||||||
|
|
||||||
|
self.image_tensor = prepare_control_image(
|
||||||
|
image=self.image,
|
||||||
|
do_classifier_free_guidance=False,
|
||||||
|
width=image_width,
|
||||||
|
height=image_height,
|
||||||
|
# batch_size=batch_size * num_images_per_prompt,
|
||||||
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
|
device=ctx.latents.device,
|
||||||
|
dtype=ctx.latents.dtype,
|
||||||
|
control_mode=self.control_mode,
|
||||||
|
resize_mode=self.resize_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_UNET)
|
||||||
|
def pre_unet_step(self, ctx: DenoiseContext):
|
||||||
|
# skip if model not active in current step
|
||||||
|
total_steps = len(ctx.inputs.timesteps)
|
||||||
|
first_step = math.floor(self.begin_step_percent * total_steps)
|
||||||
|
last_step = math.ceil(self.end_step_percent * total_steps)
|
||||||
|
if ctx.step_index < first_step or ctx.step_index > last_step:
|
||||||
|
return
|
||||||
|
|
||||||
|
# convert mode to internal flags
|
||||||
|
soft_injection = self.control_mode in ["more_prompt", "more_control"]
|
||||||
|
cfg_injection = self.control_mode in ["more_control", "unbalanced"]
|
||||||
|
|
||||||
|
# no negative conditioning in cfg_injection mode
|
||||||
|
if cfg_injection:
|
||||||
|
if ctx.conditioning_mode == ConditioningMode.Negative:
|
||||||
|
return
|
||||||
|
down_samples, mid_sample = self._run(ctx, soft_injection, ConditioningMode.Positive)
|
||||||
|
|
||||||
|
if ctx.conditioning_mode == ConditioningMode.Both:
|
||||||
|
# add zeros as samples for negative conditioning
|
||||||
|
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
||||||
|
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
||||||
|
|
||||||
|
else:
|
||||||
|
down_samples, mid_sample = self._run(ctx, soft_injection, ctx.conditioning_mode)
|
||||||
|
|
||||||
|
if (
|
||||||
|
ctx.unet_kwargs.down_block_additional_residuals is None
|
||||||
|
and ctx.unet_kwargs.mid_block_additional_residual is None
|
||||||
|
):
|
||||||
|
ctx.unet_kwargs.down_block_additional_residuals = down_samples
|
||||||
|
ctx.unet_kwargs.mid_block_additional_residual = mid_sample
|
||||||
|
else:
|
||||||
|
# add controlnet outputs together if have multiple controlnets
|
||||||
|
ctx.unet_kwargs.down_block_additional_residuals = [
|
||||||
|
samples_prev + samples_curr
|
||||||
|
for samples_prev, samples_curr in zip(
|
||||||
|
ctx.unet_kwargs.down_block_additional_residuals, down_samples, strict=True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
ctx.unet_kwargs.mid_block_additional_residual += mid_sample
|
||||||
|
|
||||||
|
def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: ConditioningMode):
|
||||||
|
total_steps = len(ctx.inputs.timesteps)
|
||||||
|
|
||||||
|
model_input = ctx.latent_model_input
|
||||||
|
image_tensor = self.image_tensor
|
||||||
|
if conditioning_mode == ConditioningMode.Both:
|
||||||
|
model_input = torch.cat([model_input] * 2)
|
||||||
|
image_tensor = torch.cat([image_tensor] * 2)
|
||||||
|
|
||||||
|
cn_unet_kwargs = UNetKwargs(
|
||||||
|
sample=model_input,
|
||||||
|
timestep=ctx.timestep,
|
||||||
|
encoder_hidden_states=None, # set later by conditoning
|
||||||
|
cross_attention_kwargs=dict( # noqa: C408
|
||||||
|
percent_through=ctx.step_index / total_steps,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode)
|
||||||
|
|
||||||
|
# get static weight, or weight corresponding to current step
|
||||||
|
weight = self.weight
|
||||||
|
if isinstance(weight, list):
|
||||||
|
weight = weight[ctx.step_index]
|
||||||
|
|
||||||
|
tmp_kwargs = vars(cn_unet_kwargs)
|
||||||
|
tmp_kwargs.pop("down_block_additional_residuals", None)
|
||||||
|
tmp_kwargs.pop("mid_block_additional_residual", None)
|
||||||
|
tmp_kwargs.pop("down_intrablock_additional_residuals", None)
|
||||||
|
|
||||||
|
# controlnet(s) inference
|
||||||
|
down_samples, mid_sample = self.model(
|
||||||
|
controlnet_cond=image_tensor,
|
||||||
|
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
|
||||||
|
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||||
|
return_dict=False,
|
||||||
|
**vars(cn_unet_kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
return down_samples, mid_sample
|
Loading…
Reference in New Issue
Block a user