diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 7ccf906893..ccacc3303c 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -1,5 +1,6 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) import inspect +import os from contextlib import ExitStack from typing import Any, Dict, Iterator, List, Optional, Tuple, Union @@ -39,6 +40,7 @@ from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import BaseModelType from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs from invokeai.backend.stable_diffusion.diffusers_pipeline import ( ControlNetData, StableDiffusionGeneratorPipeline, @@ -53,6 +55,11 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( TextConditioningData, TextConditioningRegions, ) +from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 +from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt +from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES from invokeai.backend.util.devices import TorchDevice @@ -314,9 +321,10 @@ class DenoiseLatentsInvocation(BaseInvocation): context: InvocationContext, positive_conditioning_field: Union[ConditioningField, list[ConditioningField]], negative_conditioning_field: Union[ConditioningField, list[ConditioningField]], - unet: UNet2DConditionModel, latent_height: int, latent_width: int, + device: torch.device, + dtype: torch.dtype, cfg_scale: float | list[float], steps: int, cfg_rescale_multiplier: float, @@ -330,10 +338,10 @@ class DenoiseLatentsInvocation(BaseInvocation): uncond_list = [uncond_list] cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks( - cond_list, context, unet.device, unet.dtype + cond_list, context, device, dtype ) uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks( - uncond_list, context, unet.device, unet.dtype + uncond_list, context, device, dtype ) cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings( @@ -341,14 +349,14 @@ class DenoiseLatentsInvocation(BaseInvocation): masks=cond_text_embedding_masks, latent_height=latent_height, latent_width=latent_width, - dtype=unet.dtype, + dtype=dtype, ) uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings( text_conditionings=uncond_text_embeddings, masks=uncond_text_embedding_masks, latent_height=latent_height, latent_width=latent_width, - dtype=unet.dtype, + dtype=dtype, ) if isinstance(cfg_scale, list): @@ -707,9 +715,108 @@ class DenoiseLatentsInvocation(BaseInvocation): return seed, noise, latents + def invoke(self, context: InvocationContext) -> LatentsOutput: + if os.environ.get("USE_MODULAR_DENOISE", False): + return self._new_invoke(context) + else: + return self._old_invoke(context) + @torch.no_grad() @SilenceWarnings() # This quenches the NSFW nag from diffusers. - def invoke(self, context: InvocationContext) -> LatentsOutput: + def _new_invoke(self, context: InvocationContext) -> LatentsOutput: + ext_manager = ExtensionsManager(is_canceled=context.util.is_canceled) + + device = TorchDevice.choose_torch_device() + dtype = TorchDevice.choose_torch_dtype() + + seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) + latents = latents.to(device=device, dtype=dtype) + if noise is not None: + noise = noise.to(device=device, dtype=dtype) + + _, _, latent_height, latent_width = latents.shape + + conditioning_data = self.get_conditioning_data( + context=context, + positive_conditioning_field=self.positive_conditioning, + negative_conditioning_field=self.negative_conditioning, + cfg_scale=self.cfg_scale, + steps=self.steps, + latent_height=latent_height, + latent_width=latent_width, + device=device, + dtype=dtype, + # TODO: old backend, remove + cfg_rescale_multiplier=self.cfg_rescale_multiplier, + ) + + scheduler = get_scheduler( + context=context, + scheduler_info=self.unet.scheduler, + scheduler_name=self.scheduler, + seed=seed, + ) + + timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( + scheduler, + seed=seed, + device=device, + steps=self.steps, + denoising_start=self.denoising_start, + denoising_end=self.denoising_end, + ) + + denoise_ctx = DenoiseContext( + inputs=DenoiseInputs( + orig_latents=latents, + timesteps=timesteps, + init_timestep=init_timestep, + noise=noise, + seed=seed, + scheduler_step_kwargs=scheduler_step_kwargs, + conditioning_data=conditioning_data, + attention_processor_cls=CustomAttnProcessor2_0, + ), + unet=None, + scheduler=scheduler, + ) + + # get the unet's config so that we can pass the base to sd_step_callback() + unet_config = context.models.get_config(self.unet.unet.key) + + ### preview + def step_callback(state: PipelineIntermediateState) -> None: + context.util.sd_step_callback(state, unet_config.base) + + ext_manager.add_extension(PreviewExt(step_callback)) + + # ext: t2i/ip adapter + ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) + + unet_info = context.models.load(self.unet.unet) + assert isinstance(unet_info.model, UNet2DConditionModel) + with ( + unet_info.model_on_device() as (model_state_dict, unet), + ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls), + # ext: controlnet + ext_manager.patch_extensions(unet), + # ext: freeu, seamless, ip adapter, lora + ext_manager.patch_unet(model_state_dict, unet), + ): + sd_backend = StableDiffusionBackend(unet, scheduler) + denoise_ctx.unet = unet + result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + result_latents = result_latents.detach().to("cpu") + TorchDevice.empty_cache() + + name = context.tensors.save(tensor=result_latents) + return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None) + + @torch.no_grad() + @SilenceWarnings() # This quenches the NSFW nag from diffusers. + def _old_invoke(self, context: InvocationContext) -> LatentsOutput: seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) @@ -788,7 +895,8 @@ class DenoiseLatentsInvocation(BaseInvocation): context=context, positive_conditioning_field=self.positive_conditioning, negative_conditioning_field=self.negative_conditioning, - unet=unet, + device=unet.device, + dtype=unet.dtype, latent_height=latent_height, latent_width=latent_width, cfg_scale=self.cfg_scale, diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index 8b8aa6d5a5..d30f7b3167 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -5,7 +5,7 @@ from __future__ import annotations import pickle from contextlib import contextmanager -from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union import numpy as np import torch @@ -32,8 +32,27 @@ with LoRAHelper.apply_lora_unet(unet, loras): """ -# TODO: rename smth like ModelPatcher and add TI method? class ModelPatcher: + @staticmethod + @contextmanager + def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]): + """A context manager that patches `unet` with the provided attention processor. + + Args: + unet (UNet2DConditionModel): The UNet model to patch. + processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...). + """ + unet_orig_processors = unet.attn_processors + + # create separate instance for each attention, to be able modify each attention separately + unet_new_processors = {key: processor_cls() for key in unet_orig_processors.keys()} + try: + unet.set_attn_processor(unet_new_processors) + yield None + + finally: + unet.set_attn_processor(unet_orig_processors) + @staticmethod def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: assert "." not in lora_key diff --git a/invokeai/backend/stable_diffusion/denoise_context.py b/invokeai/backend/stable_diffusion/denoise_context.py new file mode 100644 index 0000000000..2b43d3fb0f --- /dev/null +++ b/invokeai/backend/stable_diffusion/denoise_context.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union + +import torch +from diffusers import UNet2DConditionModel +from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode, TextConditioningData + + +@dataclass +class UNetKwargs: + sample: torch.Tensor + timestep: Union[torch.Tensor, float, int] + encoder_hidden_states: torch.Tensor + + class_labels: Optional[torch.Tensor] = None + timestep_cond: Optional[torch.Tensor] = None + attention_mask: Optional[torch.Tensor] = None + cross_attention_kwargs: Optional[Dict[str, Any]] = None + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None + mid_block_additional_residual: Optional[torch.Tensor] = None + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None + encoder_attention_mask: Optional[torch.Tensor] = None + # return_dict: bool = True + + +@dataclass +class DenoiseInputs: + """Initial variables passed to denoise. Supposed to be unchanged.""" + + # The latent-space image to denoise. + # Shape: [batch, channels, latent_height, latent_width] + # - If we are inpainting, this is the initial latent image before noise has been added. + # - If we are generating a new image, this should be initialized to zeros. + # - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner). + orig_latents: torch.Tensor + + # kwargs forwarded to the scheduler.step() method. + scheduler_step_kwargs: dict[str, Any] + + # Text conditionging data. + conditioning_data: TextConditioningData + + # Noise used for two purposes: + # 1. Used by the scheduler to noise the initial `latents` before denoising. + # 2. Used to noise the `masked_latents` when inpainting. + # `noise` should be None if the `latents` tensor has already been noised. + # Shape: [1 or batch, channels, latent_height, latent_width] + noise: Optional[torch.Tensor] + + # The seed used to generate the noise for the denoising process. + # HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the + # same noise used earlier in the pipeline. This should really be handled in a clearer way. + seed: int + + # The timestep schedule for the denoising process. + timesteps: torch.Tensor + + # The first timestep in the schedule. This is used to determine the initial noise level, so + # should be populated if you want noise applied *even* if timesteps is empty. + init_timestep: torch.Tensor + + # Class of attention processor that is used. + attention_processor_cls: Type[Any] + + +@dataclass +class DenoiseContext: + """Context with all variables in denoise""" + + # Initial variables passed to denoise. Supposed to be unchanged. + inputs: DenoiseInputs + + # Scheduler which used to apply noise predictions. + scheduler: SchedulerMixin + + # UNet model. + unet: Optional[UNet2DConditionModel] = None + + # Current state of latent-space image in denoising process. + # None until `pre_denoise_loop` callback. + # Shape: [batch, channels, latent_height, latent_width] + latents: Optional[torch.Tensor] = None + + # Current denoising step index. + # None until `pre_step` callback. + step_index: Optional[int] = None + + # Current denoising step timestep. + # None until `pre_step` callback. + timestep: Optional[torch.Tensor] = None + + # Arguments which will be passed to UNet model. + # Available in `pre_unet`/`post_unet` callbacks, otherwise will be None. + unet_kwargs: Optional[UNetKwargs] = None + + # SchedulerOutput class returned from step function(normally, generated by scheduler). + # Supposed to be used only in `post_step` callback, otherwise can be None. + step_output: Optional[SchedulerOutput] = None + + # Scaled version of `latents`, which will be passed to unet_kwargs initialization. + # Available in events inside step(between `pre_step` and `post_stop`). + # Shape: [batch, channels, latent_height, latent_width] + latent_model_input: Optional[torch.Tensor] = None + + # [TMP] Defines on which conditionings current unet call will be runned. + # Available in `pre_unet`/`post_unet` callbacks, otherwise will be None. + conditioning_mode: Optional[ConditioningMode] = None + + # [TMP] Noise predictions from negative conditioning. + # Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. + # Shape: [batch, channels, latent_height, latent_width] + negative_noise_pred: Optional[torch.Tensor] = None + + # [TMP] Noise predictions from positive conditioning. + # Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. + # Shape: [batch, channels, latent_height, latent_width] + positive_noise_pred: Optional[torch.Tensor] = None + + # Combined noise prediction from passed conditionings. + # Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. + # Shape: [batch, channels, latent_height, latent_width] + noise_pred: Optional[torch.Tensor] = None + + # Dictionary for extensions to pass extra info about denoise process to other extensions. + extra: dict = field(default_factory=dict) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index ee464f73e1..b3a668518b 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -23,21 +23,12 @@ from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData +from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState from invokeai.backend.util.attention import auto_detect_slice_size from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.hotfixes import ControlNetModel -@dataclass -class PipelineIntermediateState: - step: int - order: int - total_steps: int - timestep: int - latents: torch.Tensor - predicted_original: Optional[torch.Tensor] = None - - @dataclass class AddsMaskGuidance: mask: torch.Tensor diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 85950a01df..5fe1483ebc 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -1,10 +1,17 @@ +from __future__ import annotations + import math from dataclasses import dataclass -from typing import List, Optional, Union +from enum import Enum +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch -from invokeai.backend.ip_adapter.ip_adapter import IPAdapter +from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData + +if TYPE_CHECKING: + from invokeai.backend.ip_adapter.ip_adapter import IPAdapter + from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs @dataclass @@ -95,6 +102,12 @@ class TextConditioningRegions: assert self.masks.shape[1] == len(self.ranges) +class ConditioningMode(Enum): + Both = "both" + Negative = "negative" + Positive = "positive" + + class TextConditioningData: def __init__( self, @@ -103,7 +116,7 @@ class TextConditioningData: uncond_regions: Optional[TextConditioningRegions], cond_regions: Optional[TextConditioningRegions], guidance_scale: Union[float, List[float]], - guidance_rescale_multiplier: float = 0, + guidance_rescale_multiplier: float = 0, # TODO: old backend, remove ): self.uncond_text = uncond_text self.cond_text = cond_text @@ -114,6 +127,7 @@ class TextConditioningData: # Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate # images that are closely linked to the text `prompt`, usually at the expense of lower image quality. self.guidance_scale = guidance_scale + # TODO: old backend, remove # For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7. # See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). self.guidance_rescale_multiplier = guidance_rescale_multiplier @@ -121,3 +135,114 @@ class TextConditioningData: def is_sdxl(self): assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo) return isinstance(self.cond_text, SDXLConditioningInfo) + + def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode): + """Fills unet arguments with data from provided conditionings. + + Args: + unet_kwargs (UNetKwargs): Object which stores UNet model arguments. + conditioning_mode (ConditioningMode): Describes which conditionings should be used. + """ + _, _, h, w = unet_kwargs.sample.shape + device = unet_kwargs.sample.device + dtype = unet_kwargs.sample.dtype + + # TODO: combine regions with conditionings + if conditioning_mode == ConditioningMode.Both: + conditionings = [self.uncond_text, self.cond_text] + c_regions = [self.uncond_regions, self.cond_regions] + elif conditioning_mode == ConditioningMode.Positive: + conditionings = [self.cond_text] + c_regions = [self.cond_regions] + elif conditioning_mode == ConditioningMode.Negative: + conditionings = [self.uncond_text] + c_regions = [self.uncond_regions] + else: + raise ValueError(f"Unexpected conditioning mode: {conditioning_mode}") + + encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch( + [c.embeds for c in conditionings] + ) + + unet_kwargs.encoder_hidden_states = encoder_hidden_states + unet_kwargs.encoder_attention_mask = encoder_attention_mask + + if self.is_sdxl(): + added_cond_kwargs = dict( # noqa: C408 + text_embeds=torch.cat([c.pooled_embeds for c in conditionings]), + time_ids=torch.cat([c.add_time_ids for c in conditionings]), + ) + + unet_kwargs.added_cond_kwargs = added_cond_kwargs + + if any(r is not None for r in c_regions): + tmp_regions = [] + for c, r in zip(conditionings, c_regions, strict=True): + if r is None: + r = TextConditioningRegions( + masks=torch.ones((1, 1, h, w), dtype=dtype), + ranges=[Range(start=0, end=c.embeds.shape[1])], + ) + tmp_regions.append(r) + + if unet_kwargs.cross_attention_kwargs is None: + unet_kwargs.cross_attention_kwargs = {} + + unet_kwargs.cross_attention_kwargs.update( + regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype), + ) + + @staticmethod + def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int) -> torch.Tensor: + return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim) + + @classmethod + def _pad_conditioning( + cls, + cond: torch.Tensor, + target_len: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Pad provided conditioning tensor to target_len by zeros and returns mask of unpadded bytes. + + Args: + cond (torch.Tensor): Conditioning tensor which to pads by zeros. + target_len (int): To which length(tokens count) pad tensor. + """ + conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype) + + if cond.shape[1] < target_len: + conditioning_attention_mask = cls._pad_zeros( + conditioning_attention_mask, + pad_shape=(cond.shape[0], target_len - cond.shape[1]), + dim=1, + ) + + cond = cls._pad_zeros( + cond, + pad_shape=(cond.shape[0], target_len - cond.shape[1], cond.shape[2]), + dim=1, + ) + + return cond, conditioning_attention_mask + + @classmethod + def _concat_conditionings_for_batch( + cls, + conditionings: List[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Concatenate provided conditioning tensors to one batched tensor. + If tensors have different sizes then pad them by zeros and creates + encoder_attention_mask to exclude padding from attention. + + Args: + conditionings (List[torch.Tensor]): List of conditioning tensors to concatenate. + """ + encoder_attention_mask = None + max_len = max([c.shape[1] for c in conditionings]) + if any(c.shape[1] != max_len for c in conditionings): + encoder_attention_masks = [None] * len(conditionings) + for i in range(len(conditionings)): + conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(conditionings[i], max_len) + encoder_attention_mask = torch.cat(encoder_attention_masks) + + return torch.cat(conditionings), encoder_attention_mask diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py index f09cc0a0d2..eddd31f0c4 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -1,9 +1,14 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import torch import torch.nn.functional as F -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( - TextConditioningRegions, -) +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + TextConditioningRegions, + ) class RegionalPromptData: diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py new file mode 100644 index 0000000000..806deb5e03 --- /dev/null +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import torch +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput +from tqdm.auto import tqdm + +from invokeai.app.services.config.config_default import get_config +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager + + +class StableDiffusionBackend: + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + ): + self.unet = unet + self.scheduler = scheduler + config = get_config() + self._sequential_guidance = config.sequential_guidance + + def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): + if ctx.inputs.init_timestep.shape[0] == 0: + return ctx.inputs.orig_latents + + ctx.latents = ctx.inputs.orig_latents.clone() + + if ctx.inputs.noise is not None: + batch_size = ctx.latents.shape[0] + # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers + ctx.latents = ctx.scheduler.add_noise( + ctx.latents, ctx.inputs.noise, ctx.inputs.init_timestep.expand(batch_size) + ) + + # if no work to do, return latents + if ctx.inputs.timesteps.shape[0] == 0: + return ctx.latents + + # ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed) + # ext: preview[pre_denoise_loop, priority=low] + ext_manager.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, ctx) + + for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020 + # ext: inpaint (apply mask to latents on non-inpaint models) + ext_manager.run_callback(ExtensionCallbackType.PRE_STEP, ctx) + + # ext: tiles? [override: step] + ctx.step_output = self.step(ctx, ext_manager) + + # ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models) + # ext: preview[post_step, priority=low] + ext_manager.run_callback(ExtensionCallbackType.POST_STEP, ctx) + + ctx.latents = ctx.step_output.prev_sample + + # ext: inpaint[post_denoise_loop] (restore unmasked part) + ext_manager.run_callback(ExtensionCallbackType.POST_DENOISE_LOOP, ctx) + return ctx.latents + + @torch.inference_mode() + def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> SchedulerOutput: + ctx.latent_model_input = ctx.scheduler.scale_model_input(ctx.latents, ctx.timestep) + + # TODO: conditionings as list(conditioning_data.to_unet_kwargs - ready) + # Note: The current handling of conditioning doesn't feel very future-proof. + # This might change in the future as new requirements come up, but for now, + # this is the rough plan. + if self._sequential_guidance: + ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Negative) + ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Positive) + else: + both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both) + ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2) + + # ext: override apply_cfg + ctx.noise_pred = self.apply_cfg(ctx) + + # ext: cfg_rescale [modify_noise_prediction] + # TODO: rename + ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx) + + # compute the previous noisy sample x_t -> x_t-1 + step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs) + + # clean up locals + ctx.latent_model_input = None + ctx.negative_noise_pred = None + ctx.positive_noise_pred = None + ctx.noise_pred = None + + return step_output + + @staticmethod + def apply_cfg(ctx: DenoiseContext) -> torch.Tensor: + guidance_scale = ctx.inputs.conditioning_data.guidance_scale + if isinstance(guidance_scale, list): + guidance_scale = guidance_scale[ctx.step_index] + + return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) + # return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) + + def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode): + sample = ctx.latent_model_input + if conditioning_mode == ConditioningMode.Both: + sample = torch.cat([sample] * 2) + + ctx.unet_kwargs = UNetKwargs( + sample=sample, + timestep=ctx.timestep, + encoder_hidden_states=None, # set later by conditoning + cross_attention_kwargs=dict( # noqa: C408 + percent_through=ctx.step_index / len(ctx.inputs.timesteps), + ), + ) + + ctx.conditioning_mode = conditioning_mode + ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode) + + # ext: controlnet/ip/t2i [pre_unet] + ext_manager.run_callback(ExtensionCallbackType.PRE_UNET, ctx) + + # ext: inpaint [pre_unet, priority=low] + # or + # ext: inpaint [override: unet_forward] + noise_pred = self._unet_forward(**vars(ctx.unet_kwargs)) + + ext_manager.run_callback(ExtensionCallbackType.POST_UNET, ctx) + + # clean up locals + ctx.unet_kwargs = None + ctx.conditioning_mode = None + + return noise_pred + + def _unet_forward(self, **kwargs) -> torch.Tensor: + return self.unet(**kwargs).sample diff --git a/invokeai/backend/stable_diffusion/extension_callback_type.py b/invokeai/backend/stable_diffusion/extension_callback_type.py new file mode 100644 index 0000000000..aaefbd7ed0 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extension_callback_type.py @@ -0,0 +1,12 @@ +from enum import Enum + + +class ExtensionCallbackType(Enum): + SETUP = "setup" + PRE_DENOISE_LOOP = "pre_denoise_loop" + POST_DENOISE_LOOP = "post_denoise_loop" + PRE_STEP = "pre_step" + POST_STEP = "post_step" + PRE_UNET = "pre_unet" + POST_UNET = "post_unet" + POST_APPLY_CFG = "post_apply_cfg" diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py new file mode 100644 index 0000000000..802af86e6d --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Dict, List + +import torch +from diffusers import UNet2DConditionModel + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType + + +@dataclass +class CallbackMetadata: + callback_type: ExtensionCallbackType + order: int + + +@dataclass +class CallbackFunctionWithMetadata: + metadata: CallbackMetadata + function: Callable[[DenoiseContext], None] + + +def callback(callback_type: ExtensionCallbackType, order: int = 0): + def _decorator(function): + function._ext_metadata = CallbackMetadata( + callback_type=callback_type, + order=order, + ) + return function + + return _decorator + + +class ExtensionBase: + def __init__(self): + self._callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {} + + # Register all of the callback methods for this instance. + for func_name in dir(self): + func = getattr(self, func_name) + metadata = getattr(func, "_ext_metadata", None) + if metadata is not None and isinstance(metadata, CallbackMetadata): + if metadata.callback_type not in self._callbacks: + self._callbacks[metadata.callback_type] = [] + self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func)) + + def get_callbacks(self): + return self._callbacks + + @contextmanager + def patch_extension(self, context: DenoiseContext): + yield None + + @contextmanager + def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): + yield None diff --git a/invokeai/backend/stable_diffusion/extensions/preview.py b/invokeai/backend/stable_diffusion/extensions/preview.py new file mode 100644 index 0000000000..fd0cc0b9a3 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/preview.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Optional + +import torch + +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + + +# TODO: change event to accept image instead of latents +@dataclass +class PipelineIntermediateState: + step: int + order: int + total_steps: int + timestep: int + latents: torch.Tensor + predicted_original: Optional[torch.Tensor] = None + + +class PreviewExt(ExtensionBase): + def __init__(self, callback: Callable[[PipelineIntermediateState], None]): + super().__init__() + self.callback = callback + + # do last so that all other changes shown + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000) + def initial_preview(self, ctx: DenoiseContext): + self.callback( + PipelineIntermediateState( + step=-1, + order=ctx.scheduler.order, + total_steps=len(ctx.inputs.timesteps), + timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it? + latents=ctx.latents, + ) + ) + + # do last so that all other changes shown + @callback(ExtensionCallbackType.POST_STEP, order=1000) + def step_preview(self, ctx: DenoiseContext): + if hasattr(ctx.step_output, "denoised"): + predicted_original = ctx.step_output.denoised + elif hasattr(ctx.step_output, "pred_original_sample"): + predicted_original = ctx.step_output.pred_original_sample + else: + predicted_original = ctx.step_output.prev_sample + + self.callback( + PipelineIntermediateState( + step=ctx.step_index, + order=ctx.scheduler.order, + total_steps=len(ctx.inputs.timesteps), + timestep=int(ctx.timestep), # TODO: is there any code which uses it? + latents=ctx.step_output.prev_sample, + predicted_original=predicted_original, # TODO: is there any reason for additional field? + ) + ) diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py new file mode 100644 index 0000000000..1cae2e4219 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from contextlib import ExitStack, contextmanager +from typing import TYPE_CHECKING, Callable, Dict, List, Optional + +import torch +from diffusers import UNet2DConditionModel + +from invokeai.app.services.session_processor.session_processor_common import CanceledException + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType + from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase + + +class ExtensionsManager: + def __init__(self, is_canceled: Optional[Callable[[], bool]] = None): + self._is_canceled = is_canceled + + # A list of extensions in the order that they were added to the ExtensionsManager. + self._extensions: List[ExtensionBase] = [] + self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {} + + def add_extension(self, extension: ExtensionBase): + self._extensions.append(extension) + self._regenerate_ordered_callbacks() + + def _regenerate_ordered_callbacks(self): + """Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added.""" + self._ordered_callbacks = {} + + # Fill the ordered callbacks dictionary. + for extension in self._extensions: + for callback_type, callbacks in extension.get_callbacks().items(): + if callback_type not in self._ordered_callbacks: + self._ordered_callbacks[callback_type] = [] + self._ordered_callbacks[callback_type].extend(callbacks) + + # Sort each callback list. + for callback_type, callbacks in self._ordered_callbacks.items(): + # Note that sorted() is stable, so if two callbacks have the same order, the order that they extensions were + # added will be preserved. + self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order) + + def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext): + if self._is_canceled and self._is_canceled(): + raise CanceledException + + callbacks = self._ordered_callbacks.get(callback_type, []) + for cb in callbacks: + cb.function(ctx) + + @contextmanager + def patch_extensions(self, context: DenoiseContext): + if self._is_canceled and self._is_canceled(): + raise CanceledException + + with ExitStack() as exit_stack: + for ext in self._extensions: + exit_stack.enter_context(ext.patch_extension(context)) + + yield None + + @contextmanager + def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): + if self._is_canceled and self._is_canceled(): + raise CanceledException + + # TODO: create logic in PR with extension which uses it + yield None diff --git a/tests/backend/stable_diffusion/extensions/test_base.py b/tests/backend/stable_diffusion/extensions/test_base.py new file mode 100644 index 0000000000..d024c551a2 --- /dev/null +++ b/tests/backend/stable_diffusion/extensions/test_base.py @@ -0,0 +1,46 @@ +from unittest import mock + +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback + + +class MockExtension(ExtensionBase): + """A mock ExtensionBase subclass for testing purposes.""" + + def __init__(self, x: int): + super().__init__() + self._x = x + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def set_step_index(self, ctx: DenoiseContext): + ctx.step_index = self._x + + +def test_extension_base_callback_registration(): + """Test that a callback can be successfully registered with an extension.""" + val = 5 + mock_extension = MockExtension(val) + + mock_ctx = mock.MagicMock() + + callbacks = mock_extension.get_callbacks() + pre_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.PRE_DENOISE_LOOP, []) + assert len(pre_denoise_loop_cbs) == 1 + + # Call the mock callback. + pre_denoise_loop_cbs[0].function(mock_ctx) + + # Confirm that the callback ran. + assert mock_ctx.step_index == val + + +def test_extension_base_empty_callback_type(): + """Test that an empty list is returned when no callbacks are registered for a given callback type.""" + mock_extension = MockExtension(5) + + # There should be no callbacks registered for POST_DENOISE_LOOP. + callbacks = mock_extension.get_callbacks() + + post_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.POST_DENOISE_LOOP, []) + assert len(post_denoise_loop_cbs) == 0 diff --git a/tests/backend/stable_diffusion/test_extension_manager.py b/tests/backend/stable_diffusion/test_extension_manager.py new file mode 100644 index 0000000000..889f8316e5 --- /dev/null +++ b/tests/backend/stable_diffusion/test_extension_manager.py @@ -0,0 +1,112 @@ +from unittest import mock + +import pytest + +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback +from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager + + +class MockExtension(ExtensionBase): + """A mock ExtensionBase subclass for testing purposes.""" + + def __init__(self, x: int): + super().__init__() + self._x = x + + # Note that order is not specified. It should default to 0. + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def set_step_index(self, ctx: DenoiseContext): + ctx.step_index = self._x + + +class MockExtensionLate(ExtensionBase): + """A mock ExtensionBase subclass with a high order value on its PRE_DENOISE_LOOP callback.""" + + def __init__(self, x: int): + super().__init__() + self._x = x + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000) + def set_step_index(self, ctx: DenoiseContext): + ctx.step_index = self._x + + +def test_extension_manager_run_callback(): + """Test that run_callback runs all callbacks for the given callback type.""" + + em = ExtensionsManager() + mock_extension_1 = MockExtension(1) + em.add_extension(mock_extension_1) + + mock_ctx = mock.MagicMock() + em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx) + + assert mock_ctx.step_index == 1 + + +def test_extension_manager_run_callback_no_callbacks(): + """Test that run_callback does not raise an error when there are no callbacks for the given callback type.""" + em = ExtensionsManager() + mock_ctx = mock.MagicMock() + em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx) + + +@pytest.mark.parametrize( + ["extension_1", "extension_2"], + # Regardless of initialization order, we expect MockExtensionLate to run last. + [(MockExtension(1), MockExtensionLate(2)), (MockExtensionLate(2), MockExtension(1))], +) +def test_extension_manager_order_callbacks(extension_1: ExtensionBase, extension_2: ExtensionBase): + """Test that run_callback runs callbacks in the correct order.""" + em = ExtensionsManager() + em.add_extension(extension_1) + em.add_extension(extension_2) + + mock_ctx = mock.MagicMock() + em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx) + + assert mock_ctx.step_index == 2 + + +class MockExtensionStableSort(ExtensionBase): + """A mock extension with three PRE_DENOISE_LOOP callbacks, each with a different order value.""" + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=-1000) + def early(self, ctx: DenoiseContext): + pass + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def middle(self, ctx: DenoiseContext): + pass + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000) + def late(self, ctx: DenoiseContext): + pass + + +def test_extension_manager_stable_sort(): + """Test that when two callbacks have the same 'order' value, they are sorted based on the order they were added to + the ExtensionsManager.""" + + em = ExtensionsManager() + + mock_extension_1 = MockExtensionStableSort() + mock_extension_2 = MockExtensionStableSort() + + em.add_extension(mock_extension_1) + em.add_extension(mock_extension_2) + + expected_order = [ + mock_extension_1.early, + mock_extension_2.early, + mock_extension_1.middle, + mock_extension_2.middle, + mock_extension_1.late, + mock_extension_2.late, + ] + + # It's not ideal that we are accessing a private attribute here, but this was the most direct way to assert the + # desired behaviour. + assert [cb.function for cb in em._ordered_callbacks[ExtensionCallbackType.PRE_DENOISE_LOOP]] == expected_order