From 9cc852cf7f4441521ca844b647c94ff765b7bb9c Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Fri, 12 Jul 2024 20:31:26 +0300 Subject: [PATCH 01/25] Base code from draft PR --- invokeai/app/invocations/denoise_latents.py | 111 ++++++++- .../stable_diffusion/denoise_context.py | 60 +++++ .../diffusion/conditioning_data.py | 128 +++++++++- .../diffusion/regional_prompt_data.py | 11 +- .../stable_diffusion/diffusion_backend.py | 220 ++++++++++++++++++ .../stable_diffusion/extensions/__init__.py | 9 + .../stable_diffusion/extensions/base.py | 58 +++++ .../stable_diffusion/extensions_manager.py | 195 ++++++++++++++++ 8 files changed, 781 insertions(+), 11 deletions(-) create mode 100644 invokeai/backend/stable_diffusion/denoise_context.py create mode 100644 invokeai/backend/stable_diffusion/diffusion_backend.py create mode 100644 invokeai/backend/stable_diffusion/extensions/__init__.py create mode 100644 invokeai/backend/stable_diffusion/extensions/base.py create mode 100644 invokeai/backend/stable_diffusion/extensions_manager.py diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 7ccf906893..bec8741936 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -1,5 +1,6 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) import inspect +import os from contextlib import ExitStack from typing import Any, Dict, Iterator, List, Optional, Tuple, Union @@ -39,6 +40,7 @@ from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import BaseModelType from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext from invokeai.backend.stable_diffusion.diffusers_pipeline import ( ControlNetData, StableDiffusionGeneratorPipeline, @@ -53,6 +55,9 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( TextConditioningData, TextConditioningRegions, ) +from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 +from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend +from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES from invokeai.backend.util.devices import TorchDevice @@ -314,9 +319,10 @@ class DenoiseLatentsInvocation(BaseInvocation): context: InvocationContext, positive_conditioning_field: Union[ConditioningField, list[ConditioningField]], negative_conditioning_field: Union[ConditioningField, list[ConditioningField]], - unet: UNet2DConditionModel, latent_height: int, latent_width: int, + device: torch.device, + dtype: torch.dtype, cfg_scale: float | list[float], steps: int, cfg_rescale_multiplier: float, @@ -330,10 +336,10 @@ class DenoiseLatentsInvocation(BaseInvocation): uncond_list = [uncond_list] cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks( - cond_list, context, unet.device, unet.dtype + cond_list, context, device, dtype ) uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks( - uncond_list, context, unet.device, unet.dtype + uncond_list, context, device, dtype ) cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings( @@ -341,14 +347,14 @@ class DenoiseLatentsInvocation(BaseInvocation): masks=cond_text_embedding_masks, latent_height=latent_height, latent_width=latent_width, - dtype=unet.dtype, + dtype=dtype, ) uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings( text_conditionings=uncond_text_embeddings, masks=uncond_text_embedding_masks, latent_height=latent_height, latent_width=latent_width, - dtype=unet.dtype, + dtype=dtype, ) if isinstance(cfg_scale, list): @@ -707,9 +713,99 @@ class DenoiseLatentsInvocation(BaseInvocation): return seed, noise, latents + def invoke(self, context: InvocationContext) -> LatentsOutput: + if os.environ.get("USE_MODULAR_DENOISE", False): + return self._new_invoke(context) + else: + return self._old_invoke(context) + @torch.no_grad() @SilenceWarnings() # This quenches the NSFW nag from diffusers. - def invoke(self, context: InvocationContext) -> LatentsOutput: + def _new_invoke(self, context: InvocationContext) -> LatentsOutput: + with ExitStack() as exit_stack: + ext_manager = ExtensionsManager() + + device = TorchDevice.choose_torch_device() + dtype = TorchDevice.choose_torch_dtype() + + seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) + latents = latents.to(device=device, dtype=dtype) + if noise is not None: + noise = noise.to(device=device, dtype=dtype) + + _, _, latent_height, latent_width = latents.shape + + conditioning_data = self.get_conditioning_data( + context=context, + positive_conditioning_field=self.positive_conditioning, + negative_conditioning_field=self.negative_conditioning, + cfg_scale=self.cfg_scale, + steps=self.steps, + latent_height=latent_height, + latent_width=latent_width, + device=device, + dtype=dtype, + # TODO: old backend, remove + cfg_rescale_multiplier=self.cfg_rescale_multiplier, + ) + + scheduler = get_scheduler( + context=context, + scheduler_info=self.unet.scheduler, + scheduler_name=self.scheduler, + seed=seed, + ) + + timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( + scheduler, + seed=seed, + device=device, + steps=self.steps, + denoising_start=self.denoising_start, + denoising_end=self.denoising_end, + ) + + denoise_ctx = DenoiseContext( + latents=latents, + timesteps=timesteps, + init_timestep=init_timestep, + noise=noise, + seed=seed, + scheduler_step_kwargs=scheduler_step_kwargs, + conditioning_data=conditioning_data, + unet=None, + scheduler=scheduler, + ) + + # get the unet's config so that we can pass the base to sd_step_callback() + unet_config = context.models.get_config(self.unet.unet.key) + + # ext: t2i/ip adapter + ext_manager.modifiers.pre_unet_load(denoise_ctx, ext_manager) + + unet_info = context.models.load(self.unet.unet) + assert isinstance(unet_info.model, UNet2DConditionModel) + with ( + unet_info.model_on_device() as (model_state_dict, unet), + # ext: controlnet + ext_manager.patch_attention_processor(unet, CustomAttnProcessor2_0), + # ext: freeu, seamless, ip adapter, lora + ext_manager.patch_unet(model_state_dict, unet), + ): + sd_backend = StableDiffusionBackend(unet, scheduler) + denoise_ctx.unet = unet + result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + result_latents = result_latents.to("cpu") # TODO: detach? + TorchDevice.empty_cache() + + name = context.tensors.save(tensor=result_latents) + return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None) + + @torch.no_grad() + @SilenceWarnings() # This quenches the NSFW nag from diffusers. + def _old_invoke(self, context: InvocationContext) -> LatentsOutput: seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) @@ -788,7 +884,8 @@ class DenoiseLatentsInvocation(BaseInvocation): context=context, positive_conditioning_field=self.positive_conditioning, negative_conditioning_field=self.negative_conditioning, - unet=unet, + device=unet.device, + dtype=unet.dtype, latent_height=latent_height, latent_width=latent_width, cfg_scale=self.cfg_scale, diff --git a/invokeai/backend/stable_diffusion/denoise_context.py b/invokeai/backend/stable_diffusion/denoise_context.py new file mode 100644 index 0000000000..b56f095948 --- /dev/null +++ b/invokeai/backend/stable_diffusion/denoise_context.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union + +import torch +from diffusers import UNet2DConditionModel +from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData + + +@dataclass +class UNetKwargs: + sample: torch.Tensor + timestep: Union[torch.Tensor, float, int] + encoder_hidden_states: torch.Tensor + + class_labels: Optional[torch.Tensor] = None + timestep_cond: Optional[torch.Tensor] = None + attention_mask: Optional[torch.Tensor] = None + cross_attention_kwargs: Optional[Dict[str, Any]] = None + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None + mid_block_additional_residual: Optional[torch.Tensor] = None + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None + encoder_attention_mask: Optional[torch.Tensor] = None + # return_dict: bool = True + + +@dataclass +class DenoiseContext: + latents: torch.Tensor + scheduler_step_kwargs: dict[str, Any] + conditioning_data: TextConditioningData + noise: Optional[torch.Tensor] + seed: int + timesteps: torch.Tensor + init_timestep: torch.Tensor + + scheduler: SchedulerMixin + unet: Optional[UNet2DConditionModel] = None + + orig_latents: Optional[torch.Tensor] = None + step_index: Optional[int] = None + timestep: Optional[torch.Tensor] = None + unet_kwargs: Optional[UNetKwargs] = None + step_output: Optional[SchedulerOutput] = None + + latent_model_input: Optional[torch.Tensor] = None + conditioning_mode: Optional[str] = None + negative_noise_pred: Optional[torch.Tensor] = None + positive_noise_pred: Optional[torch.Tensor] = None + noise_pred: Optional[torch.Tensor] = None + + extra: dict = field(default_factory=dict) + + def __delattr__(self, name: str): + setattr(self, name, None) diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 85950a01df..3a758839ea 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -5,6 +5,7 @@ from typing import List, Optional, Union import torch from invokeai.backend.ip_adapter.ip_adapter import IPAdapter +from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData @dataclass @@ -103,7 +104,7 @@ class TextConditioningData: uncond_regions: Optional[TextConditioningRegions], cond_regions: Optional[TextConditioningRegions], guidance_scale: Union[float, List[float]], - guidance_rescale_multiplier: float = 0, + guidance_rescale_multiplier: float = 0, # TODO: old backend, remove ): self.uncond_text = uncond_text self.cond_text = cond_text @@ -114,6 +115,7 @@ class TextConditioningData: # Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate # images that are closely linked to the text `prompt`, usually at the expense of lower image quality. self.guidance_scale = guidance_scale + # TODO: old backend, remove # For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7. # See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). self.guidance_rescale_multiplier = guidance_rescale_multiplier @@ -121,3 +123,127 @@ class TextConditioningData: def is_sdxl(self): assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo) return isinstance(self.cond_text, SDXLConditioningInfo) + + def to_unet_kwargs(self, unet_kwargs, conditioning_mode): + if conditioning_mode == "both": + encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch( + self.uncond_text.embeds, self.cond_text.embeds + ) + elif conditioning_mode == "positive": + encoder_hidden_states = self.cond_text.embeds + encoder_attention_mask = None + else: # elif conditioning_mode == "negative": + encoder_hidden_states = self.uncond_text.embeds + encoder_attention_mask = None + + unet_kwargs.encoder_hidden_states = encoder_hidden_states + unet_kwargs.encoder_attention_mask = encoder_attention_mask + + if self.is_sdxl(): + if conditioning_mode == "negative": + added_cond_kwargs = dict( # noqa: C408 + text_embeds=self.cond_text.pooled_embeds, + time_ids=self.cond_text.add_time_ids, + ) + elif conditioning_mode == "positive": + added_cond_kwargs = dict( # noqa: C408 + text_embeds=self.uncond_text.pooled_embeds, + time_ids=self.uncond_text.add_time_ids, + ) + else: # elif conditioning_mode == "both": + added_cond_kwargs = dict( # noqa: C408 + text_embeds=torch.cat( + [ + # TODO: how to pad? just by zeros? or even truncate? + self.uncond_text.pooled_embeds, + self.cond_text.pooled_embeds, + ], + ), + time_ids=torch.cat( + [ + self.uncond_text.add_time_ids, + self.cond_text.add_time_ids, + ], + ), + ) + + unet_kwargs.added_cond_kwargs = added_cond_kwargs + + if self.cond_regions is not None or self.uncond_regions is not None: + # TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings + # and masks are not changing from step-to-step, so this really only needs to be done once. While this seems + # painfully inefficient, the time spent is typically negligible compared to the forward inference pass of + # the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly + # awkward to handle both standard conditioning and sequential conditioning further up the stack. + + _tmp_regions = self.cond_regions if self.cond_regions is not None else self.uncond_regions + _, _, h, w = _tmp_regions.masks.shape + dtype = self.cond_text.embeds.dtype + device = self.cond_text.embeds.device + + regions = [] + for c, r in [ + (self.uncond_text, self.uncond_regions), + (self.cond_text, self.cond_regions), + ]: + if r is None: + # Create a dummy mask and range for text conditioning that doesn't have region masks. + r = TextConditioningRegions( + masks=torch.ones((1, 1, h, w), dtype=dtype), + ranges=[Range(start=0, end=c.embeds.shape[1])], + ) + regions.append(r) + + if unet_kwargs.cross_attention_kwargs is None: + unet_kwargs.cross_attention_kwargs = {} + + unet_kwargs.cross_attention_kwargs.update( + regional_prompt_data=RegionalPromptData(regions=regions, device=device, dtype=dtype), + ) + + def _concat_conditionings_for_batch(self, unconditioning, conditioning): + def _pad_conditioning(cond, target_len, encoder_attention_mask): + conditioning_attention_mask = torch.ones( + (cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype + ) + + if cond.shape[1] < max_len: + conditioning_attention_mask = torch.cat( + [ + conditioning_attention_mask, + torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype), + ], + dim=1, + ) + + cond = torch.cat( + [ + cond, + torch.zeros( + (cond.shape[0], max_len - cond.shape[1], cond.shape[2]), + device=cond.device, + dtype=cond.dtype, + ), + ], + dim=1, + ) + + if encoder_attention_mask is None: + encoder_attention_mask = conditioning_attention_mask + else: + encoder_attention_mask = torch.cat( + [ + encoder_attention_mask, + conditioning_attention_mask, + ] + ) + + return cond, encoder_attention_mask + + encoder_attention_mask = None + if unconditioning.shape[1] != conditioning.shape[1]: + max_len = max(unconditioning.shape[1], conditioning.shape[1]) + unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask) + conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask) + + return torch.cat([unconditioning, conditioning]), encoder_attention_mask diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py index f09cc0a0d2..eddd31f0c4 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -1,9 +1,14 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import torch import torch.nn.functional as F -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( - TextConditioningRegions, -) +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + TextConditioningRegions, + ) class RegionalPromptData: diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py new file mode 100644 index 0000000000..264fed2fe6 --- /dev/null +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import PIL.Image +import torch +import torchvision.transforms as T +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput +from tqdm.auto import tqdm + +from invokeai.app.services.config.config_default import get_config +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs +from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager + + +def trim_to_multiple_of(*args, multiple_of=8): + return tuple((x - x % multiple_of) for x in args) + + +def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool = True, multiple_of=8) -> torch.FloatTensor: + """ + + :param image: input image + :param normalize: scale the range to [-1, 1] instead of [0, 1] + :param multiple_of: resize the input so both dimensions are a multiple of this + """ + w, h = trim_to_multiple_of(*image.size, multiple_of=multiple_of) + transformation = T.Compose( + [ + T.Resize((h, w), T.InterpolationMode.LANCZOS, antialias=True), + T.ToTensor(), + ] + ) + tensor = transformation(image) + if normalize: + tensor = tensor * 2.0 - 1.0 + return tensor + + +class StableDiffusionBackend: + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + ): + self.unet = unet + self.scheduler = scheduler + config = get_config() + self.sequential_guidance = config.sequential_guidance + + def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): + if ctx.init_timestep.shape[0] == 0: + return ctx.latents + + ctx.orig_latents = ctx.latents.clone() + + if ctx.noise is not None: + batch_size = ctx.latents.shape[0] + # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers + ctx.latents = ctx.scheduler.add_noise(ctx.latents, ctx.noise, ctx.init_timestep.expand(batch_size)) + + # if no work to do, return latents + if ctx.timesteps.shape[0] == 0: + return ctx.latents + + # ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed) + # ext: preview[pre_denoise_loop, priority=low] + ext_manager.modifiers.pre_denoise_loop(ctx) + + for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.timesteps)): # noqa: B020 + # ext: inpaint (apply mask to latents on non-inpaint models) + ext_manager.modifiers.pre_step(ctx) + + # ext: tiles? [override: step] + ctx.step_output = ext_manager.overrides.step(self.step, ctx, ext_manager) + + # ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models) + # ext: preview[post_step, priority=low] + ext_manager.modifiers.post_step(ctx) + + ctx.latents = ctx.step_output.prev_sample + + # ext: inpaint[post_denoise_loop] (restore unmasked part) + ext_manager.modifiers.post_denoise_loop(ctx) + return ctx.latents + + @torch.inference_mode() + def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> SchedulerOutput: + ctx.latent_model_input = ctx.scheduler.scale_model_input(ctx.latents, ctx.timestep) + + if self.sequential_guidance: + conditioning_call = self._apply_standard_conditioning_sequentially + else: + conditioning_call = self._apply_standard_conditioning + + # not sure if here needed override + ctx.negative_noise_pred, ctx.positive_noise_pred = conditioning_call(ctx, ext_manager) + + # ext: override combine_noise + ctx.noise_pred = ext_manager.overrides.combine_noise(self.combine_noise, ctx) + + # ext: cfg_rescale [modify_noise_prediction] + ext_manager.modifiers.modify_noise_prediction(ctx) + + # compute the previous noisy sample x_t -> x_t-1 + step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.scheduler_step_kwargs) + + # del locals + del ctx.latent_model_input + del ctx.negative_noise_pred + del ctx.positive_noise_pred + del ctx.noise_pred + + return step_output + + @staticmethod + def combine_noise(ctx: DenoiseContext) -> torch.Tensor: + guidance_scale = ctx.conditioning_data.guidance_scale + if isinstance(guidance_scale, list): + guidance_scale = guidance_scale[ctx.step_index] + + return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) + # return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) + + def _apply_standard_conditioning( + self, ctx: DenoiseContext, ext_manager: ExtensionsManager + ) -> tuple[torch.Tensor, torch.Tensor]: + """Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at + the cost of higher memory usage. + """ + + ctx.unet_kwargs = UNetKwargs( + sample=torch.cat([ctx.latent_model_input] * 2), + timestep=ctx.timestep, + encoder_hidden_states=None, # set later by conditoning + cross_attention_kwargs=dict( # noqa: C408 + percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps, + ), + ) + + ctx.conditioning_mode = "both" + ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode) + + # ext: controlnet/ip/t2i [pre_unet_forward] + ext_manager.modifiers.pre_unet_forward(ctx) + + # ext: inpaint [pre_unet_forward, priority=low] + # or + # ext: inpaint [override: unet_forward] + both_results = self._unet_forward(**vars(ctx.unet_kwargs)) + negative_next_x, positive_next_x = both_results.chunk(2) + # del locals + del ctx.unet_kwargs + del ctx.conditioning_mode + return negative_next_x, positive_next_x + + def _apply_standard_conditioning_sequentially(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): + """Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of + slower execution speed. + """ + + ################### + # Negative pass + ################### + + ctx.unet_kwargs = UNetKwargs( + sample=ctx.latent_model_input, + timestep=ctx.timestep, + encoder_hidden_states=None, # set later by conditoning + cross_attention_kwargs=dict( # noqa: C408 + percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps, + ), + ) + + ctx.conditioning_mode = "negative" + ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "negative") + + # ext: controlnet/ip/t2i [pre_unet_forward] + ext_manager.modifiers.pre_unet_forward(ctx) + + # ext: inpaint [pre_unet_forward, priority=low] + # or + # ext: inpaint [override: unet_forward] + negative_next_x = self._unet_forward(**vars(ctx.unet_kwargs)) + + del ctx.unet_kwargs + del ctx.conditioning_mode + # TODO: gc.collect() ? + + ################### + # Positive pass + ################### + + ctx.unet_kwargs = UNetKwargs( + sample=ctx.latent_model_input, + timestep=ctx.timestep, + encoder_hidden_states=None, # set later by conditoning + cross_attention_kwargs=dict( # noqa: C408 + percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps, + ), + ) + + ctx.conditioning_mode = "positive" + ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "positive") + + # ext: controlnet/ip/t2i [pre_unet_forward] + ext_manager.modifiers.pre_unet_forward(ctx) + + # ext: inpaint [pre_unet_forward, priority=low] + # or + # ext: inpaint [override: unet_forward] + positive_next_x = self._unet_forward(**vars(ctx.unet_kwargs)) + + del ctx.unet_kwargs + del ctx.conditioning_mode + # TODO: gc.collect() ? + + return negative_next_x, positive_next_x + + def _unet_forward(self, **kwargs) -> torch.Tensor: + return self.unet(**kwargs).sample diff --git a/invokeai/backend/stable_diffusion/extensions/__init__.py b/invokeai/backend/stable_diffusion/extensions/__init__.py new file mode 100644 index 0000000000..395d9c1d9a --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/__init__.py @@ -0,0 +1,9 @@ +""" +Initialization file for the invokeai.backend.stable_diffusion.extensions package +""" + +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase + +__all__ = [ + "ExtensionBase", +] diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py new file mode 100644 index 0000000000..d3414eea6f --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -0,0 +1,58 @@ +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional + +import torch +from diffusers import UNet2DConditionModel + + +@dataclass +class InjectionInfo: + type: str + name: str + order: Optional[str] + function: Callable + + +def modifier(name: str, order: str = "any"): + def _decorator(func): + func.__inj_info__ = { + "type": "modifier", + "name": name, + "order": order, + } + return func + + return _decorator + + +def override(name: str): + def _decorator(func): + func.__inj_info__ = { + "type": "override", + "name": name, + "order": None, + } + return func + + return _decorator + + +class ExtensionBase: + def __init__(self, priority: int): + self.priority = priority + self.injections: List[InjectionInfo] = [] + for func_name in dir(self): + func = getattr(self, func_name) + if not callable(func) or not hasattr(func, "__inj_info__"): + continue + + self.injections.append(InjectionInfo(**func.__inj_info__, function=func)) + + @contextmanager + def patch_attention_processor(self, attention_processor_cls: object): + yield None + + @contextmanager + def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): + yield None diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py new file mode 100644 index 0000000000..2e6882e0ca --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from contextlib import ExitStack, contextmanager +from functools import partial +from typing import TYPE_CHECKING, Callable, Dict + +import torch +from diffusers import UNet2DConditionModel + +from invokeai.backend.util.devices import TorchDevice + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + from invokeai.backend.stable_diffusion.extensions import ExtensionBase + + +class ExtModifiersApi(ABC): + @abstractmethod + def pre_denoise_loop(self, ctx: DenoiseContext): + pass + + @abstractmethod + def post_denoise_loop(self, ctx: DenoiseContext): + pass + + @abstractmethod + def pre_step(self, ctx: DenoiseContext): + pass + + @abstractmethod + def post_step(self, ctx: DenoiseContext): + pass + + @abstractmethod + def modify_noise_prediction(self, ctx: DenoiseContext): + pass + + @abstractmethod + def pre_unet_forward(self, ctx: DenoiseContext): + pass + + @abstractmethod + def pre_unet_load(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): + pass + + +class ExtOverridesApi(ABC): + @abstractmethod + def step(self, orig_func: Callable, ctx: DenoiseContext, ext_manager: ExtensionsManager): + pass + + @abstractmethod + def combine_noise(self, orig_func: Callable, ctx: DenoiseContext): + pass + + +class ProxyCallsClass: + def __init__(self, handler): + self._handler = handler + + def __getattr__(self, item): + return partial(self._handler, item) + + +class ModifierInjectionPoint: + def __init__(self): + self.first = [] + self.any = [] + self.last = [] + + def add(self, func: Callable, order: str): + if order == "first": + self.first.append(func) + elif order == "last": + self.last.append(func) + else: # elif order == "any": + self.any.append(func) + + def __call__(self, *args, **kwargs): + for func in self.first: + func(*args, **kwargs) + for func in self.any: + func(*args, **kwargs) + for func in reversed(self.last): + func(*args, **kwargs) + + +class ExtensionsManager: + def __init__(self): + self.extensions = [] + + self._overrides = {} + self._modifiers = {} + + self.modifiers: ExtModifiersApi = ProxyCallsClass(self.call_modifier) + self.overrides: ExtOverridesApi = ProxyCallsClass(self.call_override) + + def add_extension(self, ext: ExtensionBase): + self.extensions.append(ext) + ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority) + + self._overrides.clear() + self._modifiers.clear() + + for ext in ordered_extensions: + for inj_info in ext.injections: + if inj_info.type == "modifier": + if inj_info.name not in self._modifiers: + self._modifiers[inj_info.name] = ModifierInjectionPoint() + self._modifiers[inj_info.name].add(inj_info.function, inj_info.order) + + else: + if inj_info.name in self._overrides: + raise Exception(f"Already overloaded - {inj_info.name}") + self._overrides[inj_info.name] = inj_info.function + + def call_modifier(self, name: str, *args, **kwargs): + if name in self._modifiers: + self._modifiers[name](*args, **kwargs) + + def call_override(self, name: str, orig_func: Callable, *args, **kwargs): + if name in self._overrides: + return self._overrides[name](orig_func, *args, **kwargs) + else: + return orig_func(*args, **kwargs) + + # TODO: is there any need in such high abstarction + # @contextmanager + # def patch_extensions(self): + # exit_stack = ExitStack() + # try: + # for ext in self.extensions: + # exit_stack.enter_context(ext.patch_extension(self)) + # + # yield None + # + # finally: + # exit_stack.close() + + @contextmanager + def patch_attention_processor(self, unet: UNet2DConditionModel, attn_processor_cls: object): + unet_orig_processors = unet.attn_processors + exit_stack = ExitStack() + try: + # just to be sure that attentions have not same processor instance + attn_procs = {} + for name in unet.attn_processors.keys(): + attn_procs[name] = attn_processor_cls() + unet.set_attn_processor(attn_procs) + + for ext in self.extensions: + exit_stack.enter_context(ext.patch_attention_processor(attn_processor_cls)) + + yield None + + finally: + unet.set_attn_processor(unet_orig_processors) + exit_stack.close() + + @contextmanager + def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): + exit_stack = ExitStack() + try: + changed_keys = set() + changed_unknown_keys = {} + + ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority) + for ext in ordered_extensions: + patch_result = exit_stack.enter_context(ext.patch_unet(state_dict, unet)) + if patch_result is None: + continue + new_keys, new_unk_keys = patch_result + changed_keys.update(new_keys) + # skip already seen keys, as new weight might be changed + for k, v in new_unk_keys.items(): + if k in changed_unknown_keys: + continue + changed_unknown_keys[k] = v + + yield None + + finally: + exit_stack.close() + assert hasattr(unet, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() + with torch.no_grad(): + for module_key in changed_keys: + weight = state_dict[module_key] + unet.get_submodule(module_key).weight.copy_( + weight, non_blocking=TorchDevice.get_non_blocking(weight.device) + ) + for module_key, weight in changed_unknown_keys.items(): + unet.get_submodule(module_key).weight.copy_( + weight, non_blocking=TorchDevice.get_non_blocking(weight.device) + ) From 0bc60378d36595bb97a96d44ce716d2cae89fc8e Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Fri, 12 Jul 2024 20:43:32 +0300 Subject: [PATCH 02/25] A bit rework conditioning convert to unet kwargs --- .../diffusion/conditioning_data.py | 120 ++++++------------ 1 file changed, 40 insertions(+), 80 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 3a758839ea..f5b02889e9 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -125,125 +125,85 @@ class TextConditioningData: return isinstance(self.cond_text, SDXLConditioningInfo) def to_unet_kwargs(self, unet_kwargs, conditioning_mode): + _, _, h, w = unet_kwargs.sample.shape + device = unet_kwargs.sample.device + dtype = unet_kwargs.sample.dtype + if conditioning_mode == "both": - encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch( - self.uncond_text.embeds, self.cond_text.embeds - ) + conditionings = [self.uncond_text.embeds, self.cond_text.embeds] + c_regions = [self.uncond_regions, self.cond_regions] elif conditioning_mode == "positive": - encoder_hidden_states = self.cond_text.embeds - encoder_attention_mask = None - else: # elif conditioning_mode == "negative": - encoder_hidden_states = self.uncond_text.embeds - encoder_attention_mask = None + conditionings = [self.cond_text.embeds] + c_regions = [self.cond_regions] + else: + conditionings = [self.uncond_text.embeds] + c_regions = [self.uncond_regions] + + encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(conditionings) unet_kwargs.encoder_hidden_states = encoder_hidden_states unet_kwargs.encoder_attention_mask = encoder_attention_mask if self.is_sdxl(): - if conditioning_mode == "negative": - added_cond_kwargs = dict( # noqa: C408 - text_embeds=self.cond_text.pooled_embeds, - time_ids=self.cond_text.add_time_ids, - ) - elif conditioning_mode == "positive": - added_cond_kwargs = dict( # noqa: C408 - text_embeds=self.uncond_text.pooled_embeds, - time_ids=self.uncond_text.add_time_ids, - ) - else: # elif conditioning_mode == "both": - added_cond_kwargs = dict( # noqa: C408 - text_embeds=torch.cat( - [ - # TODO: how to pad? just by zeros? or even truncate? - self.uncond_text.pooled_embeds, - self.cond_text.pooled_embeds, - ], - ), - time_ids=torch.cat( - [ - self.uncond_text.add_time_ids, - self.cond_text.add_time_ids, - ], - ), - ) + added_cond_kwargs = dict( # noqa: C408 + text_embeds=torch.cat([c.pooled_embeds for c in conditionings]), + time_ids=torch.cat([c.add_time_ids for c in conditionings]), + ) unet_kwargs.added_cond_kwargs = added_cond_kwargs - if self.cond_regions is not None or self.uncond_regions is not None: - # TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings - # and masks are not changing from step-to-step, so this really only needs to be done once. While this seems - # painfully inefficient, the time spent is typically negligible compared to the forward inference pass of - # the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly - # awkward to handle both standard conditioning and sequential conditioning further up the stack. - - _tmp_regions = self.cond_regions if self.cond_regions is not None else self.uncond_regions - _, _, h, w = _tmp_regions.masks.shape - dtype = self.cond_text.embeds.dtype - device = self.cond_text.embeds.device - - regions = [] - for c, r in [ - (self.uncond_text, self.uncond_regions), - (self.cond_text, self.cond_regions), - ]: + if any(r is not None for r in c_regions): + tmp_regions = [] + for c, r in zip(conditionings, c_regions, strict=True): if r is None: - # Create a dummy mask and range for text conditioning that doesn't have region masks. r = TextConditioningRegions( masks=torch.ones((1, 1, h, w), dtype=dtype), ranges=[Range(start=0, end=c.embeds.shape[1])], ) - regions.append(r) + tmp_regions.append(r) if unet_kwargs.cross_attention_kwargs is None: unet_kwargs.cross_attention_kwargs = {} unet_kwargs.cross_attention_kwargs.update( - regional_prompt_data=RegionalPromptData(regions=regions, device=device, dtype=dtype), + regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype), ) - def _concat_conditionings_for_batch(self, unconditioning, conditioning): + def _concat_conditionings_for_batch(self, conditionings): + def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int): + return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim) + def _pad_conditioning(cond, target_len, encoder_attention_mask): conditioning_attention_mask = torch.ones( (cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype ) if cond.shape[1] < max_len: - conditioning_attention_mask = torch.cat( - [ - conditioning_attention_mask, - torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype), - ], + conditioning_attention_mask = _pad_zeros( + conditioning_attention_mask, + pad_shape=(cond.shape[0], max_len - cond.shape[1]), dim=1, ) - cond = torch.cat( - [ - cond, - torch.zeros( - (cond.shape[0], max_len - cond.shape[1], cond.shape[2]), - device=cond.device, - dtype=cond.dtype, - ), - ], + cond = _pad_zeros( + cond, + pad_shape=(cond.shape[0], max_len - cond.shape[1], cond.shape[2]), dim=1, ) if encoder_attention_mask is None: encoder_attention_mask = conditioning_attention_mask else: - encoder_attention_mask = torch.cat( - [ - encoder_attention_mask, - conditioning_attention_mask, - ] - ) + encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask]) return cond, encoder_attention_mask encoder_attention_mask = None - if unconditioning.shape[1] != conditioning.shape[1]: - max_len = max(unconditioning.shape[1], conditioning.shape[1]) - unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask) - conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask) + max_len = max([c.shape[1] for c in conditionings]) + if any(c.shape[1] != max_len for c in conditionings): + for i in range(len(conditionings)): + conditionings[i], encoder_attention_mask = _pad_conditioning( + conditionings[i], max_len, encoder_attention_mask + ) - return torch.cat([unconditioning, conditioning]), encoder_attention_mask + return torch.cat(conditionings), encoder_attention_mask From 87e96e1be2d3cb9dee1f08c5b254b3089637b555 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Fri, 12 Jul 2024 22:01:05 +0300 Subject: [PATCH 03/25] Rename modifiers to callbacks, convert order to int, a bit unify injection points --- invokeai/app/invocations/denoise_latents.py | 2 +- .../stable_diffusion/diffusion_backend.py | 48 ++++---------- .../stable_diffusion/extensions/base.py | 6 +- .../stable_diffusion/extensions_manager.py | 62 ++++++++----------- 4 files changed, 42 insertions(+), 76 deletions(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index bec8741936..1bc66e423f 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -781,7 +781,7 @@ class DenoiseLatentsInvocation(BaseInvocation): unet_config = context.models.get_config(self.unet.unet.key) # ext: t2i/ip adapter - ext_manager.modifiers.pre_unet_load(denoise_ctx, ext_manager) + ext_manager.callbacks.pre_unet_load(denoise_ctx, ext_manager) unet_info = context.models.load(self.unet.unet) assert isinstance(unet_info.model, UNet2DConditionModel) diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 264fed2fe6..4630d4740d 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -1,8 +1,6 @@ from __future__ import annotations -import PIL.Image import torch -import torchvision.transforms as T from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput from tqdm.auto import tqdm @@ -12,30 +10,6 @@ from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UN from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager -def trim_to_multiple_of(*args, multiple_of=8): - return tuple((x - x % multiple_of) for x in args) - - -def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool = True, multiple_of=8) -> torch.FloatTensor: - """ - - :param image: input image - :param normalize: scale the range to [-1, 1] instead of [0, 1] - :param multiple_of: resize the input so both dimensions are a multiple of this - """ - w, h = trim_to_multiple_of(*image.size, multiple_of=multiple_of) - transformation = T.Compose( - [ - T.Resize((h, w), T.InterpolationMode.LANCZOS, antialias=True), - T.ToTensor(), - ] - ) - tensor = transformation(image) - if normalize: - tensor = tensor * 2.0 - 1.0 - return tensor - - class StableDiffusionBackend: def __init__( self, @@ -64,23 +38,23 @@ class StableDiffusionBackend: # ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed) # ext: preview[pre_denoise_loop, priority=low] - ext_manager.modifiers.pre_denoise_loop(ctx) + ext_manager.callbacks.pre_denoise_loop(ctx, ext_manager) for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.timesteps)): # noqa: B020 # ext: inpaint (apply mask to latents on non-inpaint models) - ext_manager.modifiers.pre_step(ctx) + ext_manager.callbacks.pre_step(ctx, ext_manager) # ext: tiles? [override: step] ctx.step_output = ext_manager.overrides.step(self.step, ctx, ext_manager) # ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models) # ext: preview[post_step, priority=low] - ext_manager.modifiers.post_step(ctx) + ext_manager.callbacks.post_step(ctx, ext_manager) ctx.latents = ctx.step_output.prev_sample # ext: inpaint[post_denoise_loop] (restore unmasked part) - ext_manager.modifiers.post_denoise_loop(ctx) + ext_manager.callbacks.post_denoise_loop(ctx, ext_manager) return ctx.latents @torch.inference_mode() @@ -95,11 +69,11 @@ class StableDiffusionBackend: # not sure if here needed override ctx.negative_noise_pred, ctx.positive_noise_pred = conditioning_call(ctx, ext_manager) - # ext: override combine_noise - ctx.noise_pred = ext_manager.overrides.combine_noise(self.combine_noise, ctx) + # ext: override apply_cfg + ctx.noise_pred = ext_manager.overrides.apply_cfg(self.apply_cfg, ctx) # ext: cfg_rescale [modify_noise_prediction] - ext_manager.modifiers.modify_noise_prediction(ctx) + ext_manager.callbacks.modify_noise_prediction(ctx, ext_manager) # compute the previous noisy sample x_t -> x_t-1 step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.scheduler_step_kwargs) @@ -113,7 +87,7 @@ class StableDiffusionBackend: return step_output @staticmethod - def combine_noise(ctx: DenoiseContext) -> torch.Tensor: + def apply_cfg(ctx: DenoiseContext) -> torch.Tensor: guidance_scale = ctx.conditioning_data.guidance_scale if isinstance(guidance_scale, list): guidance_scale = guidance_scale[ctx.step_index] @@ -141,7 +115,7 @@ class StableDiffusionBackend: ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode) # ext: controlnet/ip/t2i [pre_unet_forward] - ext_manager.modifiers.pre_unet_forward(ctx) + ext_manager.callbacks.pre_unet_forward(ctx, ext_manager) # ext: inpaint [pre_unet_forward, priority=low] # or @@ -175,7 +149,7 @@ class StableDiffusionBackend: ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "negative") # ext: controlnet/ip/t2i [pre_unet_forward] - ext_manager.modifiers.pre_unet_forward(ctx) + ext_manager.callbacks.pre_unet_forward(ctx, ext_manager) # ext: inpaint [pre_unet_forward, priority=low] # or @@ -203,7 +177,7 @@ class StableDiffusionBackend: ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "positive") # ext: controlnet/ip/t2i [pre_unet_forward] - ext_manager.modifiers.pre_unet_forward(ctx) + ext_manager.callbacks.pre_unet_forward(ctx, ext_manager) # ext: inpaint [pre_unet_forward, priority=low] # or diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index d3414eea6f..79227921c3 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -10,14 +10,14 @@ from diffusers import UNet2DConditionModel class InjectionInfo: type: str name: str - order: Optional[str] + order: Optional[int] function: Callable -def modifier(name: str, order: str = "any"): +def callback(name: str, order: int = 0): def _decorator(func): func.__inj_info__ = { - "type": "modifier", + "type": "callback", "name": name, "order": order, } diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 2e6882e0ca..1d4892a982 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -15,29 +15,29 @@ if TYPE_CHECKING: from invokeai.backend.stable_diffusion.extensions import ExtensionBase -class ExtModifiersApi(ABC): +class ExtCallbacksApi(ABC): @abstractmethod - def pre_denoise_loop(self, ctx: DenoiseContext): + def pre_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): pass @abstractmethod - def post_denoise_loop(self, ctx: DenoiseContext): + def post_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): pass @abstractmethod - def pre_step(self, ctx: DenoiseContext): + def pre_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): pass @abstractmethod - def post_step(self, ctx: DenoiseContext): + def post_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): pass @abstractmethod - def modify_noise_prediction(self, ctx: DenoiseContext): + def modify_noise_prediction(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): pass @abstractmethod - def pre_unet_forward(self, ctx: DenoiseContext): + def pre_unet_forward(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): pass @abstractmethod @@ -51,7 +51,7 @@ class ExtOverridesApi(ABC): pass @abstractmethod - def combine_noise(self, orig_func: Callable, ctx: DenoiseContext): + def apply_cfg(self, orig_func: Callable, ctx: DenoiseContext): pass @@ -63,27 +63,19 @@ class ProxyCallsClass: return partial(self._handler, item) -class ModifierInjectionPoint: +class CallbackInjectionPoint: def __init__(self): - self.first = [] - self.any = [] - self.last = [] + self.handlers = {} - def add(self, func: Callable, order: str): - if order == "first": - self.first.append(func) - elif order == "last": - self.last.append(func) - else: # elif order == "any": - self.any.append(func) + def add(self, func: Callable, order: int): + if order not in self.handlers: + self.handlers[order] = [] + self.handlers[order].append(func) def __call__(self, *args, **kwargs): - for func in self.first: - func(*args, **kwargs) - for func in self.any: - func(*args, **kwargs) - for func in reversed(self.last): - func(*args, **kwargs) + for order in sorted(self.handlers.keys(), reverse=True): + for handler in self.handlers[order]: + handler(*args, **kwargs) class ExtensionsManager: @@ -91,9 +83,9 @@ class ExtensionsManager: self.extensions = [] self._overrides = {} - self._modifiers = {} + self._callbacks = {} - self.modifiers: ExtModifiersApi = ProxyCallsClass(self.call_modifier) + self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback) self.overrides: ExtOverridesApi = ProxyCallsClass(self.call_override) def add_extension(self, ext: ExtensionBase): @@ -101,23 +93,23 @@ class ExtensionsManager: ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority) self._overrides.clear() - self._modifiers.clear() + self._callbacks.clear() for ext in ordered_extensions: for inj_info in ext.injections: - if inj_info.type == "modifier": - if inj_info.name not in self._modifiers: - self._modifiers[inj_info.name] = ModifierInjectionPoint() - self._modifiers[inj_info.name].add(inj_info.function, inj_info.order) + if inj_info.type == "callback": + if inj_info.name not in self._callbacks: + self._callbacks[inj_info.name] = CallbackInjectionPoint() + self._callbacks[inj_info.name].add(inj_info.function, inj_info.order) else: if inj_info.name in self._overrides: raise Exception(f"Already overloaded - {inj_info.name}") self._overrides[inj_info.name] = inj_info.function - def call_modifier(self, name: str, *args, **kwargs): - if name in self._modifiers: - self._modifiers[name](*args, **kwargs) + def call_callback(self, name: str, *args, **kwargs): + if name in self._callbacks: + self._callbacks[name](*args, **kwargs) def call_override(self, name: str, orig_func: Callable, *args, **kwargs): if name in self._overrides: From bd8ae5d896eecdd6bef6fe8101ae7cb650e2b4f7 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Fri, 12 Jul 2024 22:01:37 +0300 Subject: [PATCH 04/25] Simplify guidance modes --- .../stable_diffusion/diffusion_backend.py | 92 +++---------------- 1 file changed, 14 insertions(+), 78 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 4630d4740d..561624609b 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -61,18 +61,19 @@ class StableDiffusionBackend: def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> SchedulerOutput: ctx.latent_model_input = ctx.scheduler.scale_model_input(ctx.latents, ctx.timestep) + # TODO: conditionings as list if self.sequential_guidance: - conditioning_call = self._apply_standard_conditioning_sequentially + ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, "negative") + ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, "positive") else: - conditioning_call = self._apply_standard_conditioning - - # not sure if here needed override - ctx.negative_noise_pred, ctx.positive_noise_pred = conditioning_call(ctx, ext_manager) + both_noise_pred = self.run_unet(ctx, ext_manager, "both") + ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2) # ext: override apply_cfg ctx.noise_pred = ext_manager.overrides.apply_cfg(self.apply_cfg, ctx) # ext: cfg_rescale [modify_noise_prediction] + # TODO: rename ext_manager.callbacks.modify_noise_prediction(ctx, ext_manager) # compute the previous noisy sample x_t -> x_t-1 @@ -95,15 +96,13 @@ class StableDiffusionBackend: return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) # return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) - def _apply_standard_conditioning( - self, ctx: DenoiseContext, ext_manager: ExtensionsManager - ) -> tuple[torch.Tensor, torch.Tensor]: - """Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at - the cost of higher memory usage. - """ + def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: str): + sample = ctx.latent_model_input + if conditioning_mode == "both": + sample = torch.cat([sample] * 2) ctx.unet_kwargs = UNetKwargs( - sample=torch.cat([ctx.latent_model_input] * 2), + sample=sample, timestep=ctx.timestep, encoder_hidden_states=None, # set later by conditoning cross_attention_kwargs=dict( # noqa: C408 @@ -111,7 +110,7 @@ class StableDiffusionBackend: ), ) - ctx.conditioning_mode = "both" + ctx.conditioning_mode = conditioning_mode ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode) # ext: controlnet/ip/t2i [pre_unet_forward] @@ -120,75 +119,12 @@ class StableDiffusionBackend: # ext: inpaint [pre_unet_forward, priority=low] # or # ext: inpaint [override: unet_forward] - both_results = self._unet_forward(**vars(ctx.unet_kwargs)) - negative_next_x, positive_next_x = both_results.chunk(2) - # del locals - del ctx.unet_kwargs - del ctx.conditioning_mode - return negative_next_x, positive_next_x - - def _apply_standard_conditioning_sequentially(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): - """Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of - slower execution speed. - """ - - ################### - # Negative pass - ################### - - ctx.unet_kwargs = UNetKwargs( - sample=ctx.latent_model_input, - timestep=ctx.timestep, - encoder_hidden_states=None, # set later by conditoning - cross_attention_kwargs=dict( # noqa: C408 - percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps, - ), - ) - - ctx.conditioning_mode = "negative" - ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "negative") - - # ext: controlnet/ip/t2i [pre_unet_forward] - ext_manager.callbacks.pre_unet_forward(ctx, ext_manager) - - # ext: inpaint [pre_unet_forward, priority=low] - # or - # ext: inpaint [override: unet_forward] - negative_next_x = self._unet_forward(**vars(ctx.unet_kwargs)) + noise_pred = self._unet_forward(**vars(ctx.unet_kwargs)) del ctx.unet_kwargs del ctx.conditioning_mode - # TODO: gc.collect() ? - ################### - # Positive pass - ################### - - ctx.unet_kwargs = UNetKwargs( - sample=ctx.latent_model_input, - timestep=ctx.timestep, - encoder_hidden_states=None, # set later by conditoning - cross_attention_kwargs=dict( # noqa: C408 - percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps, - ), - ) - - ctx.conditioning_mode = "positive" - ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "positive") - - # ext: controlnet/ip/t2i [pre_unet_forward] - ext_manager.callbacks.pre_unet_forward(ctx, ext_manager) - - # ext: inpaint [pre_unet_forward, priority=low] - # or - # ext: inpaint [override: unet_forward] - positive_next_x = self._unet_forward(**vars(ctx.unet_kwargs)) - - del ctx.unet_kwargs - del ctx.conditioning_mode - # TODO: gc.collect() ? - - return negative_next_x, positive_next_x + return noise_pred def _unet_forward(self, **kwargs) -> torch.Tensor: return self.unet(**kwargs).sample From 3a9dda9177946b648b2aca1747635ba014a5ab12 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Fri, 12 Jul 2024 22:44:00 +0300 Subject: [PATCH 05/25] Renames --- invokeai/app/invocations/denoise_latents.py | 2 +- .../diffusion/conditioning_data.py | 1 + .../backend/stable_diffusion/diffusion_backend.py | 15 ++++++++++----- .../stable_diffusion/extensions_manager.py | 10 +++++++--- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 1bc66e423f..beced2a283 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -781,7 +781,7 @@ class DenoiseLatentsInvocation(BaseInvocation): unet_config = context.models.get_config(self.unet.unet.key) # ext: t2i/ip adapter - ext_manager.callbacks.pre_unet_load(denoise_ctx, ext_manager) + ext_manager.callbacks.setup(denoise_ctx, ext_manager) unet_info = context.models.load(self.unet.unet) assert isinstance(unet_info.model, UNet2DConditionModel) diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index f5b02889e9..802aec0109 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -129,6 +129,7 @@ class TextConditioningData: device = unet_kwargs.sample.device dtype = unet_kwargs.sample.dtype + # TODO: combine regions with conditionings if conditioning_mode == "both": conditionings = [self.uncond_text.embeds, self.cond_text.embeds] c_regions = [self.uncond_regions, self.cond_regions] diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 561624609b..5f02ab93fc 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -61,7 +61,10 @@ class StableDiffusionBackend: def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> SchedulerOutput: ctx.latent_model_input = ctx.scheduler.scale_model_input(ctx.latents, ctx.timestep) - # TODO: conditionings as list + # TODO: conditionings as list(conditioning_data.to_unet_kwargs - ready) + # Note: The current handling of conditioning doesn't feel very future-proof. + # This might change in the future as new requirements come up, but for now, + # this is the rough plan. if self.sequential_guidance: ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, "negative") ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, "positive") @@ -74,7 +77,7 @@ class StableDiffusionBackend: # ext: cfg_rescale [modify_noise_prediction] # TODO: rename - ext_manager.callbacks.modify_noise_prediction(ctx, ext_manager) + ext_manager.callbacks.post_apply_cfg(ctx, ext_manager) # compute the previous noisy sample x_t -> x_t-1 step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.scheduler_step_kwargs) @@ -113,14 +116,16 @@ class StableDiffusionBackend: ctx.conditioning_mode = conditioning_mode ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode) - # ext: controlnet/ip/t2i [pre_unet_forward] - ext_manager.callbacks.pre_unet_forward(ctx, ext_manager) + # ext: controlnet/ip/t2i [pre_unet] + ext_manager.callbacks.pre_unet(ctx, ext_manager) - # ext: inpaint [pre_unet_forward, priority=low] + # ext: inpaint [pre_unet, priority=low] # or # ext: inpaint [override: unet_forward] noise_pred = self._unet_forward(**vars(ctx.unet_kwargs)) + ext_manager.callbacks.post_unet(ctx, ext_manager) + del ctx.unet_kwargs del ctx.conditioning_mode diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 1d4892a982..ef88dec1a3 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -16,6 +16,10 @@ if TYPE_CHECKING: class ExtCallbacksApi(ABC): + @abstractmethod + def setup(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): + pass + @abstractmethod def pre_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): pass @@ -33,15 +37,15 @@ class ExtCallbacksApi(ABC): pass @abstractmethod - def modify_noise_prediction(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): + def pre_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): pass @abstractmethod - def pre_unet_forward(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): + def post_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): pass @abstractmethod - def pre_unet_load(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): + def post_apply_cfg(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): pass From 7e00526999d65c8ea35f0563423c87f43279b3b8 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sat, 13 Jul 2024 00:28:56 +0300 Subject: [PATCH 06/25] Remove overrides logic for now --- .../stable_diffusion/diffusion_backend.py | 4 ++-- .../stable_diffusion/extensions/base.py | 12 ---------- .../stable_diffusion/extensions_manager.py | 24 +------------------ 3 files changed, 3 insertions(+), 37 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 5f02ab93fc..81fe00d59d 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -45,7 +45,7 @@ class StableDiffusionBackend: ext_manager.callbacks.pre_step(ctx, ext_manager) # ext: tiles? [override: step] - ctx.step_output = ext_manager.overrides.step(self.step, ctx, ext_manager) + ctx.step_output = self.step(ctx, ext_manager) # ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models) # ext: preview[post_step, priority=low] @@ -73,7 +73,7 @@ class StableDiffusionBackend: ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2) # ext: override apply_cfg - ctx.noise_pred = ext_manager.overrides.apply_cfg(self.apply_cfg, ctx) + ctx.noise_pred = self.apply_cfg(ctx) # ext: cfg_rescale [modify_noise_prediction] # TODO: rename diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index 79227921c3..1617fcf1c1 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -26,18 +26,6 @@ def callback(name: str, order: int = 0): return _decorator -def override(name: str): - def _decorator(func): - func.__inj_info__ = { - "type": "override", - "name": name, - "order": None, - } - return func - - return _decorator - - class ExtensionBase: def __init__(self, priority: int): self.priority = priority diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index ef88dec1a3..26e7198fa4 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -49,16 +49,6 @@ class ExtCallbacksApi(ABC): pass -class ExtOverridesApi(ABC): - @abstractmethod - def step(self, orig_func: Callable, ctx: DenoiseContext, ext_manager: ExtensionsManager): - pass - - @abstractmethod - def apply_cfg(self, orig_func: Callable, ctx: DenoiseContext): - pass - - class ProxyCallsClass: def __init__(self, handler): self._handler = handler @@ -86,17 +76,13 @@ class ExtensionsManager: def __init__(self): self.extensions = [] - self._overrides = {} self._callbacks = {} - self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback) - self.overrides: ExtOverridesApi = ProxyCallsClass(self.call_override) def add_extension(self, ext: ExtensionBase): self.extensions.append(ext) ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority) - self._overrides.clear() self._callbacks.clear() for ext in ordered_extensions: @@ -107,20 +93,12 @@ class ExtensionsManager: self._callbacks[inj_info.name].add(inj_info.function, inj_info.order) else: - if inj_info.name in self._overrides: - raise Exception(f"Already overloaded - {inj_info.name}") - self._overrides[inj_info.name] = inj_info.function + raise Exception(f"Unsupported injection type: {inj_info.type}") def call_callback(self, name: str, *args, **kwargs): if name in self._callbacks: self._callbacks[name](*args, **kwargs) - def call_override(self, name: str, orig_func: Callable, *args, **kwargs): - if name in self._overrides: - return self._overrides[name](orig_func, *args, **kwargs) - else: - return orig_func(*args, **kwargs) - # TODO: is there any need in such high abstarction # @contextmanager # def patch_extensions(self): From e961dd1decfa1b96f4bbad08a233111d17458b67 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sat, 13 Jul 2024 00:44:21 +0300 Subject: [PATCH 07/25] Remove remains of priority logic --- invokeai/backend/stable_diffusion/extensions/base.py | 3 +-- invokeai/backend/stable_diffusion/extensions_manager.py | 6 ++---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index 1617fcf1c1..3effa77da4 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -27,8 +27,7 @@ def callback(name: str, order: int = 0): class ExtensionBase: - def __init__(self, priority: int): - self.priority = priority + def __init__(self): self.injections: List[InjectionInfo] = [] for func_name in dir(self): func = getattr(self, func_name) diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 26e7198fa4..876fd96d39 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -81,11 +81,10 @@ class ExtensionsManager: def add_extension(self, ext: ExtensionBase): self.extensions.append(ext) - ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority) self._callbacks.clear() - for ext in ordered_extensions: + for ext in self.extensions: for inj_info in ext.injections: if inj_info.type == "callback": if inj_info.name not in self._callbacks: @@ -139,8 +138,7 @@ class ExtensionsManager: changed_keys = set() changed_unknown_keys = {} - ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority) - for ext in ordered_extensions: + for ext in self.extensions: patch_result = exit_stack.enter_context(ext.patch_unet(state_dict, unet)) if patch_result is None: continue From 499e4d4fded022eaf2f238192c88f2cc165077c3 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sat, 13 Jul 2024 00:45:04 +0300 Subject: [PATCH 08/25] Add preview extension to check logic --- invokeai/app/invocations/denoise_latents.py | 7 +++ .../stable_diffusion/diffusers_pipeline.py | 18 +++--- .../stable_diffusion/extensions/__init__.py | 3 + .../stable_diffusion/extensions/preview.py | 63 +++++++++++++++++++ 4 files changed, 82 insertions(+), 9 deletions(-) create mode 100644 invokeai/backend/stable_diffusion/extensions/preview.py diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index beced2a283..c0a74756cb 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -57,6 +57,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( ) from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend +from invokeai.backend.stable_diffusion.extensions import PreviewExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES @@ -777,6 +778,12 @@ class DenoiseLatentsInvocation(BaseInvocation): scheduler=scheduler, ) + ### preview + def step_callback(state: PipelineIntermediateState) -> None: + context.util.sd_step_callback(state, unet_config.base) + + ext_manager.add_extension(PreviewExt(step_callback)) + # get the unet's config so that we can pass the base to sd_step_callback() unet_config = context.models.get_config(self.unet.unet.key) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index ee464f73e1..216e4d3bd1 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -23,19 +23,19 @@ from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData +from invokeai.backend.stable_diffusion.extensions import PipelineIntermediateState from invokeai.backend.util.attention import auto_detect_slice_size from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.hotfixes import ControlNetModel - -@dataclass -class PipelineIntermediateState: - step: int - order: int - total_steps: int - timestep: int - latents: torch.Tensor - predicted_original: Optional[torch.Tensor] = None +# @dataclass +# class PipelineIntermediateState: +# step: int +# order: int +# total_steps: int +# timestep: int +# latents: torch.Tensor +# predicted_original: Optional[torch.Tensor] = None @dataclass diff --git a/invokeai/backend/stable_diffusion/extensions/__init__.py b/invokeai/backend/stable_diffusion/extensions/__init__.py index 395d9c1d9a..faf0a1e1ec 100644 --- a/invokeai/backend/stable_diffusion/extensions/__init__.py +++ b/invokeai/backend/stable_diffusion/extensions/__init__.py @@ -3,7 +3,10 @@ Initialization file for the invokeai.backend.stable_diffusion.extensions package """ from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase +from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState, PreviewExt __all__ = [ "ExtensionBase", + "PipelineIntermediateState", + "PreviewExt", ] diff --git a/invokeai/backend/stable_diffusion/extensions/preview.py b/invokeai/backend/stable_diffusion/extensions/preview.py new file mode 100644 index 0000000000..73a1eef3c5 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/preview.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Optional + +import torch + +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager + + +# TODO: change event to accept image instead of latents +@dataclass +class PipelineIntermediateState: + step: int + order: int + total_steps: int + timestep: int + latents: torch.Tensor + predicted_original: Optional[torch.Tensor] = None + + +class PreviewExt(ExtensionBase): + def __init__(self, callback: Callable[[PipelineIntermediateState], None]): + super().__init__() + self.callback = callback + + # do last so that all other changes shown + @callback("pre_denoise_loop", order=1000) + def initial_preview(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): + self.callback( + PipelineIntermediateState( + step=-1, + order=ctx.scheduler.order, + total_steps=len(ctx.timesteps), + timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it? + latents=ctx.latents, + ) + ) + + # do last so that all other changes shown + @callback("post_step", order=1000) + def step_preview(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): + if hasattr(ctx.step_output, "denoised"): + predicted_original = ctx.step_output.denoised + elif hasattr(ctx.step_output, "pred_original_sample"): + predicted_original = ctx.step_output.pred_original_sample + else: + predicted_original = ctx.step_output.prev_sample + + self.callback( + PipelineIntermediateState( + step=ctx.step_index, + order=ctx.scheduler.order, + total_steps=len(ctx.timesteps), + timestep=int(ctx.timestep), # TODO: is there any code which uses it? + latents=ctx.step_output.prev_sample, + predicted_original=predicted_original, # TODO: is there any reason for additional field? + ) + ) From d623bd429b04d3c4b72cd93061bdc92fb285212b Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 16 Jul 2024 00:31:56 +0300 Subject: [PATCH 09/25] Fix condtionings logic --- .../stable_diffusion/diffusion/conditioning_data.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 802aec0109..21fb8d5780 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -131,16 +131,18 @@ class TextConditioningData: # TODO: combine regions with conditionings if conditioning_mode == "both": - conditionings = [self.uncond_text.embeds, self.cond_text.embeds] + conditionings = [self.uncond_text, self.cond_text] c_regions = [self.uncond_regions, self.cond_regions] elif conditioning_mode == "positive": - conditionings = [self.cond_text.embeds] + conditionings = [self.cond_text] c_regions = [self.cond_regions] else: - conditionings = [self.uncond_text.embeds] + conditionings = [self.uncond_text] c_regions = [self.uncond_regions] - encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(conditionings) + encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch( + [c.embeds for c in conditionings] + ) unet_kwargs.encoder_hidden_states = encoder_hidden_states unet_kwargs.encoder_attention_mask = encoder_attention_mask From fd8d1c12d4e8f40f75b3d595c2e0fdfa8f8cca03 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 16 Jul 2024 00:43:32 +0300 Subject: [PATCH 10/25] Remove 'del' operator overload --- .../backend/stable_diffusion/denoise_context.py | 3 --- .../backend/stable_diffusion/diffusion_backend.py | 15 ++++++++------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/invokeai/backend/stable_diffusion/denoise_context.py b/invokeai/backend/stable_diffusion/denoise_context.py index b56f095948..453398a121 100644 --- a/invokeai/backend/stable_diffusion/denoise_context.py +++ b/invokeai/backend/stable_diffusion/denoise_context.py @@ -55,6 +55,3 @@ class DenoiseContext: noise_pred: Optional[torch.Tensor] = None extra: dict = field(default_factory=dict) - - def __delattr__(self, name: str): - setattr(self, name, None) diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 81fe00d59d..4c08639ddf 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -82,11 +82,11 @@ class StableDiffusionBackend: # compute the previous noisy sample x_t -> x_t-1 step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.scheduler_step_kwargs) - # del locals - del ctx.latent_model_input - del ctx.negative_noise_pred - del ctx.positive_noise_pred - del ctx.noise_pred + # clean up locals + ctx.latent_model_input = None + ctx.negative_noise_pred = None + ctx.positive_noise_pred = None + ctx.noise_pred = None return step_output @@ -126,8 +126,9 @@ class StableDiffusionBackend: ext_manager.callbacks.post_unet(ctx, ext_manager) - del ctx.unet_kwargs - del ctx.conditioning_mode + # clean up locals + ctx.unet_kwargs = None + ctx.conditioning_mode = None return noise_pred From 9f088d1bf5bff0d53a584c1580d43a882f8cfe24 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 16 Jul 2024 00:51:25 +0300 Subject: [PATCH 11/25] Multiple small fixes --- invokeai/app/invocations/denoise_latents.py | 7 ++++--- .../backend/stable_diffusion/diffusers_pipeline.py | 11 +---------- .../backend/stable_diffusion/extensions/__init__.py | 12 ------------ 3 files changed, 5 insertions(+), 25 deletions(-) delete mode 100644 invokeai/backend/stable_diffusion/extensions/__init__.py diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index c0a74756cb..7563c30223 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -57,7 +57,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( ) from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend -from invokeai.backend.stable_diffusion.extensions import PreviewExt +from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES @@ -723,7 +723,8 @@ class DenoiseLatentsInvocation(BaseInvocation): @torch.no_grad() @SilenceWarnings() # This quenches the NSFW nag from diffusers. def _new_invoke(self, context: InvocationContext) -> LatentsOutput: - with ExitStack() as exit_stack: + # TODO: remove supression when extensions which use models added + with ExitStack() as exit_stack: # noqa: F841 ext_manager = ExtensionsManager() device = TorchDevice.choose_torch_device() @@ -804,7 +805,7 @@ class DenoiseLatentsInvocation(BaseInvocation): result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - result_latents = result_latents.to("cpu") # TODO: detach? + result_latents = result_latents.detach().to("cpu") TorchDevice.empty_cache() name = context.tensors.save(tensor=result_latents) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 216e4d3bd1..b3a668518b 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -23,20 +23,11 @@ from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData -from invokeai.backend.stable_diffusion.extensions import PipelineIntermediateState +from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState from invokeai.backend.util.attention import auto_detect_slice_size from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.hotfixes import ControlNetModel -# @dataclass -# class PipelineIntermediateState: -# step: int -# order: int -# total_steps: int -# timestep: int -# latents: torch.Tensor -# predicted_original: Optional[torch.Tensor] = None - @dataclass class AddsMaskGuidance: diff --git a/invokeai/backend/stable_diffusion/extensions/__init__.py b/invokeai/backend/stable_diffusion/extensions/__init__.py deleted file mode 100644 index faf0a1e1ec..0000000000 --- a/invokeai/backend/stable_diffusion/extensions/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Initialization file for the invokeai.backend.stable_diffusion.extensions package -""" - -from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase -from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState, PreviewExt - -__all__ = [ - "ExtensionBase", - "PipelineIntermediateState", - "PreviewExt", -] From 608cbe3f5c4f5efe2ed507dfc3a81d57eaaa0423 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 16 Jul 2024 19:30:29 +0300 Subject: [PATCH 12/25] Separate inputs in denoise context --- invokeai/app/invocations/denoise_latents.py | 18 +++++++------- .../stable_diffusion/denoise_context.py | 11 ++++++--- .../stable_diffusion/diffusion_backend.py | 24 ++++++++++--------- .../stable_diffusion/extensions/preview.py | 4 ++-- 4 files changed, 33 insertions(+), 24 deletions(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 7563c30223..81b92d4fa7 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -40,7 +40,7 @@ from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import BaseModelType from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless -from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs from invokeai.backend.stable_diffusion.diffusers_pipeline import ( ControlNetData, StableDiffusionGeneratorPipeline, @@ -768,13 +768,15 @@ class DenoiseLatentsInvocation(BaseInvocation): ) denoise_ctx = DenoiseContext( - latents=latents, - timesteps=timesteps, - init_timestep=init_timestep, - noise=noise, - seed=seed, - scheduler_step_kwargs=scheduler_step_kwargs, - conditioning_data=conditioning_data, + inputs=DenoiseInputs( + orig_latents=latents, + timesteps=timesteps, + init_timestep=init_timestep, + noise=noise, + seed=seed, + scheduler_step_kwargs=scheduler_step_kwargs, + conditioning_data=conditioning_data, + ), unet=None, scheduler=scheduler, ) diff --git a/invokeai/backend/stable_diffusion/denoise_context.py b/invokeai/backend/stable_diffusion/denoise_context.py index 453398a121..2a00052fd1 100644 --- a/invokeai/backend/stable_diffusion/denoise_context.py +++ b/invokeai/backend/stable_diffusion/denoise_context.py @@ -30,8 +30,8 @@ class UNetKwargs: @dataclass -class DenoiseContext: - latents: torch.Tensor +class DenoiseInputs: + orig_latents: torch.Tensor scheduler_step_kwargs: dict[str, Any] conditioning_data: TextConditioningData noise: Optional[torch.Tensor] @@ -39,10 +39,15 @@ class DenoiseContext: timesteps: torch.Tensor init_timestep: torch.Tensor + +@dataclass +class DenoiseContext: + inputs: DenoiseInputs + scheduler: SchedulerMixin unet: Optional[UNet2DConditionModel] = None - orig_latents: Optional[torch.Tensor] = None + latents: Optional[torch.Tensor] = None step_index: Optional[int] = None timestep: Optional[torch.Tensor] = None unet_kwargs: Optional[UNetKwargs] = None diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 4c08639ddf..f8cb92d1d4 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -22,25 +22,27 @@ class StableDiffusionBackend: self.sequential_guidance = config.sequential_guidance def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): - if ctx.init_timestep.shape[0] == 0: - return ctx.latents + if ctx.inputs.init_timestep.shape[0] == 0: + return ctx.inputs.orig_latents - ctx.orig_latents = ctx.latents.clone() + ctx.latents = ctx.inputs.orig_latents.clone() - if ctx.noise is not None: + if ctx.inputs.noise is not None: batch_size = ctx.latents.shape[0] # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers - ctx.latents = ctx.scheduler.add_noise(ctx.latents, ctx.noise, ctx.init_timestep.expand(batch_size)) + ctx.latents = ctx.scheduler.add_noise( + ctx.latents, ctx.inputs.noise, ctx.inputs.init_timestep.expand(batch_size) + ) # if no work to do, return latents - if ctx.timesteps.shape[0] == 0: + if ctx.inputs.timesteps.shape[0] == 0: return ctx.latents # ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed) # ext: preview[pre_denoise_loop, priority=low] ext_manager.callbacks.pre_denoise_loop(ctx, ext_manager) - for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.timesteps)): # noqa: B020 + for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020 # ext: inpaint (apply mask to latents on non-inpaint models) ext_manager.callbacks.pre_step(ctx, ext_manager) @@ -80,7 +82,7 @@ class StableDiffusionBackend: ext_manager.callbacks.post_apply_cfg(ctx, ext_manager) # compute the previous noisy sample x_t -> x_t-1 - step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.scheduler_step_kwargs) + step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs) # clean up locals ctx.latent_model_input = None @@ -92,7 +94,7 @@ class StableDiffusionBackend: @staticmethod def apply_cfg(ctx: DenoiseContext) -> torch.Tensor: - guidance_scale = ctx.conditioning_data.guidance_scale + guidance_scale = ctx.inputs.conditioning_data.guidance_scale if isinstance(guidance_scale, list): guidance_scale = guidance_scale[ctx.step_index] @@ -109,12 +111,12 @@ class StableDiffusionBackend: timestep=ctx.timestep, encoder_hidden_states=None, # set later by conditoning cross_attention_kwargs=dict( # noqa: C408 - percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps, + percent_through=ctx.step_index / len(ctx.inputs.timesteps), ), ) ctx.conditioning_mode = conditioning_mode - ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode) + ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode) # ext: controlnet/ip/t2i [pre_unet] ext_manager.callbacks.pre_unet(ctx, ext_manager) diff --git a/invokeai/backend/stable_diffusion/extensions/preview.py b/invokeai/backend/stable_diffusion/extensions/preview.py index 73a1eef3c5..acc55e6172 100644 --- a/invokeai/backend/stable_diffusion/extensions/preview.py +++ b/invokeai/backend/stable_diffusion/extensions/preview.py @@ -35,7 +35,7 @@ class PreviewExt(ExtensionBase): PipelineIntermediateState( step=-1, order=ctx.scheduler.order, - total_steps=len(ctx.timesteps), + total_steps=len(ctx.inputs.timesteps), timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it? latents=ctx.latents, ) @@ -55,7 +55,7 @@ class PreviewExt(ExtensionBase): PipelineIntermediateState( step=ctx.step_index, order=ctx.scheduler.order, - total_steps=len(ctx.timesteps), + total_steps=len(ctx.inputs.timesteps), timestep=int(ctx.timestep), # TODO: is there any code which uses it? latents=ctx.step_output.prev_sample, predicted_original=predicted_original, # TODO: is there any reason for additional field? From cec345cb5c680577f2d41d58d5c78a3483dfbc74 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 16 Jul 2024 20:03:29 +0300 Subject: [PATCH 13/25] Change attention processor apply logic --- invokeai/app/invocations/denoise_latents.py | 4 ++- invokeai/backend/model_patcher.py | 23 ++++++++++++-- .../stable_diffusion/denoise_context.py | 3 +- .../stable_diffusion/extensions/base.py | 9 ++++-- .../stable_diffusion/extensions_manager.py | 31 ++----------------- 5 files changed, 36 insertions(+), 34 deletions(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 81b92d4fa7..6005bc83e0 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -776,6 +776,7 @@ class DenoiseLatentsInvocation(BaseInvocation): seed=seed, scheduler_step_kwargs=scheduler_step_kwargs, conditioning_data=conditioning_data, + attention_processor_cls=CustomAttnProcessor2_0, ), unet=None, scheduler=scheduler, @@ -797,8 +798,9 @@ class DenoiseLatentsInvocation(BaseInvocation): assert isinstance(unet_info.model, UNet2DConditionModel) with ( unet_info.model_on_device() as (model_state_dict, unet), + ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls), # ext: controlnet - ext_manager.patch_attention_processor(unet, CustomAttnProcessor2_0), + ext_manager.patch_extensions(unet), # ext: freeu, seamless, ip adapter, lora ext_manager.patch_unet(model_state_dict, unet), ): diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index 8c7a62c371..d31cb6bdef 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -5,7 +5,7 @@ from __future__ import annotations import pickle from contextlib import contextmanager -from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union import numpy as np import torch @@ -32,8 +32,27 @@ with LoRAHelper.apply_lora_unet(unet, loras): """ -# TODO: rename smth like ModelPatcher and add TI method? class ModelPatcher: + @staticmethod + @contextmanager + def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]): + """A context manager that patches `unet` with the provided attention processor. + + Args: + unet (UNet2DConditionModel): The UNet model to patch. + processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...). + """ + unet_orig_processors = unet.attn_processors + try: + # create separate instance for each attention, to be able modify each attention separately + new_attn_processors = {key: processor_cls() for key in unet_orig_processors.keys()} + unet.set_attn_processor(new_attn_processors) + + yield None + + finally: + unet.set_attn_processor(unet_orig_processors) + @staticmethod def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: assert "." not in lora_key diff --git a/invokeai/backend/stable_diffusion/denoise_context.py b/invokeai/backend/stable_diffusion/denoise_context.py index 2a00052fd1..26c3b02c3b 100644 --- a/invokeai/backend/stable_diffusion/denoise_context.py +++ b/invokeai/backend/stable_diffusion/denoise_context.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union import torch from diffusers import UNet2DConditionModel @@ -38,6 +38,7 @@ class DenoiseInputs: seed: int timesteps: torch.Tensor init_timestep: torch.Tensor + attention_processor_cls: Type[Any] @dataclass diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index 3effa77da4..2aaf49e3b9 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -1,10 +1,15 @@ +from __future__ import annotations + from contextlib import contextmanager from dataclasses import dataclass -from typing import Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Callable, Dict, List, Optional import torch from diffusers import UNet2DConditionModel +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + @dataclass class InjectionInfo: @@ -37,7 +42,7 @@ class ExtensionBase: self.injections.append(InjectionInfo(**func.__inj_info__, function=func)) @contextmanager - def patch_attention_processor(self, attention_processor_cls: object): + def patch_extension(self, context: DenoiseContext): yield None @contextmanager diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 876fd96d39..e747579d8b 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -98,39 +98,14 @@ class ExtensionsManager: if name in self._callbacks: self._callbacks[name](*args, **kwargs) - # TODO: is there any need in such high abstarction - # @contextmanager - # def patch_extensions(self): - # exit_stack = ExitStack() - # try: - # for ext in self.extensions: - # exit_stack.enter_context(ext.patch_extension(self)) - # - # yield None - # - # finally: - # exit_stack.close() - @contextmanager - def patch_attention_processor(self, unet: UNet2DConditionModel, attn_processor_cls: object): - unet_orig_processors = unet.attn_processors - exit_stack = ExitStack() - try: - # just to be sure that attentions have not same processor instance - attn_procs = {} - for name in unet.attn_processors.keys(): - attn_procs[name] = attn_processor_cls() - unet.set_attn_processor(attn_procs) - + def patch_extensions(self, context: DenoiseContext): + with ExitStack() as exit_stack: for ext in self.extensions: - exit_stack.enter_context(ext.patch_attention_processor(attn_processor_cls)) + exit_stack.enter_context(ext.patch_extension(context)) yield None - finally: - unet.set_attn_processor(unet_orig_processors) - exit_stack.close() - @contextmanager def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): exit_stack = ExitStack() From b7c6c63005e4144d8589ee7620824c9f9b2ff52c Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 16 Jul 2024 22:52:44 +0300 Subject: [PATCH 14/25] Added some comments --- .../stable_diffusion/denoise_context.py | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/invokeai/backend/stable_diffusion/denoise_context.py b/invokeai/backend/stable_diffusion/denoise_context.py index 26c3b02c3b..bcebb2945e 100644 --- a/invokeai/backend/stable_diffusion/denoise_context.py +++ b/invokeai/backend/stable_diffusion/denoise_context.py @@ -31,6 +31,30 @@ class UNetKwargs: @dataclass class DenoiseInputs: + """Initial variables passed to denoise. Supposed to be unchanged. + + Variables: + orig_latents: The latent-space image to denoise. + Shape: [batch, channels, latent_height, latent_width] + - If we are inpainting, this is the initial latent image before noise has been added. + - If we are generating a new image, this should be initialized to zeros. + - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner). + scheduler_step_kwargs: kwargs forwarded to the scheduler.step() method. + conditioning_data: Text conditionging data. + noise: Noise used for two purposes: + Shape: [1 or batch, channels, latent_height, latent_width] + 1. Used by the scheduler to noise the initial `latents` before denoising. + 2. Used to noise the `masked_latents` when inpainting. + `noise` should be None if the `latents` tensor has already been noised. + seed: The seed used to generate the noise for the denoising process. + HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the + same noise used earlier in the pipeline. This should really be handled in a clearer way. + timesteps: The timestep schedule for the denoising process. + init_timestep: The first timestep in the schedule. This is used to determine the initial noise level, so + should be populated if you want noise applied *even* if timesteps is empty. + attention_processor_cls: Class of attention processor that is used. + """ + orig_latents: torch.Tensor scheduler_step_kwargs: dict[str, Any] conditioning_data: TextConditioningData @@ -43,6 +67,41 @@ class DenoiseInputs: @dataclass class DenoiseContext: + """Context with all variables in denoise + + Variables: + inputs: Initial variables passed to denoise. Supposed to be unchanged. + scheduler: Scheduler which used to apply noise predictions. + unet: UNet model. + latents: Current state of latent-space image in denoising process. + None until `pre_denoise_loop` callback. + Shape: [batch, channels, latent_height, latent_width] + step_index: Current denoising step index. + None until `pre_step` callback. + timestep: Current denoising step timestep. + None until `pre_step` callback. + unet_kwargs: Arguments which will be passed to U Net model. + Available in `pre_unet`/`post_unet` callbacks, otherwice will be None. + step_output: SchedulerOutput class returned from step function(normally, generated by scheduler). + Supposed to be used only in `post_step` callback, otherwice can be None. + latent_model_input: Scaled version of `latents`, which will be passed to unet_kwargs initialization. + Available in events inside step(between `pre_step` and `post_stop`). + Shape: [batch, channels, latent_height, latent_width] + conditioning_mode: [TMP] Defines on which conditionings current unet call will be runned. + Available in `pre_unet`/`post_unet` callbacks, otherwice will be None. + Can be "negative", "positive" or "both" + negative_noise_pred: [TMP] Noise predictions from negative conditioning. + Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None. + Shape: [batch, channels, latent_height, latent_width] + positive_noise_pred: [TMP] Noise predictions from positive conditioning. + Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None. + Shape: [batch, channels, latent_height, latent_width] + noise_pred: Combined noise prediction from passed conditionings. + Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None. + Shape: [batch, channels, latent_height, latent_width] + extra: Dictionary for extensions to pass extra info about denoise process to other extensions. + """ + inputs: DenoiseInputs scheduler: SchedulerMixin From cd1bc1595a71975ff44b58117149aa8ea3b4d77a Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 17 Jul 2024 03:24:11 +0300 Subject: [PATCH 15/25] Rename sequential as private variable --- invokeai/backend/stable_diffusion/diffusion_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index f8cb92d1d4..d4c784e1d6 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -19,7 +19,7 @@ class StableDiffusionBackend: self.unet = unet self.scheduler = scheduler config = get_config() - self.sequential_guidance = config.sequential_guidance + self._sequential_guidance = config.sequential_guidance def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): if ctx.inputs.init_timestep.shape[0] == 0: @@ -67,7 +67,7 @@ class StableDiffusionBackend: # Note: The current handling of conditioning doesn't feel very future-proof. # This might change in the future as new requirements come up, but for now, # this is the rough plan. - if self.sequential_guidance: + if self._sequential_guidance: ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, "negative") ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, "positive") else: From ae6d4fbc78951aa0c3b0ede1a81bfb76e70831d3 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 17 Jul 2024 03:31:26 +0300 Subject: [PATCH 16/25] Move out _concat_conditionings_for_batch submethods --- .../diffusion/conditioning_data.py | 54 ++++++++++--------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 21fb8d5780..80b671df65 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -172,40 +172,46 @@ class TextConditioningData: regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype), ) - def _concat_conditionings_for_batch(self, conditionings): - def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int): - return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim) + @staticmethod + def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int): + return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim) - def _pad_conditioning(cond, target_len, encoder_attention_mask): - conditioning_attention_mask = torch.ones( - (cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype + @classmethod + def _pad_conditioning( + cls, + cond: torch.Tensor, + target_len: int, + encoder_attention_mask: Optional[torch.Tensor], + ): + conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype) + + if cond.shape[1] < target_len: + conditioning_attention_mask = cls._pad_zeros( + conditioning_attention_mask, + pad_shape=(cond.shape[0], target_len - cond.shape[1]), + dim=1, ) - if cond.shape[1] < max_len: - conditioning_attention_mask = _pad_zeros( - conditioning_attention_mask, - pad_shape=(cond.shape[0], max_len - cond.shape[1]), - dim=1, - ) + cond = cls._pad_zeros( + cond, + pad_shape=(cond.shape[0], target_len - cond.shape[1], cond.shape[2]), + dim=1, + ) - cond = _pad_zeros( - cond, - pad_shape=(cond.shape[0], max_len - cond.shape[1], cond.shape[2]), - dim=1, - ) + if encoder_attention_mask is None: + encoder_attention_mask = conditioning_attention_mask + else: + encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask]) - if encoder_attention_mask is None: - encoder_attention_mask = conditioning_attention_mask - else: - encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask]) - - return cond, encoder_attention_mask + return cond, encoder_attention_mask + @classmethod + def _concat_conditionings_for_batch(cls, conditionings: List[torch.Tensor]): encoder_attention_mask = None max_len = max([c.shape[1] for c in conditionings]) if any(c.shape[1] != max_len for c in conditionings): for i in range(len(conditionings)): - conditionings[i], encoder_attention_mask = _pad_conditioning( + conditionings[i], encoder_attention_mask = cls._pad_conditioning( conditionings[i], max_len, encoder_attention_mask ) From 03e22c257b4084f4d1c18b442ca96914129f34c2 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 17 Jul 2024 03:37:11 +0300 Subject: [PATCH 17/25] Convert conditioning_mode to enum --- .../diffusion/conditioning_data.py | 26 ++++++++++++++----- .../stable_diffusion/diffusion_backend.py | 11 ++++---- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 80b671df65..8a52310e6f 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -1,12 +1,18 @@ +from __future__ import annotations + import math from dataclasses import dataclass -from typing import List, Optional, Union +from enum import Enum +from typing import TYPE_CHECKING, List, Optional, Union import torch -from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData +if TYPE_CHECKING: + from invokeai.backend.ip_adapter.ip_adapter import IPAdapter + from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs + @dataclass class BasicConditioningInfo: @@ -96,6 +102,12 @@ class TextConditioningRegions: assert self.masks.shape[1] == len(self.ranges) +class ConditioningMode(Enum): + Both = "both" + Negative = "negative" + Positive = "positive" + + class TextConditioningData: def __init__( self, @@ -124,21 +136,23 @@ class TextConditioningData: assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo) return isinstance(self.cond_text, SDXLConditioningInfo) - def to_unet_kwargs(self, unet_kwargs, conditioning_mode): + def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode): _, _, h, w = unet_kwargs.sample.shape device = unet_kwargs.sample.device dtype = unet_kwargs.sample.dtype # TODO: combine regions with conditionings - if conditioning_mode == "both": + if conditioning_mode == ConditioningMode.Both: conditionings = [self.uncond_text, self.cond_text] c_regions = [self.uncond_regions, self.cond_regions] - elif conditioning_mode == "positive": + elif conditioning_mode == ConditioningMode.Positive: conditionings = [self.cond_text] c_regions = [self.cond_regions] - else: + elif conditioning_mode == ConditioningMode.Negative: conditionings = [self.uncond_text] c_regions = [self.uncond_regions] + else: + raise ValueError(f"Unexpected conditioning mode: {conditioning_mode}") encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch( [c.embeds for c in conditionings] diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index d4c784e1d6..c1035c2a97 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -7,6 +7,7 @@ from tqdm.auto import tqdm from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager @@ -68,10 +69,10 @@ class StableDiffusionBackend: # This might change in the future as new requirements come up, but for now, # this is the rough plan. if self._sequential_guidance: - ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, "negative") - ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, "positive") + ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Negative) + ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Positive) else: - both_noise_pred = self.run_unet(ctx, ext_manager, "both") + both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both) ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2) # ext: override apply_cfg @@ -101,9 +102,9 @@ class StableDiffusionBackend: return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) # return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) - def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: str): + def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode): sample = ctx.latent_model_input - if conditioning_mode == "both": + if conditioning_mode == ConditioningMode.Both: sample = torch.cat([sample] * 2) ctx.unet_kwargs = UNetKwargs( From 137202b77cadf5f6c9205a376177eaf89516e51d Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 17 Jul 2024 03:40:27 +0300 Subject: [PATCH 18/25] Remove patch_unet logic for now --- .../stable_diffusion/extensions_manager.py | 34 ++----------------- 1 file changed, 2 insertions(+), 32 deletions(-) diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index e747579d8b..08004339e9 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -108,35 +108,5 @@ class ExtensionsManager: @contextmanager def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): - exit_stack = ExitStack() - try: - changed_keys = set() - changed_unknown_keys = {} - - for ext in self.extensions: - patch_result = exit_stack.enter_context(ext.patch_unet(state_dict, unet)) - if patch_result is None: - continue - new_keys, new_unk_keys = patch_result - changed_keys.update(new_keys) - # skip already seen keys, as new weight might be changed - for k, v in new_unk_keys.items(): - if k in changed_unknown_keys: - continue - changed_unknown_keys[k] = v - - yield None - - finally: - exit_stack.close() - assert hasattr(unet, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() - with torch.no_grad(): - for module_key in changed_keys: - weight = state_dict[module_key] - unet.get_submodule(module_key).weight.copy_( - weight, non_blocking=TorchDevice.get_non_blocking(weight.device) - ) - for module_key, weight in changed_unknown_keys.items(): - unet.get_submodule(module_key).weight.copy_( - weight, non_blocking=TorchDevice.get_non_blocking(weight.device) - ) + # TODO: create logic in PR with extension which uses it + yield None From 79e35bd0d360070d5f646cc7d81c69e99c06f7c3 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 17 Jul 2024 03:48:37 +0300 Subject: [PATCH 19/25] Minor fixes --- invokeai/app/invocations/denoise_latents.py | 144 +++++++++--------- invokeai/backend/model_patcher.py | 8 +- .../stable_diffusion/extensions_manager.py | 2 - 3 files changed, 75 insertions(+), 79 deletions(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 6005bc83e0..17a79cca90 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -723,90 +723,88 @@ class DenoiseLatentsInvocation(BaseInvocation): @torch.no_grad() @SilenceWarnings() # This quenches the NSFW nag from diffusers. def _new_invoke(self, context: InvocationContext) -> LatentsOutput: - # TODO: remove supression when extensions which use models added - with ExitStack() as exit_stack: # noqa: F841 - ext_manager = ExtensionsManager() + ext_manager = ExtensionsManager() - device = TorchDevice.choose_torch_device() - dtype = TorchDevice.choose_torch_dtype() + device = TorchDevice.choose_torch_device() + dtype = TorchDevice.choose_torch_dtype() - seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) - latents = latents.to(device=device, dtype=dtype) - if noise is not None: - noise = noise.to(device=device, dtype=dtype) + seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) + latents = latents.to(device=device, dtype=dtype) + if noise is not None: + noise = noise.to(device=device, dtype=dtype) - _, _, latent_height, latent_width = latents.shape + _, _, latent_height, latent_width = latents.shape - conditioning_data = self.get_conditioning_data( - context=context, - positive_conditioning_field=self.positive_conditioning, - negative_conditioning_field=self.negative_conditioning, - cfg_scale=self.cfg_scale, - steps=self.steps, - latent_height=latent_height, - latent_width=latent_width, - device=device, - dtype=dtype, - # TODO: old backend, remove - cfg_rescale_multiplier=self.cfg_rescale_multiplier, - ) + conditioning_data = self.get_conditioning_data( + context=context, + positive_conditioning_field=self.positive_conditioning, + negative_conditioning_field=self.negative_conditioning, + cfg_scale=self.cfg_scale, + steps=self.steps, + latent_height=latent_height, + latent_width=latent_width, + device=device, + dtype=dtype, + # TODO: old backend, remove + cfg_rescale_multiplier=self.cfg_rescale_multiplier, + ) - scheduler = get_scheduler( - context=context, - scheduler_info=self.unet.scheduler, - scheduler_name=self.scheduler, + scheduler = get_scheduler( + context=context, + scheduler_info=self.unet.scheduler, + scheduler_name=self.scheduler, + seed=seed, + ) + + timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( + scheduler, + seed=seed, + device=device, + steps=self.steps, + denoising_start=self.denoising_start, + denoising_end=self.denoising_end, + ) + + denoise_ctx = DenoiseContext( + inputs=DenoiseInputs( + orig_latents=latents, + timesteps=timesteps, + init_timestep=init_timestep, + noise=noise, seed=seed, - ) + scheduler_step_kwargs=scheduler_step_kwargs, + conditioning_data=conditioning_data, + attention_processor_cls=CustomAttnProcessor2_0, + ), + unet=None, + scheduler=scheduler, + ) - timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( - scheduler, - seed=seed, - device=device, - steps=self.steps, - denoising_start=self.denoising_start, - denoising_end=self.denoising_end, - ) + # get the unet's config so that we can pass the base to sd_step_callback() + unet_config = context.models.get_config(self.unet.unet.key) - denoise_ctx = DenoiseContext( - inputs=DenoiseInputs( - orig_latents=latents, - timesteps=timesteps, - init_timestep=init_timestep, - noise=noise, - seed=seed, - scheduler_step_kwargs=scheduler_step_kwargs, - conditioning_data=conditioning_data, - attention_processor_cls=CustomAttnProcessor2_0, - ), - unet=None, - scheduler=scheduler, - ) + ### preview + def step_callback(state: PipelineIntermediateState) -> None: + context.util.sd_step_callback(state, unet_config.base) - ### preview - def step_callback(state: PipelineIntermediateState) -> None: - context.util.sd_step_callback(state, unet_config.base) + ext_manager.add_extension(PreviewExt(step_callback)) - ext_manager.add_extension(PreviewExt(step_callback)) + # ext: t2i/ip adapter + ext_manager.callbacks.setup(denoise_ctx, ext_manager) - # get the unet's config so that we can pass the base to sd_step_callback() - unet_config = context.models.get_config(self.unet.unet.key) - - # ext: t2i/ip adapter - ext_manager.callbacks.setup(denoise_ctx, ext_manager) - - unet_info = context.models.load(self.unet.unet) - assert isinstance(unet_info.model, UNet2DConditionModel) - with ( - unet_info.model_on_device() as (model_state_dict, unet), - ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls), - # ext: controlnet - ext_manager.patch_extensions(unet), - # ext: freeu, seamless, ip adapter, lora - ext_manager.patch_unet(model_state_dict, unet), - ): - sd_backend = StableDiffusionBackend(unet, scheduler) - denoise_ctx.unet = unet - result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager) + unet_info = context.models.load(self.unet.unet) + assert isinstance(unet_info.model, UNet2DConditionModel) + with ( + unet_info.model_on_device() as (model_state_dict, unet), + ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls), + # ext: controlnet + ext_manager.patch_extensions(unet), + # ext: freeu, seamless, ip adapter, lora + ext_manager.patch_unet(model_state_dict, unet), + ): + sd_backend = StableDiffusionBackend(unet, scheduler) + denoise_ctx.unet = unet + result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 result_latents = result_latents.detach().to("cpu") diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index d31cb6bdef..b2d6036f63 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -43,11 +43,11 @@ class ModelPatcher: processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...). """ unet_orig_processors = unet.attn_processors - try: - # create separate instance for each attention, to be able modify each attention separately - new_attn_processors = {key: processor_cls() for key in unet_orig_processors.keys()} - unet.set_attn_processor(new_attn_processors) + # create separate instance for each attention, to be able modify each attention separately + unet_new_processors = {key: processor_cls() for key in unet_orig_processors.keys()} + try: + unet.set_attn_processor(unet_new_processors) yield None finally: diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 08004339e9..213eb5d782 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -8,8 +8,6 @@ from typing import TYPE_CHECKING, Callable, Dict import torch from diffusers import UNet2DConditionModel -from invokeai.backend.util.devices import TorchDevice - if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext from invokeai.backend.stable_diffusion.extensions import ExtensionBase From 2c2ec8f0bcd728750c08576ab2508d8e42bae26c Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 17 Jul 2024 04:20:31 +0300 Subject: [PATCH 20/25] Comments, a bit refactor --- .../stable_diffusion/denoise_context.py | 127 ++++++++++-------- .../diffusion/conditioning_data.py | 42 ++++-- 2 files changed, 98 insertions(+), 71 deletions(-) diff --git a/invokeai/backend/stable_diffusion/denoise_context.py b/invokeai/backend/stable_diffusion/denoise_context.py index bcebb2945e..2b43d3fb0f 100644 --- a/invokeai/backend/stable_diffusion/denoise_context.py +++ b/invokeai/backend/stable_diffusion/denoise_context.py @@ -8,7 +8,7 @@ from diffusers import UNet2DConditionModel from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput if TYPE_CHECKING: - from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData + from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode, TextConditioningData @dataclass @@ -31,92 +31,101 @@ class UNetKwargs: @dataclass class DenoiseInputs: - """Initial variables passed to denoise. Supposed to be unchanged. - - Variables: - orig_latents: The latent-space image to denoise. - Shape: [batch, channels, latent_height, latent_width] - - If we are inpainting, this is the initial latent image before noise has been added. - - If we are generating a new image, this should be initialized to zeros. - - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner). - scheduler_step_kwargs: kwargs forwarded to the scheduler.step() method. - conditioning_data: Text conditionging data. - noise: Noise used for two purposes: - Shape: [1 or batch, channels, latent_height, latent_width] - 1. Used by the scheduler to noise the initial `latents` before denoising. - 2. Used to noise the `masked_latents` when inpainting. - `noise` should be None if the `latents` tensor has already been noised. - seed: The seed used to generate the noise for the denoising process. - HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the - same noise used earlier in the pipeline. This should really be handled in a clearer way. - timesteps: The timestep schedule for the denoising process. - init_timestep: The first timestep in the schedule. This is used to determine the initial noise level, so - should be populated if you want noise applied *even* if timesteps is empty. - attention_processor_cls: Class of attention processor that is used. - """ + """Initial variables passed to denoise. Supposed to be unchanged.""" + # The latent-space image to denoise. + # Shape: [batch, channels, latent_height, latent_width] + # - If we are inpainting, this is the initial latent image before noise has been added. + # - If we are generating a new image, this should be initialized to zeros. + # - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner). orig_latents: torch.Tensor + + # kwargs forwarded to the scheduler.step() method. scheduler_step_kwargs: dict[str, Any] + + # Text conditionging data. conditioning_data: TextConditioningData + + # Noise used for two purposes: + # 1. Used by the scheduler to noise the initial `latents` before denoising. + # 2. Used to noise the `masked_latents` when inpainting. + # `noise` should be None if the `latents` tensor has already been noised. + # Shape: [1 or batch, channels, latent_height, latent_width] noise: Optional[torch.Tensor] + + # The seed used to generate the noise for the denoising process. + # HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the + # same noise used earlier in the pipeline. This should really be handled in a clearer way. seed: int + + # The timestep schedule for the denoising process. timesteps: torch.Tensor + + # The first timestep in the schedule. This is used to determine the initial noise level, so + # should be populated if you want noise applied *even* if timesteps is empty. init_timestep: torch.Tensor + + # Class of attention processor that is used. attention_processor_cls: Type[Any] @dataclass class DenoiseContext: - """Context with all variables in denoise - - Variables: - inputs: Initial variables passed to denoise. Supposed to be unchanged. - scheduler: Scheduler which used to apply noise predictions. - unet: UNet model. - latents: Current state of latent-space image in denoising process. - None until `pre_denoise_loop` callback. - Shape: [batch, channels, latent_height, latent_width] - step_index: Current denoising step index. - None until `pre_step` callback. - timestep: Current denoising step timestep. - None until `pre_step` callback. - unet_kwargs: Arguments which will be passed to U Net model. - Available in `pre_unet`/`post_unet` callbacks, otherwice will be None. - step_output: SchedulerOutput class returned from step function(normally, generated by scheduler). - Supposed to be used only in `post_step` callback, otherwice can be None. - latent_model_input: Scaled version of `latents`, which will be passed to unet_kwargs initialization. - Available in events inside step(between `pre_step` and `post_stop`). - Shape: [batch, channels, latent_height, latent_width] - conditioning_mode: [TMP] Defines on which conditionings current unet call will be runned. - Available in `pre_unet`/`post_unet` callbacks, otherwice will be None. - Can be "negative", "positive" or "both" - negative_noise_pred: [TMP] Noise predictions from negative conditioning. - Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None. - Shape: [batch, channels, latent_height, latent_width] - positive_noise_pred: [TMP] Noise predictions from positive conditioning. - Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None. - Shape: [batch, channels, latent_height, latent_width] - noise_pred: Combined noise prediction from passed conditionings. - Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None. - Shape: [batch, channels, latent_height, latent_width] - extra: Dictionary for extensions to pass extra info about denoise process to other extensions. - """ + """Context with all variables in denoise""" + # Initial variables passed to denoise. Supposed to be unchanged. inputs: DenoiseInputs + # Scheduler which used to apply noise predictions. scheduler: SchedulerMixin + + # UNet model. unet: Optional[UNet2DConditionModel] = None + # Current state of latent-space image in denoising process. + # None until `pre_denoise_loop` callback. + # Shape: [batch, channels, latent_height, latent_width] latents: Optional[torch.Tensor] = None + + # Current denoising step index. + # None until `pre_step` callback. step_index: Optional[int] = None + + # Current denoising step timestep. + # None until `pre_step` callback. timestep: Optional[torch.Tensor] = None + + # Arguments which will be passed to UNet model. + # Available in `pre_unet`/`post_unet` callbacks, otherwise will be None. unet_kwargs: Optional[UNetKwargs] = None + + # SchedulerOutput class returned from step function(normally, generated by scheduler). + # Supposed to be used only in `post_step` callback, otherwise can be None. step_output: Optional[SchedulerOutput] = None + # Scaled version of `latents`, which will be passed to unet_kwargs initialization. + # Available in events inside step(between `pre_step` and `post_stop`). + # Shape: [batch, channels, latent_height, latent_width] latent_model_input: Optional[torch.Tensor] = None - conditioning_mode: Optional[str] = None + + # [TMP] Defines on which conditionings current unet call will be runned. + # Available in `pre_unet`/`post_unet` callbacks, otherwise will be None. + conditioning_mode: Optional[ConditioningMode] = None + + # [TMP] Noise predictions from negative conditioning. + # Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. + # Shape: [batch, channels, latent_height, latent_width] negative_noise_pred: Optional[torch.Tensor] = None + + # [TMP] Noise predictions from positive conditioning. + # Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. + # Shape: [batch, channels, latent_height, latent_width] positive_noise_pred: Optional[torch.Tensor] = None + + # Combined noise prediction from passed conditionings. + # Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. + # Shape: [batch, channels, latent_height, latent_width] noise_pred: Optional[torch.Tensor] = None + # Dictionary for extensions to pass extra info about denoise process to other extensions. extra: dict = field(default_factory=dict) diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 8a52310e6f..b017454a78 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -137,6 +137,12 @@ class TextConditioningData: return isinstance(self.cond_text, SDXLConditioningInfo) def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode): + """Fills unet arguments with data from provided conditionings. + + Args: + unet_kwargs (UNetKwargs): Object which stores UNet model arguments. + conditioning_mode (ConditioningMode): Describes which conditionings should be used. + """ _, _, h, w = unet_kwargs.sample.shape device = unet_kwargs.sample.device dtype = unet_kwargs.sample.dtype @@ -187,7 +193,7 @@ class TextConditioningData: ) @staticmethod - def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int): + def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int) -> torch.Tensor: return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim) @classmethod @@ -195,8 +201,13 @@ class TextConditioningData: cls, cond: torch.Tensor, target_len: int, - encoder_attention_mask: Optional[torch.Tensor], - ): + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Pad provided conditioning tensor to target_len by zeros and returns mask of unpadded bytes. + + Args: + cond (torch.Tensor): Conditioning tensor which to pads by zeros. + target_len (int): To which length(tokens count) pad tensor. + """ conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype) if cond.shape[1] < target_len: @@ -212,21 +223,28 @@ class TextConditioningData: dim=1, ) - if encoder_attention_mask is None: - encoder_attention_mask = conditioning_attention_mask - else: - encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask]) - - return cond, encoder_attention_mask + return cond, conditioning_attention_mask @classmethod - def _concat_conditionings_for_batch(cls, conditionings: List[torch.Tensor]): + def _concat_conditionings_for_batch( + cls, + conditionings: List[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Concatenate provided conditioning tensors to one batched tensor. + If tensors have different sizes then pad them by zeros and creates + encoder_attention_mask to exclude padding from attention. + + Args: + conditionings (List[torch.Tensor]): List of conditioning tensors to concatenate. + """ encoder_attention_mask = None max_len = max([c.shape[1] for c in conditionings]) if any(c.shape[1] != max_len for c in conditionings): + encoder_attention_masks = [None] * len(conditionings) for i in range(len(conditionings)): - conditionings[i], encoder_attention_mask = cls._pad_conditioning( - conditionings[i], max_len, encoder_attention_mask + conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning( + conditionings[i], max_len ) + encoder_attention_mask = torch.cat(encoder_attention_masks) return torch.cat(conditionings), encoder_attention_mask From 3f79467f7bf153f142ea568b54bc0e0694375806 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 17 Jul 2024 04:24:45 +0300 Subject: [PATCH 21/25] Ruff format --- .../stable_diffusion/diffusion/conditioning_data.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index b017454a78..5fe1483ebc 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -3,7 +3,7 @@ from __future__ import annotations import math from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch @@ -231,7 +231,7 @@ class TextConditioningData: conditionings: List[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Concatenate provided conditioning tensors to one batched tensor. - If tensors have different sizes then pad them by zeros and creates + If tensors have different sizes then pad them by zeros and creates encoder_attention_mask to exclude padding from attention. Args: @@ -242,9 +242,7 @@ class TextConditioningData: if any(c.shape[1] != max_len for c in conditionings): encoder_attention_masks = [None] * len(conditionings) for i in range(len(conditionings)): - conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning( - conditionings[i], max_len - ) + conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(conditionings[i], max_len) encoder_attention_mask = torch.cat(encoder_attention_masks) return torch.cat(conditionings), encoder_attention_mask From 2ef3b49a7937d8b0efed6053a71000928b7986bc Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 17 Jul 2024 04:39:15 +0300 Subject: [PATCH 22/25] Add run cancelling logic to extension manager --- invokeai/app/invocations/denoise_latents.py | 2 +- .../stable_diffusion/extensions_manager.py | 18 +++++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 17a79cca90..5b6d945b4e 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -723,7 +723,7 @@ class DenoiseLatentsInvocation(BaseInvocation): @torch.no_grad() @SilenceWarnings() # This quenches the NSFW nag from diffusers. def _new_invoke(self, context: InvocationContext) -> LatentsOutput: - ext_manager = ExtensionsManager() + ext_manager = ExtensionsManager(is_canceled=context.util.is_canceled) device = TorchDevice.choose_torch_device() dtype = TorchDevice.choose_torch_dtype() diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 213eb5d782..481d1dc358 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -3,14 +3,16 @@ from __future__ import annotations from abc import ABC, abstractmethod from contextlib import ExitStack, contextmanager from functools import partial -from typing import TYPE_CHECKING, Callable, Dict +from typing import TYPE_CHECKING, Callable, Dict, List, Optional import torch from diffusers import UNet2DConditionModel +from invokeai.app.services.session_processor.session_processor_common import CanceledException + if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext - from invokeai.backend.stable_diffusion.extensions import ExtensionBase + from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase class ExtCallbacksApi(ABC): @@ -71,10 +73,11 @@ class CallbackInjectionPoint: class ExtensionsManager: - def __init__(self): - self.extensions = [] + def __init__(self, is_canceled: Optional[Callable[[], bool]] = None): + self.extensions: List[ExtensionBase] = [] + self._is_canceled = is_canceled - self._callbacks = {} + self._callbacks: Dict[str, CallbackInjectionPoint] = {} self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback) def add_extension(self, ext: ExtensionBase): @@ -93,6 +96,11 @@ class ExtensionsManager: raise Exception(f"Unsupported injection type: {inj_info.type}") def call_callback(self, name: str, *args, **kwargs): + # TODO: add to patchers too? + # and if so, should it be only in beginning of function or in for loop + if self._is_canceled and self._is_canceled(): + raise CanceledException + if name in self._callbacks: self._callbacks[name](*args, **kwargs) From 0c56d4a581df6f058d741101ec3d96a736d45815 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Thu, 18 Jul 2024 23:49:44 +0300 Subject: [PATCH 23/25] Ryan's suggested changes to extension manager/extensions Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com> --- invokeai/app/invocations/denoise_latents.py | 3 +- .../stable_diffusion/diffusion_backend.py | 15 +-- .../extension_callback_type.py | 12 +++ .../stable_diffusion/extensions/base.py | 46 ++++---- .../stable_diffusion/extensions/preview.py | 10 +- .../stable_diffusion/extensions_manager.py | 102 +++++------------- 6 files changed, 79 insertions(+), 109 deletions(-) create mode 100644 invokeai/backend/stable_diffusion/extension_callback_type.py diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 5b6d945b4e..ccacc3303c 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -57,6 +57,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( ) from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP @@ -790,7 +791,7 @@ class DenoiseLatentsInvocation(BaseInvocation): ext_manager.add_extension(PreviewExt(step_callback)) # ext: t2i/ip adapter - ext_manager.callbacks.setup(denoise_ctx, ext_manager) + ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) unet_info = context.models.load(self.unet.unet) assert isinstance(unet_info.model, UNet2DConditionModel) diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index c1035c2a97..806deb5e03 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -8,6 +8,7 @@ from tqdm.auto import tqdm from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager @@ -41,23 +42,23 @@ class StableDiffusionBackend: # ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed) # ext: preview[pre_denoise_loop, priority=low] - ext_manager.callbacks.pre_denoise_loop(ctx, ext_manager) + ext_manager.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, ctx) for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020 # ext: inpaint (apply mask to latents on non-inpaint models) - ext_manager.callbacks.pre_step(ctx, ext_manager) + ext_manager.run_callback(ExtensionCallbackType.PRE_STEP, ctx) # ext: tiles? [override: step] ctx.step_output = self.step(ctx, ext_manager) # ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models) # ext: preview[post_step, priority=low] - ext_manager.callbacks.post_step(ctx, ext_manager) + ext_manager.run_callback(ExtensionCallbackType.POST_STEP, ctx) ctx.latents = ctx.step_output.prev_sample # ext: inpaint[post_denoise_loop] (restore unmasked part) - ext_manager.callbacks.post_denoise_loop(ctx, ext_manager) + ext_manager.run_callback(ExtensionCallbackType.POST_DENOISE_LOOP, ctx) return ctx.latents @torch.inference_mode() @@ -80,7 +81,7 @@ class StableDiffusionBackend: # ext: cfg_rescale [modify_noise_prediction] # TODO: rename - ext_manager.callbacks.post_apply_cfg(ctx, ext_manager) + ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx) # compute the previous noisy sample x_t -> x_t-1 step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs) @@ -120,14 +121,14 @@ class StableDiffusionBackend: ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode) # ext: controlnet/ip/t2i [pre_unet] - ext_manager.callbacks.pre_unet(ctx, ext_manager) + ext_manager.run_callback(ExtensionCallbackType.PRE_UNET, ctx) # ext: inpaint [pre_unet, priority=low] # or # ext: inpaint [override: unet_forward] noise_pred = self._unet_forward(**vars(ctx.unet_kwargs)) - ext_manager.callbacks.post_unet(ctx, ext_manager) + ext_manager.run_callback(ExtensionCallbackType.POST_UNET, ctx) # clean up locals ctx.unet_kwargs = None diff --git a/invokeai/backend/stable_diffusion/extension_callback_type.py b/invokeai/backend/stable_diffusion/extension_callback_type.py new file mode 100644 index 0000000000..aaefbd7ed0 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extension_callback_type.py @@ -0,0 +1,12 @@ +from enum import Enum + + +class ExtensionCallbackType(Enum): + SETUP = "setup" + PRE_DENOISE_LOOP = "pre_denoise_loop" + POST_DENOISE_LOOP = "post_denoise_loop" + PRE_STEP = "pre_step" + POST_STEP = "post_step" + PRE_UNET = "pre_unet" + POST_UNET = "post_unet" + POST_APPLY_CFG = "post_apply_cfg" diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index 2aaf49e3b9..802af86e6d 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -2,44 +2,54 @@ from __future__ import annotations from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Callable, Dict, List import torch from diffusers import UNet2DConditionModel if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType @dataclass -class InjectionInfo: - type: str - name: str - order: Optional[int] - function: Callable +class CallbackMetadata: + callback_type: ExtensionCallbackType + order: int -def callback(name: str, order: int = 0): - def _decorator(func): - func.__inj_info__ = { - "type": "callback", - "name": name, - "order": order, - } - return func +@dataclass +class CallbackFunctionWithMetadata: + metadata: CallbackMetadata + function: Callable[[DenoiseContext], None] + + +def callback(callback_type: ExtensionCallbackType, order: int = 0): + def _decorator(function): + function._ext_metadata = CallbackMetadata( + callback_type=callback_type, + order=order, + ) + return function return _decorator class ExtensionBase: def __init__(self): - self.injections: List[InjectionInfo] = [] + self._callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {} + + # Register all of the callback methods for this instance. for func_name in dir(self): func = getattr(self, func_name) - if not callable(func) or not hasattr(func, "__inj_info__"): - continue + metadata = getattr(func, "_ext_metadata", None) + if metadata is not None and isinstance(metadata, CallbackMetadata): + if metadata.callback_type not in self._callbacks: + self._callbacks[metadata.callback_type] = [] + self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func)) - self.injections.append(InjectionInfo(**func.__inj_info__, function=func)) + def get_callbacks(self): + return self._callbacks @contextmanager def patch_extension(self, context: DenoiseContext): diff --git a/invokeai/backend/stable_diffusion/extensions/preview.py b/invokeai/backend/stable_diffusion/extensions/preview.py index acc55e6172..fd0cc0b9a3 100644 --- a/invokeai/backend/stable_diffusion/extensions/preview.py +++ b/invokeai/backend/stable_diffusion/extensions/preview.py @@ -5,11 +5,11 @@ from typing import TYPE_CHECKING, Callable, Optional import torch +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext - from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager # TODO: change event to accept image instead of latents @@ -29,8 +29,8 @@ class PreviewExt(ExtensionBase): self.callback = callback # do last so that all other changes shown - @callback("pre_denoise_loop", order=1000) - def initial_preview(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000) + def initial_preview(self, ctx: DenoiseContext): self.callback( PipelineIntermediateState( step=-1, @@ -42,8 +42,8 @@ class PreviewExt(ExtensionBase): ) # do last so that all other changes shown - @callback("post_step", order=1000) - def step_preview(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): + @callback(ExtensionCallbackType.POST_STEP, order=1000) + def step_preview(self, ctx: DenoiseContext): if hasattr(ctx.step_output, "denoised"): predicted_original = ctx.step_output.denoised elif hasattr(ctx.step_output, "pred_original_sample"): diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 481d1dc358..1552fb5dd7 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -1,8 +1,6 @@ from __future__ import annotations -from abc import ABC, abstractmethod from contextlib import ExitStack, contextmanager -from functools import partial from typing import TYPE_CHECKING, Callable, Dict, List, Optional import torch @@ -12,102 +10,50 @@ from invokeai.app.services.session_processor.session_processor_common import Can if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext - from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase - - -class ExtCallbacksApi(ABC): - @abstractmethod - def setup(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): - pass - - @abstractmethod - def pre_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): - pass - - @abstractmethod - def post_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): - pass - - @abstractmethod - def pre_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): - pass - - @abstractmethod - def post_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): - pass - - @abstractmethod - def pre_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): - pass - - @abstractmethod - def post_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): - pass - - @abstractmethod - def post_apply_cfg(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): - pass - - -class ProxyCallsClass: - def __init__(self, handler): - self._handler = handler - - def __getattr__(self, item): - return partial(self._handler, item) - - -class CallbackInjectionPoint: - def __init__(self): - self.handlers = {} - - def add(self, func: Callable, order: int): - if order not in self.handlers: - self.handlers[order] = [] - self.handlers[order].append(func) - - def __call__(self, *args, **kwargs): - for order in sorted(self.handlers.keys(), reverse=True): - for handler in self.handlers[order]: - handler(*args, **kwargs) + from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType + from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase class ExtensionsManager: def __init__(self, is_canceled: Optional[Callable[[], bool]] = None): - self.extensions: List[ExtensionBase] = [] self._is_canceled = is_canceled - self._callbacks: Dict[str, CallbackInjectionPoint] = {} - self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback) + self._extensions: List[ExtensionBase] = [] + self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {} - def add_extension(self, ext: ExtensionBase): - self.extensions.append(ext) + def add_extension(self, extension: ExtensionBase): + self._extensions.append(extension) + self._regenerate_ordered_callbacks() - self._callbacks.clear() + def _regenerate_ordered_callbacks(self): + """Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added.""" + self._ordered_callbacks = {} - for ext in self.extensions: - for inj_info in ext.injections: - if inj_info.type == "callback": - if inj_info.name not in self._callbacks: - self._callbacks[inj_info.name] = CallbackInjectionPoint() - self._callbacks[inj_info.name].add(inj_info.function, inj_info.order) + # Fill the ordered callbacks dictionary. + for extension in self._extensions: + for callback_type, callbacks in extension.get_callbacks().items(): + if callback_type not in self._ordered_callbacks: + self._ordered_callbacks[callback_type] = [] + self._ordered_callbacks[callback_type].extend(callbacks) - else: - raise Exception(f"Unsupported injection type: {inj_info.type}") + # Sort each callback list. + for callback_type, callbacks in self._ordered_callbacks.items(): + self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order) - def call_callback(self, name: str, *args, **kwargs): + def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext): # TODO: add to patchers too? # and if so, should it be only in beginning of function or in for loop if self._is_canceled and self._is_canceled(): raise CanceledException - if name in self._callbacks: - self._callbacks[name](*args, **kwargs) + callbacks = self._ordered_callbacks.get(callback_type, []) + for cb in callbacks: + cb.function(ctx) @contextmanager def patch_extensions(self, context: DenoiseContext): with ExitStack() as exit_stack: - for ext in self.extensions: + for ext in self._extensions: exit_stack.enter_context(ext.patch_extension(context)) yield None From 83a86abce28880fda9ada88e818e2d3eb58ebcbd Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 19 Jul 2024 14:05:27 -0400 Subject: [PATCH 24/25] Add unit tests for ExtensionsManager and ExtensionBase. --- .../stable_diffusion/extensions_manager.py | 3 + .../stable_diffusion/extensions/test_base.py | 46 +++++++ .../test_extension_manager.py | 112 ++++++++++++++++++ 3 files changed, 161 insertions(+) create mode 100644 tests/backend/stable_diffusion/extensions/test_base.py create mode 100644 tests/backend/stable_diffusion/test_extension_manager.py diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 1552fb5dd7..f42a065e82 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -18,6 +18,7 @@ class ExtensionsManager: def __init__(self, is_canceled: Optional[Callable[[], bool]] = None): self._is_canceled = is_canceled + # A list of extensions in the order that they were added to the ExtensionsManager. self._extensions: List[ExtensionBase] = [] self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {} @@ -38,6 +39,8 @@ class ExtensionsManager: # Sort each callback list. for callback_type, callbacks in self._ordered_callbacks.items(): + # Note that sorted() is stable, so if two callbacks have the same order, the order that they extensions were + # added will be preserved. self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order) def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext): diff --git a/tests/backend/stable_diffusion/extensions/test_base.py b/tests/backend/stable_diffusion/extensions/test_base.py new file mode 100644 index 0000000000..d024c551a2 --- /dev/null +++ b/tests/backend/stable_diffusion/extensions/test_base.py @@ -0,0 +1,46 @@ +from unittest import mock + +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback + + +class MockExtension(ExtensionBase): + """A mock ExtensionBase subclass for testing purposes.""" + + def __init__(self, x: int): + super().__init__() + self._x = x + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def set_step_index(self, ctx: DenoiseContext): + ctx.step_index = self._x + + +def test_extension_base_callback_registration(): + """Test that a callback can be successfully registered with an extension.""" + val = 5 + mock_extension = MockExtension(val) + + mock_ctx = mock.MagicMock() + + callbacks = mock_extension.get_callbacks() + pre_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.PRE_DENOISE_LOOP, []) + assert len(pre_denoise_loop_cbs) == 1 + + # Call the mock callback. + pre_denoise_loop_cbs[0].function(mock_ctx) + + # Confirm that the callback ran. + assert mock_ctx.step_index == val + + +def test_extension_base_empty_callback_type(): + """Test that an empty list is returned when no callbacks are registered for a given callback type.""" + mock_extension = MockExtension(5) + + # There should be no callbacks registered for POST_DENOISE_LOOP. + callbacks = mock_extension.get_callbacks() + + post_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.POST_DENOISE_LOOP, []) + assert len(post_denoise_loop_cbs) == 0 diff --git a/tests/backend/stable_diffusion/test_extension_manager.py b/tests/backend/stable_diffusion/test_extension_manager.py new file mode 100644 index 0000000000..889f8316e5 --- /dev/null +++ b/tests/backend/stable_diffusion/test_extension_manager.py @@ -0,0 +1,112 @@ +from unittest import mock + +import pytest + +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback +from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager + + +class MockExtension(ExtensionBase): + """A mock ExtensionBase subclass for testing purposes.""" + + def __init__(self, x: int): + super().__init__() + self._x = x + + # Note that order is not specified. It should default to 0. + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def set_step_index(self, ctx: DenoiseContext): + ctx.step_index = self._x + + +class MockExtensionLate(ExtensionBase): + """A mock ExtensionBase subclass with a high order value on its PRE_DENOISE_LOOP callback.""" + + def __init__(self, x: int): + super().__init__() + self._x = x + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000) + def set_step_index(self, ctx: DenoiseContext): + ctx.step_index = self._x + + +def test_extension_manager_run_callback(): + """Test that run_callback runs all callbacks for the given callback type.""" + + em = ExtensionsManager() + mock_extension_1 = MockExtension(1) + em.add_extension(mock_extension_1) + + mock_ctx = mock.MagicMock() + em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx) + + assert mock_ctx.step_index == 1 + + +def test_extension_manager_run_callback_no_callbacks(): + """Test that run_callback does not raise an error when there are no callbacks for the given callback type.""" + em = ExtensionsManager() + mock_ctx = mock.MagicMock() + em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx) + + +@pytest.mark.parametrize( + ["extension_1", "extension_2"], + # Regardless of initialization order, we expect MockExtensionLate to run last. + [(MockExtension(1), MockExtensionLate(2)), (MockExtensionLate(2), MockExtension(1))], +) +def test_extension_manager_order_callbacks(extension_1: ExtensionBase, extension_2: ExtensionBase): + """Test that run_callback runs callbacks in the correct order.""" + em = ExtensionsManager() + em.add_extension(extension_1) + em.add_extension(extension_2) + + mock_ctx = mock.MagicMock() + em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx) + + assert mock_ctx.step_index == 2 + + +class MockExtensionStableSort(ExtensionBase): + """A mock extension with three PRE_DENOISE_LOOP callbacks, each with a different order value.""" + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=-1000) + def early(self, ctx: DenoiseContext): + pass + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def middle(self, ctx: DenoiseContext): + pass + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000) + def late(self, ctx: DenoiseContext): + pass + + +def test_extension_manager_stable_sort(): + """Test that when two callbacks have the same 'order' value, they are sorted based on the order they were added to + the ExtensionsManager.""" + + em = ExtensionsManager() + + mock_extension_1 = MockExtensionStableSort() + mock_extension_2 = MockExtensionStableSort() + + em.add_extension(mock_extension_1) + em.add_extension(mock_extension_2) + + expected_order = [ + mock_extension_1.early, + mock_extension_2.early, + mock_extension_1.middle, + mock_extension_2.middle, + mock_extension_1.late, + mock_extension_2.late, + ] + + # It's not ideal that we are accessing a private attribute here, but this was the most direct way to assert the + # desired behaviour. + assert [cb.function for cb in em._ordered_callbacks[ExtensionCallbackType.PRE_DENOISE_LOOP]] == expected_order From 39e10d894c44b73caea8c0e36776ff7b1f34b620 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Fri, 19 Jul 2024 23:17:01 +0300 Subject: [PATCH 25/25] Add invocation cancellation logic to patchers --- invokeai/backend/stable_diffusion/extensions_manager.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index f42a065e82..1cae2e4219 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -44,8 +44,6 @@ class ExtensionsManager: self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order) def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext): - # TODO: add to patchers too? - # and if so, should it be only in beginning of function or in for loop if self._is_canceled and self._is_canceled(): raise CanceledException @@ -55,6 +53,9 @@ class ExtensionsManager: @contextmanager def patch_extensions(self, context: DenoiseContext): + if self._is_canceled and self._is_canceled(): + raise CanceledException + with ExitStack() as exit_stack: for ext in self._extensions: exit_stack.enter_context(ext.patch_extension(context)) @@ -63,5 +64,8 @@ class ExtensionsManager: @contextmanager def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): + if self._is_canceled and self._is_canceled(): + raise CanceledException + # TODO: create logic in PR with extension which uses it yield None