Base of modular backend (#6606)

## Summary

Base code of new modular backend from #6577.
Contains normal generation and regional prompts support.
Also preview extension included to test if extensions logic works.

## Related Issues / Discussions


https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d

## QA Instructions

Run with and without set `USE_MODULAR_DENOISE` environment.
Currently only normal and regional conditionings supported, so just
generate some images and compare with main output.

## Merge Plan

Discuss a bit more about injection point names?
As if for example in future unet will be overridable, current
`pre_unet`/`post_unet` assumes to name override as `unet` what feels a
bit odd.
Also `apply_cfg` - future implementation could ignore/not use cfg, so in
this case `combine_noise_predictions`/`combine_noise` seems more
suitable.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
This commit is contained in:
Ryan Dick 2024-07-19 16:37:57 -04:00 committed by GitHub
commit 473f4cc1c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 908 additions and 25 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect
import os
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
@ -39,6 +40,7 @@ from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
StableDiffusionGeneratorPipeline,
@ -53,6 +55,11 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import TorchDevice
@ -314,9 +321,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
context: InvocationContext,
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
unet: UNet2DConditionModel,
latent_height: int,
latent_width: int,
device: torch.device,
dtype: torch.dtype,
cfg_scale: float | list[float],
steps: int,
cfg_rescale_multiplier: float,
@ -330,10 +338,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
uncond_list = [uncond_list]
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
cond_list, context, device, dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype
uncond_list, context, device, dtype
)
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
@ -341,14 +349,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
dtype=dtype,
)
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
dtype=dtype,
)
if isinstance(cfg_scale, list):
@ -707,9 +715,108 @@ class DenoiseLatentsInvocation(BaseInvocation):
return seed, noise, latents
def invoke(self, context: InvocationContext) -> LatentsOutput:
if os.environ.get("USE_MODULAR_DENOISE", False):
return self._new_invoke(context)
else:
return self._old_invoke(context)
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
ext_manager = ExtensionsManager(is_canceled=context.util.is_canceled)
device = TorchDevice.choose_torch_device()
dtype = TorchDevice.choose_torch_dtype()
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)
_, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
cfg_scale=self.cfg_scale,
steps=self.steps,
latent_height=latent_height,
latent_width=latent_width,
device=device,
dtype=dtype,
# TODO: old backend, remove
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
seed=seed,
device=device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
)
denoise_ctx = DenoiseContext(
inputs=DenoiseInputs(
orig_latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
attention_processor_cls=CustomAttnProcessor2_0,
),
unet=None,
scheduler=scheduler,
)
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
### preview
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
ext_manager.add_extension(PreviewExt(step_callback))
# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (model_state_dict, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet
ext_manager.patch_extensions(unet),
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet),
):
sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.detach().to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
@ -788,7 +895,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
device=unet.device,
dtype=unet.dtype,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
import numpy as np
import torch
@ -32,8 +32,27 @@ with LoRAHelper.apply_lora_unet(unet, loras):
"""
# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher:
@staticmethod
@contextmanager
def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]):
"""A context manager that patches `unet` with the provided attention processor.
Args:
unet (UNet2DConditionModel): The UNet model to patch.
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
"""
unet_orig_processors = unet.attn_processors
# create separate instance for each attention, to be able modify each attention separately
unet_new_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
try:
unet.set_attn_processor(unet_new_processors)
yield None
finally:
unet.set_attn_processor(unet_orig_processors)
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key

View File

@ -0,0 +1,131 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union
import torch
from diffusers import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode, TextConditioningData
@dataclass
class UNetKwargs:
sample: torch.Tensor
timestep: Union[torch.Tensor, float, int]
encoder_hidden_states: torch.Tensor
class_labels: Optional[torch.Tensor] = None
timestep_cond: Optional[torch.Tensor] = None
attention_mask: Optional[torch.Tensor] = None
cross_attention_kwargs: Optional[Dict[str, Any]] = None
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
mid_block_additional_residual: Optional[torch.Tensor] = None
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
encoder_attention_mask: Optional[torch.Tensor] = None
# return_dict: bool = True
@dataclass
class DenoiseInputs:
"""Initial variables passed to denoise. Supposed to be unchanged."""
# The latent-space image to denoise.
# Shape: [batch, channels, latent_height, latent_width]
# - If we are inpainting, this is the initial latent image before noise has been added.
# - If we are generating a new image, this should be initialized to zeros.
# - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
orig_latents: torch.Tensor
# kwargs forwarded to the scheduler.step() method.
scheduler_step_kwargs: dict[str, Any]
# Text conditionging data.
conditioning_data: TextConditioningData
# Noise used for two purposes:
# 1. Used by the scheduler to noise the initial `latents` before denoising.
# 2. Used to noise the `masked_latents` when inpainting.
# `noise` should be None if the `latents` tensor has already been noised.
# Shape: [1 or batch, channels, latent_height, latent_width]
noise: Optional[torch.Tensor]
# The seed used to generate the noise for the denoising process.
# HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
# same noise used earlier in the pipeline. This should really be handled in a clearer way.
seed: int
# The timestep schedule for the denoising process.
timesteps: torch.Tensor
# The first timestep in the schedule. This is used to determine the initial noise level, so
# should be populated if you want noise applied *even* if timesteps is empty.
init_timestep: torch.Tensor
# Class of attention processor that is used.
attention_processor_cls: Type[Any]
@dataclass
class DenoiseContext:
"""Context with all variables in denoise"""
# Initial variables passed to denoise. Supposed to be unchanged.
inputs: DenoiseInputs
# Scheduler which used to apply noise predictions.
scheduler: SchedulerMixin
# UNet model.
unet: Optional[UNet2DConditionModel] = None
# Current state of latent-space image in denoising process.
# None until `pre_denoise_loop` callback.
# Shape: [batch, channels, latent_height, latent_width]
latents: Optional[torch.Tensor] = None
# Current denoising step index.
# None until `pre_step` callback.
step_index: Optional[int] = None
# Current denoising step timestep.
# None until `pre_step` callback.
timestep: Optional[torch.Tensor] = None
# Arguments which will be passed to UNet model.
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
unet_kwargs: Optional[UNetKwargs] = None
# SchedulerOutput class returned from step function(normally, generated by scheduler).
# Supposed to be used only in `post_step` callback, otherwise can be None.
step_output: Optional[SchedulerOutput] = None
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
# Available in events inside step(between `pre_step` and `post_stop`).
# Shape: [batch, channels, latent_height, latent_width]
latent_model_input: Optional[torch.Tensor] = None
# [TMP] Defines on which conditionings current unet call will be runned.
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
conditioning_mode: Optional[ConditioningMode] = None
# [TMP] Noise predictions from negative conditioning.
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
# Shape: [batch, channels, latent_height, latent_width]
negative_noise_pred: Optional[torch.Tensor] = None
# [TMP] Noise predictions from positive conditioning.
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
# Shape: [batch, channels, latent_height, latent_width]
positive_noise_pred: Optional[torch.Tensor] = None
# Combined noise prediction from passed conditionings.
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
# Shape: [batch, channels, latent_height, latent_width]
noise_pred: Optional[torch.Tensor] = None
# Dictionary for extensions to pass extra info about denoise process to other extensions.
extra: dict = field(default_factory=dict)

View File

@ -23,21 +23,12 @@ from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
@dataclass
class PipelineIntermediateState:
step: int
order: int
total_steps: int
timestep: int
latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None
@dataclass
class AddsMaskGuidance:
mask: torch.Tensor

View File

@ -1,10 +1,17 @@
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import List, Optional, Union
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
if TYPE_CHECKING:
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
@dataclass
@ -95,6 +102,12 @@ class TextConditioningRegions:
assert self.masks.shape[1] == len(self.ranges)
class ConditioningMode(Enum):
Both = "both"
Negative = "negative"
Positive = "positive"
class TextConditioningData:
def __init__(
self,
@ -103,7 +116,7 @@ class TextConditioningData:
uncond_regions: Optional[TextConditioningRegions],
cond_regions: Optional[TextConditioningRegions],
guidance_scale: Union[float, List[float]],
guidance_rescale_multiplier: float = 0,
guidance_rescale_multiplier: float = 0, # TODO: old backend, remove
):
self.uncond_text = uncond_text
self.cond_text = cond_text
@ -114,6 +127,7 @@ class TextConditioningData:
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
self.guidance_scale = guidance_scale
# TODO: old backend, remove
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
self.guidance_rescale_multiplier = guidance_rescale_multiplier
@ -121,3 +135,114 @@ class TextConditioningData:
def is_sdxl(self):
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
return isinstance(self.cond_text, SDXLConditioningInfo)
def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode):
"""Fills unet arguments with data from provided conditionings.
Args:
unet_kwargs (UNetKwargs): Object which stores UNet model arguments.
conditioning_mode (ConditioningMode): Describes which conditionings should be used.
"""
_, _, h, w = unet_kwargs.sample.shape
device = unet_kwargs.sample.device
dtype = unet_kwargs.sample.dtype
# TODO: combine regions with conditionings
if conditioning_mode == ConditioningMode.Both:
conditionings = [self.uncond_text, self.cond_text]
c_regions = [self.uncond_regions, self.cond_regions]
elif conditioning_mode == ConditioningMode.Positive:
conditionings = [self.cond_text]
c_regions = [self.cond_regions]
elif conditioning_mode == ConditioningMode.Negative:
conditionings = [self.uncond_text]
c_regions = [self.uncond_regions]
else:
raise ValueError(f"Unexpected conditioning mode: {conditioning_mode}")
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(
[c.embeds for c in conditionings]
)
unet_kwargs.encoder_hidden_states = encoder_hidden_states
unet_kwargs.encoder_attention_mask = encoder_attention_mask
if self.is_sdxl():
added_cond_kwargs = dict( # noqa: C408
text_embeds=torch.cat([c.pooled_embeds for c in conditionings]),
time_ids=torch.cat([c.add_time_ids for c in conditionings]),
)
unet_kwargs.added_cond_kwargs = added_cond_kwargs
if any(r is not None for r in c_regions):
tmp_regions = []
for c, r in zip(conditionings, c_regions, strict=True):
if r is None:
r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=dtype),
ranges=[Range(start=0, end=c.embeds.shape[1])],
)
tmp_regions.append(r)
if unet_kwargs.cross_attention_kwargs is None:
unet_kwargs.cross_attention_kwargs = {}
unet_kwargs.cross_attention_kwargs.update(
regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
)
@staticmethod
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int) -> torch.Tensor:
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
@classmethod
def _pad_conditioning(
cls,
cond: torch.Tensor,
target_len: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Pad provided conditioning tensor to target_len by zeros and returns mask of unpadded bytes.
Args:
cond (torch.Tensor): Conditioning tensor which to pads by zeros.
target_len (int): To which length(tokens count) pad tensor.
"""
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
if cond.shape[1] < target_len:
conditioning_attention_mask = cls._pad_zeros(
conditioning_attention_mask,
pad_shape=(cond.shape[0], target_len - cond.shape[1]),
dim=1,
)
cond = cls._pad_zeros(
cond,
pad_shape=(cond.shape[0], target_len - cond.shape[1], cond.shape[2]),
dim=1,
)
return cond, conditioning_attention_mask
@classmethod
def _concat_conditionings_for_batch(
cls,
conditionings: List[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Concatenate provided conditioning tensors to one batched tensor.
If tensors have different sizes then pad them by zeros and creates
encoder_attention_mask to exclude padding from attention.
Args:
conditionings (List[torch.Tensor]): List of conditioning tensors to concatenate.
"""
encoder_attention_mask = None
max_len = max([c.shape[1] for c in conditionings])
if any(c.shape[1] != max_len for c in conditionings):
encoder_attention_masks = [None] * len(conditionings)
for i in range(len(conditionings)):
conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(conditionings[i], max_len)
encoder_attention_mask = torch.cat(encoder_attention_masks)
return torch.cat(conditionings), encoder_attention_mask

View File

@ -1,9 +1,14 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningRegions,
)
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningRegions,
)
class RegionalPromptData:

View File

@ -0,0 +1,140 @@
from __future__ import annotations
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from tqdm.auto import tqdm
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
class StableDiffusionBackend:
def __init__(
self,
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
):
self.unet = unet
self.scheduler = scheduler
config = get_config()
self._sequential_guidance = config.sequential_guidance
def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
if ctx.inputs.init_timestep.shape[0] == 0:
return ctx.inputs.orig_latents
ctx.latents = ctx.inputs.orig_latents.clone()
if ctx.inputs.noise is not None:
batch_size = ctx.latents.shape[0]
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
ctx.latents = ctx.scheduler.add_noise(
ctx.latents, ctx.inputs.noise, ctx.inputs.init_timestep.expand(batch_size)
)
# if no work to do, return latents
if ctx.inputs.timesteps.shape[0] == 0:
return ctx.latents
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
# ext: preview[pre_denoise_loop, priority=low]
ext_manager.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, ctx)
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020
# ext: inpaint (apply mask to latents on non-inpaint models)
ext_manager.run_callback(ExtensionCallbackType.PRE_STEP, ctx)
# ext: tiles? [override: step]
ctx.step_output = self.step(ctx, ext_manager)
# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
# ext: preview[post_step, priority=low]
ext_manager.run_callback(ExtensionCallbackType.POST_STEP, ctx)
ctx.latents = ctx.step_output.prev_sample
# ext: inpaint[post_denoise_loop] (restore unmasked part)
ext_manager.run_callback(ExtensionCallbackType.POST_DENOISE_LOOP, ctx)
return ctx.latents
@torch.inference_mode()
def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> SchedulerOutput:
ctx.latent_model_input = ctx.scheduler.scale_model_input(ctx.latents, ctx.timestep)
# TODO: conditionings as list(conditioning_data.to_unet_kwargs - ready)
# Note: The current handling of conditioning doesn't feel very future-proof.
# This might change in the future as new requirements come up, but for now,
# this is the rough plan.
if self._sequential_guidance:
ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Negative)
ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Positive)
else:
both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
# ext: override apply_cfg
ctx.noise_pred = self.apply_cfg(ctx)
# ext: cfg_rescale [modify_noise_prediction]
# TODO: rename
ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx)
# compute the previous noisy sample x_t -> x_t-1
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
# clean up locals
ctx.latent_model_input = None
ctx.negative_noise_pred = None
ctx.positive_noise_pred = None
ctx.noise_pred = None
return step_output
@staticmethod
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[ctx.step_index]
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
sample = ctx.latent_model_input
if conditioning_mode == ConditioningMode.Both:
sample = torch.cat([sample] * 2)
ctx.unet_kwargs = UNetKwargs(
sample=sample,
timestep=ctx.timestep,
encoder_hidden_states=None, # set later by conditoning
cross_attention_kwargs=dict( # noqa: C408
percent_through=ctx.step_index / len(ctx.inputs.timesteps),
),
)
ctx.conditioning_mode = conditioning_mode
ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
# ext: controlnet/ip/t2i [pre_unet]
ext_manager.run_callback(ExtensionCallbackType.PRE_UNET, ctx)
# ext: inpaint [pre_unet, priority=low]
# or
# ext: inpaint [override: unet_forward]
noise_pred = self._unet_forward(**vars(ctx.unet_kwargs))
ext_manager.run_callback(ExtensionCallbackType.POST_UNET, ctx)
# clean up locals
ctx.unet_kwargs = None
ctx.conditioning_mode = None
return noise_pred
def _unet_forward(self, **kwargs) -> torch.Tensor:
return self.unet(**kwargs).sample

View File

@ -0,0 +1,12 @@
from enum import Enum
class ExtensionCallbackType(Enum):
SETUP = "setup"
PRE_DENOISE_LOOP = "pre_denoise_loop"
POST_DENOISE_LOOP = "post_denoise_loop"
PRE_STEP = "pre_step"
POST_STEP = "post_step"
PRE_UNET = "pre_unet"
POST_UNET = "post_unet"
POST_APPLY_CFG = "post_apply_cfg"

View File

@ -0,0 +1,60 @@
from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List
import torch
from diffusers import UNet2DConditionModel
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
@dataclass
class CallbackMetadata:
callback_type: ExtensionCallbackType
order: int
@dataclass
class CallbackFunctionWithMetadata:
metadata: CallbackMetadata
function: Callable[[DenoiseContext], None]
def callback(callback_type: ExtensionCallbackType, order: int = 0):
def _decorator(function):
function._ext_metadata = CallbackMetadata(
callback_type=callback_type,
order=order,
)
return function
return _decorator
class ExtensionBase:
def __init__(self):
self._callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
# Register all of the callback methods for this instance.
for func_name in dir(self):
func = getattr(self, func_name)
metadata = getattr(func, "_ext_metadata", None)
if metadata is not None and isinstance(metadata, CallbackMetadata):
if metadata.callback_type not in self._callbacks:
self._callbacks[metadata.callback_type] = []
self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func))
def get_callbacks(self):
return self._callbacks
@contextmanager
def patch_extension(self, context: DenoiseContext):
yield None
@contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
yield None

View File

@ -0,0 +1,63 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional
import torch
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
# TODO: change event to accept image instead of latents
@dataclass
class PipelineIntermediateState:
step: int
order: int
total_steps: int
timestep: int
latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None
class PreviewExt(ExtensionBase):
def __init__(self, callback: Callable[[PipelineIntermediateState], None]):
super().__init__()
self.callback = callback
# do last so that all other changes shown
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
def initial_preview(self, ctx: DenoiseContext):
self.callback(
PipelineIntermediateState(
step=-1,
order=ctx.scheduler.order,
total_steps=len(ctx.inputs.timesteps),
timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it?
latents=ctx.latents,
)
)
# do last so that all other changes shown
@callback(ExtensionCallbackType.POST_STEP, order=1000)
def step_preview(self, ctx: DenoiseContext):
if hasattr(ctx.step_output, "denoised"):
predicted_original = ctx.step_output.denoised
elif hasattr(ctx.step_output, "pred_original_sample"):
predicted_original = ctx.step_output.pred_original_sample
else:
predicted_original = ctx.step_output.prev_sample
self.callback(
PipelineIntermediateState(
step=ctx.step_index,
order=ctx.scheduler.order,
total_steps=len(ctx.inputs.timesteps),
timestep=int(ctx.timestep), # TODO: is there any code which uses it?
latents=ctx.step_output.prev_sample,
predicted_original=predicted_original, # TODO: is there any reason for additional field?
)
)

View File

@ -0,0 +1,71 @@
from __future__ import annotations
from contextlib import ExitStack, contextmanager
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import torch
from diffusers import UNet2DConditionModel
from invokeai.app.services.session_processor.session_processor_common import CanceledException
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase
class ExtensionsManager:
def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
self._is_canceled = is_canceled
# A list of extensions in the order that they were added to the ExtensionsManager.
self._extensions: List[ExtensionBase] = []
self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
def add_extension(self, extension: ExtensionBase):
self._extensions.append(extension)
self._regenerate_ordered_callbacks()
def _regenerate_ordered_callbacks(self):
"""Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added."""
self._ordered_callbacks = {}
# Fill the ordered callbacks dictionary.
for extension in self._extensions:
for callback_type, callbacks in extension.get_callbacks().items():
if callback_type not in self._ordered_callbacks:
self._ordered_callbacks[callback_type] = []
self._ordered_callbacks[callback_type].extend(callbacks)
# Sort each callback list.
for callback_type, callbacks in self._ordered_callbacks.items():
# Note that sorted() is stable, so if two callbacks have the same order, the order that they extensions were
# added will be preserved.
self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order)
def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext):
if self._is_canceled and self._is_canceled():
raise CanceledException
callbacks = self._ordered_callbacks.get(callback_type, [])
for cb in callbacks:
cb.function(ctx)
@contextmanager
def patch_extensions(self, context: DenoiseContext):
if self._is_canceled and self._is_canceled():
raise CanceledException
with ExitStack() as exit_stack:
for ext in self._extensions:
exit_stack.enter_context(ext.patch_extension(context))
yield None
@contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
if self._is_canceled and self._is_canceled():
raise CanceledException
# TODO: create logic in PR with extension which uses it
yield None

View File

@ -0,0 +1,46 @@
from unittest import mock
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
class MockExtension(ExtensionBase):
"""A mock ExtensionBase subclass for testing purposes."""
def __init__(self, x: int):
super().__init__()
self._x = x
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
def set_step_index(self, ctx: DenoiseContext):
ctx.step_index = self._x
def test_extension_base_callback_registration():
"""Test that a callback can be successfully registered with an extension."""
val = 5
mock_extension = MockExtension(val)
mock_ctx = mock.MagicMock()
callbacks = mock_extension.get_callbacks()
pre_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.PRE_DENOISE_LOOP, [])
assert len(pre_denoise_loop_cbs) == 1
# Call the mock callback.
pre_denoise_loop_cbs[0].function(mock_ctx)
# Confirm that the callback ran.
assert mock_ctx.step_index == val
def test_extension_base_empty_callback_type():
"""Test that an empty list is returned when no callbacks are registered for a given callback type."""
mock_extension = MockExtension(5)
# There should be no callbacks registered for POST_DENOISE_LOOP.
callbacks = mock_extension.get_callbacks()
post_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.POST_DENOISE_LOOP, [])
assert len(post_denoise_loop_cbs) == 0

View File

@ -0,0 +1,112 @@
from unittest import mock
import pytest
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
class MockExtension(ExtensionBase):
"""A mock ExtensionBase subclass for testing purposes."""
def __init__(self, x: int):
super().__init__()
self._x = x
# Note that order is not specified. It should default to 0.
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
def set_step_index(self, ctx: DenoiseContext):
ctx.step_index = self._x
class MockExtensionLate(ExtensionBase):
"""A mock ExtensionBase subclass with a high order value on its PRE_DENOISE_LOOP callback."""
def __init__(self, x: int):
super().__init__()
self._x = x
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
def set_step_index(self, ctx: DenoiseContext):
ctx.step_index = self._x
def test_extension_manager_run_callback():
"""Test that run_callback runs all callbacks for the given callback type."""
em = ExtensionsManager()
mock_extension_1 = MockExtension(1)
em.add_extension(mock_extension_1)
mock_ctx = mock.MagicMock()
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
assert mock_ctx.step_index == 1
def test_extension_manager_run_callback_no_callbacks():
"""Test that run_callback does not raise an error when there are no callbacks for the given callback type."""
em = ExtensionsManager()
mock_ctx = mock.MagicMock()
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
@pytest.mark.parametrize(
["extension_1", "extension_2"],
# Regardless of initialization order, we expect MockExtensionLate to run last.
[(MockExtension(1), MockExtensionLate(2)), (MockExtensionLate(2), MockExtension(1))],
)
def test_extension_manager_order_callbacks(extension_1: ExtensionBase, extension_2: ExtensionBase):
"""Test that run_callback runs callbacks in the correct order."""
em = ExtensionsManager()
em.add_extension(extension_1)
em.add_extension(extension_2)
mock_ctx = mock.MagicMock()
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
assert mock_ctx.step_index == 2
class MockExtensionStableSort(ExtensionBase):
"""A mock extension with three PRE_DENOISE_LOOP callbacks, each with a different order value."""
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=-1000)
def early(self, ctx: DenoiseContext):
pass
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
def middle(self, ctx: DenoiseContext):
pass
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
def late(self, ctx: DenoiseContext):
pass
def test_extension_manager_stable_sort():
"""Test that when two callbacks have the same 'order' value, they are sorted based on the order they were added to
the ExtensionsManager."""
em = ExtensionsManager()
mock_extension_1 = MockExtensionStableSort()
mock_extension_2 = MockExtensionStableSort()
em.add_extension(mock_extension_1)
em.add_extension(mock_extension_2)
expected_order = [
mock_extension_1.early,
mock_extension_2.early,
mock_extension_1.middle,
mock_extension_2.middle,
mock_extension_1.late,
mock_extension_2.late,
]
# It's not ideal that we are accessing a private attribute here, but this was the most direct way to assert the
# desired behaviour.
assert [cb.function for cb in em._ordered_callbacks[ExtensionCallbackType.PRE_DENOISE_LOOP]] == expected_order