mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Modular backend - add ControlNet (#6642)
## Summary ControlNet code from #6577. ## Related Issues / Discussions #6606 https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d ## QA Instructions Run with and without set `USE_MODULAR_DENOISE` environment. ## Merge Plan Merge #6641 firstly, to be able see output difference properly. If you think that there should be some kind of tests - feel free to add. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
This commit is contained in:
commit
7c975f0d00
@ -58,6 +58,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
|
||||
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
||||
@ -465,6 +466,38 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return controlnet_data
|
||||
|
||||
@staticmethod
|
||||
def parse_controlnet_field(
|
||||
exit_stack: ExitStack,
|
||||
context: InvocationContext,
|
||||
control_input: ControlField | list[ControlField] | None,
|
||||
ext_manager: ExtensionsManager,
|
||||
) -> None:
|
||||
# Normalize control_input to a list.
|
||||
control_list: list[ControlField]
|
||||
if isinstance(control_input, ControlField):
|
||||
control_list = [control_input]
|
||||
elif isinstance(control_input, list):
|
||||
control_list = control_input
|
||||
elif control_input is None:
|
||||
control_list = []
|
||||
else:
|
||||
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
|
||||
|
||||
for control_info in control_list:
|
||||
model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||
ext_manager.add_extension(
|
||||
ControlNetExt(
|
||||
model=model,
|
||||
image=context.images.get_pil(control_info.image.image_name),
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent,
|
||||
control_mode=control_info.control_mode,
|
||||
resize_mode=control_info.resize_mode,
|
||||
)
|
||||
)
|
||||
|
||||
def prep_ip_adapter_image_prompts(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@ -800,22 +833,30 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if self.unet.freeu_config:
|
||||
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
||||
|
||||
# ext: t2i/ip adapter
|
||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||
# context for loading additional models
|
||||
with ExitStack() as exit_stack:
|
||||
# later should be smth like:
|
||||
# for extension_field in self.extensions:
|
||||
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
|
||||
# ext_manager.add_extension(ext)
|
||||
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
|
||||
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
unet_info.model_on_device() as (cached_weights, unet),
|
||||
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
||||
# ext: controlnet
|
||||
ext_manager.patch_extensions(unet),
|
||||
# ext: freeu, seamless, ip adapter, lora
|
||||
ext_manager.patch_unet(unet, cached_weights),
|
||||
):
|
||||
sd_backend = StableDiffusionBackend(unet, scheduler)
|
||||
denoise_ctx.unet = unet
|
||||
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
|
||||
# ext: t2i/ip adapter
|
||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
unet_info.model_on_device() as (cached_weights, unet),
|
||||
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
||||
# ext: controlnet
|
||||
ext_manager.patch_extensions(denoise_ctx),
|
||||
# ext: freeu, seamless, ip adapter, lora
|
||||
ext_manager.patch_unet(unet, cached_weights),
|
||||
):
|
||||
sd_backend = StableDiffusionBackend(unet, scheduler)
|
||||
denoise_ctx.unet = unet
|
||||
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.detach().to("cpu")
|
||||
|
@ -52,7 +52,7 @@ class ExtensionBase:
|
||||
return self._callbacks
|
||||
|
||||
@contextmanager
|
||||
def patch_extension(self, context: DenoiseContext):
|
||||
def patch_extension(self, ctx: DenoiseContext):
|
||||
yield None
|
||||
|
||||
@contextmanager
|
||||
|
158
invokeai/backend/stable_diffusion/extensions/controlnet.py
Normal file
158
invokeai/backend/stable_diffusion/extensions/controlnet.py
Normal file
@ -0,0 +1,158 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
|
||||
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||
|
||||
|
||||
class ControlNetExt(ExtensionBase):
|
||||
def __init__(
|
||||
self,
|
||||
model: ControlNetModel,
|
||||
image: Image,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
control_mode: CONTROLNET_MODE_VALUES,
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||
):
|
||||
super().__init__()
|
||||
self._model = model
|
||||
self._image = image
|
||||
self._weight = weight
|
||||
self._begin_step_percent = begin_step_percent
|
||||
self._end_step_percent = end_step_percent
|
||||
self._control_mode = control_mode
|
||||
self._resize_mode = resize_mode
|
||||
|
||||
self._image_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
@contextmanager
|
||||
def patch_extension(self, ctx: DenoiseContext):
|
||||
original_processors = self._model.attn_processors
|
||||
try:
|
||||
self._model.set_attn_processor(ctx.inputs.attention_processor_cls())
|
||||
|
||||
yield None
|
||||
finally:
|
||||
self._model.set_attn_processor(original_processors)
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||
def resize_image(self, ctx: DenoiseContext):
|
||||
_, _, latent_height, latent_width = ctx.latents.shape
|
||||
image_height = latent_height * LATENT_SCALE_FACTOR
|
||||
image_width = latent_width * LATENT_SCALE_FACTOR
|
||||
|
||||
self._image_tensor = prepare_control_image(
|
||||
image=self._image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=image_width,
|
||||
height=image_height,
|
||||
device=ctx.latents.device,
|
||||
dtype=ctx.latents.dtype,
|
||||
control_mode=self._control_mode,
|
||||
resize_mode=self._resize_mode,
|
||||
)
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_UNET)
|
||||
def pre_unet_step(self, ctx: DenoiseContext):
|
||||
# skip if model not active in current step
|
||||
total_steps = len(ctx.inputs.timesteps)
|
||||
first_step = math.floor(self._begin_step_percent * total_steps)
|
||||
last_step = math.ceil(self._end_step_percent * total_steps)
|
||||
if ctx.step_index < first_step or ctx.step_index > last_step:
|
||||
return
|
||||
|
||||
# convert mode to internal flags
|
||||
soft_injection = self._control_mode in ["more_prompt", "more_control"]
|
||||
cfg_injection = self._control_mode in ["more_control", "unbalanced"]
|
||||
|
||||
# no negative conditioning in cfg_injection mode
|
||||
if cfg_injection:
|
||||
if ctx.conditioning_mode == ConditioningMode.Negative:
|
||||
return
|
||||
down_samples, mid_sample = self._run(ctx, soft_injection, ConditioningMode.Positive)
|
||||
|
||||
if ctx.conditioning_mode == ConditioningMode.Both:
|
||||
# add zeros as samples for negative conditioning
|
||||
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
||||
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
||||
|
||||
else:
|
||||
down_samples, mid_sample = self._run(ctx, soft_injection, ctx.conditioning_mode)
|
||||
|
||||
if (
|
||||
ctx.unet_kwargs.down_block_additional_residuals is None
|
||||
and ctx.unet_kwargs.mid_block_additional_residual is None
|
||||
):
|
||||
ctx.unet_kwargs.down_block_additional_residuals = down_samples
|
||||
ctx.unet_kwargs.mid_block_additional_residual = mid_sample
|
||||
else:
|
||||
# add controlnet outputs together if have multiple controlnets
|
||||
ctx.unet_kwargs.down_block_additional_residuals = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(
|
||||
ctx.unet_kwargs.down_block_additional_residuals, down_samples, strict=True
|
||||
)
|
||||
]
|
||||
ctx.unet_kwargs.mid_block_additional_residual += mid_sample
|
||||
|
||||
def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: ConditioningMode):
|
||||
total_steps = len(ctx.inputs.timesteps)
|
||||
|
||||
model_input = ctx.latent_model_input
|
||||
image_tensor = self._image_tensor
|
||||
if conditioning_mode == ConditioningMode.Both:
|
||||
model_input = torch.cat([model_input] * 2)
|
||||
image_tensor = torch.cat([image_tensor] * 2)
|
||||
|
||||
cn_unet_kwargs = UNetKwargs(
|
||||
sample=model_input,
|
||||
timestep=ctx.timestep,
|
||||
encoder_hidden_states=None, # set later by conditioning
|
||||
cross_attention_kwargs=dict( # noqa: C408
|
||||
percent_through=ctx.step_index / total_steps,
|
||||
),
|
||||
)
|
||||
|
||||
ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode)
|
||||
|
||||
# get static weight, or weight corresponding to current step
|
||||
weight = self._weight
|
||||
if isinstance(weight, list):
|
||||
weight = weight[ctx.step_index]
|
||||
|
||||
tmp_kwargs = vars(cn_unet_kwargs)
|
||||
|
||||
# Remove kwargs not related to ControlNet unet
|
||||
# ControlNet guidance fields
|
||||
del tmp_kwargs["down_block_additional_residuals"]
|
||||
del tmp_kwargs["mid_block_additional_residual"]
|
||||
|
||||
# T2i Adapter guidance fields
|
||||
del tmp_kwargs["down_intrablock_additional_residuals"]
|
||||
|
||||
# controlnet(s) inference
|
||||
down_samples, mid_sample = self._model(
|
||||
controlnet_cond=image_tensor,
|
||||
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
|
||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||
return_dict=False,
|
||||
**vars(cn_unet_kwargs),
|
||||
)
|
||||
|
||||
return down_samples, mid_sample
|
@ -52,13 +52,13 @@ class ExtensionsManager:
|
||||
cb.function(ctx)
|
||||
|
||||
@contextmanager
|
||||
def patch_extensions(self, context: DenoiseContext):
|
||||
def patch_extensions(self, ctx: DenoiseContext):
|
||||
if self._is_canceled and self._is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
for ext in self._extensions:
|
||||
exit_stack.enter_context(ext.patch_extension(context))
|
||||
exit_stack.enter_context(ext.patch_extension(ctx))
|
||||
|
||||
yield None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user