mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Base of modular backend (#6606)
## Summary Base code of new modular backend from #6577. Contains normal generation and regional prompts support. Also preview extension included to test if extensions logic works. ## Related Issues / Discussions https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d ## QA Instructions Run with and without set `USE_MODULAR_DENOISE` environment. Currently only normal and regional conditionings supported, so just generate some images and compare with main output. ## Merge Plan Discuss a bit more about injection point names? As if for example in future unet will be overridable, current `pre_unet`/`post_unet` assumes to name override as `unet` what feels a bit odd. Also `apply_cfg` - future implementation could ignore/not use cfg, so in this case `combine_noise_predictions`/`combine_noise` seems more suitable. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
This commit is contained in:
commit
473f4cc1c3
@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
import inspect
|
||||
import os
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
@ -39,6 +40,7 @@ from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
@ -53,6 +55,11 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
TextConditioningData,
|
||||
TextConditioningRegions,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@ -314,9 +321,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context: InvocationContext,
|
||||
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||
unet: UNet2DConditionModel,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
cfg_scale: float | list[float],
|
||||
steps: int,
|
||||
cfg_rescale_multiplier: float,
|
||||
@ -330,10 +338,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
uncond_list = [uncond_list]
|
||||
|
||||
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||
cond_list, context, unet.device, unet.dtype
|
||||
cond_list, context, device, dtype
|
||||
)
|
||||
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||
uncond_list, context, unet.device, unet.dtype
|
||||
uncond_list, context, device, dtype
|
||||
)
|
||||
|
||||
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||
@ -341,14 +349,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
masks=cond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=unet.dtype,
|
||||
dtype=dtype,
|
||||
)
|
||||
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||
text_conditionings=uncond_text_embeddings,
|
||||
masks=uncond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=unet.dtype,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if isinstance(cfg_scale, list):
|
||||
@ -707,9 +715,108 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return seed, noise, latents
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
if os.environ.get("USE_MODULAR_DENOISE", False):
|
||||
return self._new_invoke(context)
|
||||
else:
|
||||
return self._old_invoke(context)
|
||||
|
||||
@torch.no_grad()
|
||||
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
ext_manager = ExtensionsManager(is_canceled=context.util.is_canceled)
|
||||
|
||||
device = TorchDevice.choose_torch_device()
|
||||
dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=device, dtype=dtype)
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
|
||||
conditioning_data = self.get_conditioning_data(
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
cfg_scale=self.cfg_scale,
|
||||
steps=self.steps,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
# TODO: old backend, remove
|
||||
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
)
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||
scheduler,
|
||||
seed=seed,
|
||||
device=device,
|
||||
steps=self.steps,
|
||||
denoising_start=self.denoising_start,
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
denoise_ctx = DenoiseContext(
|
||||
inputs=DenoiseInputs(
|
||||
orig_latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
noise=noise,
|
||||
seed=seed,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
attention_processor_cls=CustomAttnProcessor2_0,
|
||||
),
|
||||
unet=None,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
### preview
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, unet_config.base)
|
||||
|
||||
ext_manager.add_extension(PreviewExt(step_callback))
|
||||
|
||||
# ext: t2i/ip adapter
|
||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
unet_info.model_on_device() as (model_state_dict, unet),
|
||||
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
||||
# ext: controlnet
|
||||
ext_manager.patch_extensions(unet),
|
||||
# ext: freeu, seamless, ip adapter, lora
|
||||
ext_manager.patch_unet(model_state_dict, unet),
|
||||
):
|
||||
sd_backend = StableDiffusionBackend(unet, scheduler)
|
||||
denoise_ctx.unet = unet
|
||||
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.detach().to("cpu")
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=result_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||
|
||||
@torch.no_grad()
|
||||
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
|
||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
@ -788,7 +895,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
unet=unet,
|
||||
device=unet.device,
|
||||
dtype=unet.dtype,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
cfg_scale=self.cfg_scale,
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -32,8 +32,27 @@ with LoRAHelper.apply_lora_unet(unet, loras):
|
||||
"""
|
||||
|
||||
|
||||
# TODO: rename smth like ModelPatcher and add TI method?
|
||||
class ModelPatcher:
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]):
|
||||
"""A context manager that patches `unet` with the provided attention processor.
|
||||
|
||||
Args:
|
||||
unet (UNet2DConditionModel): The UNet model to patch.
|
||||
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
|
||||
"""
|
||||
unet_orig_processors = unet.attn_processors
|
||||
|
||||
# create separate instance for each attention, to be able modify each attention separately
|
||||
unet_new_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
|
||||
try:
|
||||
unet.set_attn_processor(unet_new_processors)
|
||||
yield None
|
||||
|
||||
finally:
|
||||
unet.set_attn_processor(unet_orig_processors)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
|
131
invokeai/backend/stable_diffusion/denoise_context.py
Normal file
131
invokeai/backend/stable_diffusion/denoise_context.py
Normal file
@ -0,0 +1,131 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode, TextConditioningData
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNetKwargs:
|
||||
sample: torch.Tensor
|
||||
timestep: Union[torch.Tensor, float, int]
|
||||
encoder_hidden_states: torch.Tensor
|
||||
|
||||
class_labels: Optional[torch.Tensor] = None
|
||||
timestep_cond: Optional[torch.Tensor] = None
|
||||
attention_mask: Optional[torch.Tensor] = None
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None
|
||||
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None
|
||||
# return_dict: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class DenoiseInputs:
|
||||
"""Initial variables passed to denoise. Supposed to be unchanged."""
|
||||
|
||||
# The latent-space image to denoise.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
# - If we are inpainting, this is the initial latent image before noise has been added.
|
||||
# - If we are generating a new image, this should be initialized to zeros.
|
||||
# - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
|
||||
orig_latents: torch.Tensor
|
||||
|
||||
# kwargs forwarded to the scheduler.step() method.
|
||||
scheduler_step_kwargs: dict[str, Any]
|
||||
|
||||
# Text conditionging data.
|
||||
conditioning_data: TextConditioningData
|
||||
|
||||
# Noise used for two purposes:
|
||||
# 1. Used by the scheduler to noise the initial `latents` before denoising.
|
||||
# 2. Used to noise the `masked_latents` when inpainting.
|
||||
# `noise` should be None if the `latents` tensor has already been noised.
|
||||
# Shape: [1 or batch, channels, latent_height, latent_width]
|
||||
noise: Optional[torch.Tensor]
|
||||
|
||||
# The seed used to generate the noise for the denoising process.
|
||||
# HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
|
||||
# same noise used earlier in the pipeline. This should really be handled in a clearer way.
|
||||
seed: int
|
||||
|
||||
# The timestep schedule for the denoising process.
|
||||
timesteps: torch.Tensor
|
||||
|
||||
# The first timestep in the schedule. This is used to determine the initial noise level, so
|
||||
# should be populated if you want noise applied *even* if timesteps is empty.
|
||||
init_timestep: torch.Tensor
|
||||
|
||||
# Class of attention processor that is used.
|
||||
attention_processor_cls: Type[Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DenoiseContext:
|
||||
"""Context with all variables in denoise"""
|
||||
|
||||
# Initial variables passed to denoise. Supposed to be unchanged.
|
||||
inputs: DenoiseInputs
|
||||
|
||||
# Scheduler which used to apply noise predictions.
|
||||
scheduler: SchedulerMixin
|
||||
|
||||
# UNet model.
|
||||
unet: Optional[UNet2DConditionModel] = None
|
||||
|
||||
# Current state of latent-space image in denoising process.
|
||||
# None until `pre_denoise_loop` callback.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
latents: Optional[torch.Tensor] = None
|
||||
|
||||
# Current denoising step index.
|
||||
# None until `pre_step` callback.
|
||||
step_index: Optional[int] = None
|
||||
|
||||
# Current denoising step timestep.
|
||||
# None until `pre_step` callback.
|
||||
timestep: Optional[torch.Tensor] = None
|
||||
|
||||
# Arguments which will be passed to UNet model.
|
||||
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
||||
unet_kwargs: Optional[UNetKwargs] = None
|
||||
|
||||
# SchedulerOutput class returned from step function(normally, generated by scheduler).
|
||||
# Supposed to be used only in `post_step` callback, otherwise can be None.
|
||||
step_output: Optional[SchedulerOutput] = None
|
||||
|
||||
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
|
||||
# Available in events inside step(between `pre_step` and `post_stop`).
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
latent_model_input: Optional[torch.Tensor] = None
|
||||
|
||||
# [TMP] Defines on which conditionings current unet call will be runned.
|
||||
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
||||
conditioning_mode: Optional[ConditioningMode] = None
|
||||
|
||||
# [TMP] Noise predictions from negative conditioning.
|
||||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
negative_noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
# [TMP] Noise predictions from positive conditioning.
|
||||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
positive_noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
# Combined noise prediction from passed conditionings.
|
||||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
# Dictionary for extensions to pass extra info about denoise process to other extensions.
|
||||
extra: dict = field(default_factory=dict)
|
@ -23,21 +23,12 @@ from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||
from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState
|
||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
step: int
|
||||
order: int
|
||||
total_steps: int
|
||||
timestep: int
|
||||
latents: torch.Tensor
|
||||
predicted_original: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddsMaskGuidance:
|
||||
mask: torch.Tensor
|
||||
|
@ -1,10 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -95,6 +102,12 @@ class TextConditioningRegions:
|
||||
assert self.masks.shape[1] == len(self.ranges)
|
||||
|
||||
|
||||
class ConditioningMode(Enum):
|
||||
Both = "both"
|
||||
Negative = "negative"
|
||||
Positive = "positive"
|
||||
|
||||
|
||||
class TextConditioningData:
|
||||
def __init__(
|
||||
self,
|
||||
@ -103,7 +116,7 @@ class TextConditioningData:
|
||||
uncond_regions: Optional[TextConditioningRegions],
|
||||
cond_regions: Optional[TextConditioningRegions],
|
||||
guidance_scale: Union[float, List[float]],
|
||||
guidance_rescale_multiplier: float = 0,
|
||||
guidance_rescale_multiplier: float = 0, # TODO: old backend, remove
|
||||
):
|
||||
self.uncond_text = uncond_text
|
||||
self.cond_text = cond_text
|
||||
@ -114,6 +127,7 @@ class TextConditioningData:
|
||||
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
self.guidance_scale = guidance_scale
|
||||
# TODO: old backend, remove
|
||||
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
||||
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
||||
@ -121,3 +135,114 @@ class TextConditioningData:
|
||||
def is_sdxl(self):
|
||||
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
|
||||
def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode):
|
||||
"""Fills unet arguments with data from provided conditionings.
|
||||
|
||||
Args:
|
||||
unet_kwargs (UNetKwargs): Object which stores UNet model arguments.
|
||||
conditioning_mode (ConditioningMode): Describes which conditionings should be used.
|
||||
"""
|
||||
_, _, h, w = unet_kwargs.sample.shape
|
||||
device = unet_kwargs.sample.device
|
||||
dtype = unet_kwargs.sample.dtype
|
||||
|
||||
# TODO: combine regions with conditionings
|
||||
if conditioning_mode == ConditioningMode.Both:
|
||||
conditionings = [self.uncond_text, self.cond_text]
|
||||
c_regions = [self.uncond_regions, self.cond_regions]
|
||||
elif conditioning_mode == ConditioningMode.Positive:
|
||||
conditionings = [self.cond_text]
|
||||
c_regions = [self.cond_regions]
|
||||
elif conditioning_mode == ConditioningMode.Negative:
|
||||
conditionings = [self.uncond_text]
|
||||
c_regions = [self.uncond_regions]
|
||||
else:
|
||||
raise ValueError(f"Unexpected conditioning mode: {conditioning_mode}")
|
||||
|
||||
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
[c.embeds for c in conditionings]
|
||||
)
|
||||
|
||||
unet_kwargs.encoder_hidden_states = encoder_hidden_states
|
||||
unet_kwargs.encoder_attention_mask = encoder_attention_mask
|
||||
|
||||
if self.is_sdxl():
|
||||
added_cond_kwargs = dict( # noqa: C408
|
||||
text_embeds=torch.cat([c.pooled_embeds for c in conditionings]),
|
||||
time_ids=torch.cat([c.add_time_ids for c in conditionings]),
|
||||
)
|
||||
|
||||
unet_kwargs.added_cond_kwargs = added_cond_kwargs
|
||||
|
||||
if any(r is not None for r in c_regions):
|
||||
tmp_regions = []
|
||||
for c, r in zip(conditionings, c_regions, strict=True):
|
||||
if r is None:
|
||||
r = TextConditioningRegions(
|
||||
masks=torch.ones((1, 1, h, w), dtype=dtype),
|
||||
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||
)
|
||||
tmp_regions.append(r)
|
||||
|
||||
if unet_kwargs.cross_attention_kwargs is None:
|
||||
unet_kwargs.cross_attention_kwargs = {}
|
||||
|
||||
unet_kwargs.cross_attention_kwargs.update(
|
||||
regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int) -> torch.Tensor:
|
||||
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
|
||||
|
||||
@classmethod
|
||||
def _pad_conditioning(
|
||||
cls,
|
||||
cond: torch.Tensor,
|
||||
target_len: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Pad provided conditioning tensor to target_len by zeros and returns mask of unpadded bytes.
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): Conditioning tensor which to pads by zeros.
|
||||
target_len (int): To which length(tokens count) pad tensor.
|
||||
"""
|
||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||
|
||||
if cond.shape[1] < target_len:
|
||||
conditioning_attention_mask = cls._pad_zeros(
|
||||
conditioning_attention_mask,
|
||||
pad_shape=(cond.shape[0], target_len - cond.shape[1]),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
cond = cls._pad_zeros(
|
||||
cond,
|
||||
pad_shape=(cond.shape[0], target_len - cond.shape[1], cond.shape[2]),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
return cond, conditioning_attention_mask
|
||||
|
||||
@classmethod
|
||||
def _concat_conditionings_for_batch(
|
||||
cls,
|
||||
conditionings: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Concatenate provided conditioning tensors to one batched tensor.
|
||||
If tensors have different sizes then pad them by zeros and creates
|
||||
encoder_attention_mask to exclude padding from attention.
|
||||
|
||||
Args:
|
||||
conditionings (List[torch.Tensor]): List of conditioning tensors to concatenate.
|
||||
"""
|
||||
encoder_attention_mask = None
|
||||
max_len = max([c.shape[1] for c in conditionings])
|
||||
if any(c.shape[1] != max_len for c in conditionings):
|
||||
encoder_attention_masks = [None] * len(conditionings)
|
||||
for i in range(len(conditionings)):
|
||||
conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(conditionings[i], max_len)
|
||||
encoder_attention_mask = torch.cat(encoder_attention_masks)
|
||||
|
||||
return torch.cat(conditionings), encoder_attention_mask
|
||||
|
@ -1,6 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
TextConditioningRegions,
|
||||
)
|
||||
|
140
invokeai/backend/stable_diffusion/diffusion_backend.py
Normal file
140
invokeai/backend/stable_diffusion/diffusion_backend.py
Normal file
@ -0,0 +1,140 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||
|
||||
|
||||
class StableDiffusionBackend:
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: SchedulerMixin,
|
||||
):
|
||||
self.unet = unet
|
||||
self.scheduler = scheduler
|
||||
config = get_config()
|
||||
self._sequential_guidance = config.sequential_guidance
|
||||
|
||||
def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
if ctx.inputs.init_timestep.shape[0] == 0:
|
||||
return ctx.inputs.orig_latents
|
||||
|
||||
ctx.latents = ctx.inputs.orig_latents.clone()
|
||||
|
||||
if ctx.inputs.noise is not None:
|
||||
batch_size = ctx.latents.shape[0]
|
||||
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||
ctx.latents = ctx.scheduler.add_noise(
|
||||
ctx.latents, ctx.inputs.noise, ctx.inputs.init_timestep.expand(batch_size)
|
||||
)
|
||||
|
||||
# if no work to do, return latents
|
||||
if ctx.inputs.timesteps.shape[0] == 0:
|
||||
return ctx.latents
|
||||
|
||||
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
|
||||
# ext: preview[pre_denoise_loop, priority=low]
|
||||
ext_manager.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, ctx)
|
||||
|
||||
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020
|
||||
# ext: inpaint (apply mask to latents on non-inpaint models)
|
||||
ext_manager.run_callback(ExtensionCallbackType.PRE_STEP, ctx)
|
||||
|
||||
# ext: tiles? [override: step]
|
||||
ctx.step_output = self.step(ctx, ext_manager)
|
||||
|
||||
# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
|
||||
# ext: preview[post_step, priority=low]
|
||||
ext_manager.run_callback(ExtensionCallbackType.POST_STEP, ctx)
|
||||
|
||||
ctx.latents = ctx.step_output.prev_sample
|
||||
|
||||
# ext: inpaint[post_denoise_loop] (restore unmasked part)
|
||||
ext_manager.run_callback(ExtensionCallbackType.POST_DENOISE_LOOP, ctx)
|
||||
return ctx.latents
|
||||
|
||||
@torch.inference_mode()
|
||||
def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> SchedulerOutput:
|
||||
ctx.latent_model_input = ctx.scheduler.scale_model_input(ctx.latents, ctx.timestep)
|
||||
|
||||
# TODO: conditionings as list(conditioning_data.to_unet_kwargs - ready)
|
||||
# Note: The current handling of conditioning doesn't feel very future-proof.
|
||||
# This might change in the future as new requirements come up, but for now,
|
||||
# this is the rough plan.
|
||||
if self._sequential_guidance:
|
||||
ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Negative)
|
||||
ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Positive)
|
||||
else:
|
||||
both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
|
||||
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
|
||||
|
||||
# ext: override apply_cfg
|
||||
ctx.noise_pred = self.apply_cfg(ctx)
|
||||
|
||||
# ext: cfg_rescale [modify_noise_prediction]
|
||||
# TODO: rename
|
||||
ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
|
||||
|
||||
# clean up locals
|
||||
ctx.latent_model_input = None
|
||||
ctx.negative_noise_pred = None
|
||||
ctx.positive_noise_pred = None
|
||||
ctx.noise_pred = None
|
||||
|
||||
return step_output
|
||||
|
||||
@staticmethod
|
||||
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
|
||||
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
|
||||
if isinstance(guidance_scale, list):
|
||||
guidance_scale = guidance_scale[ctx.step_index]
|
||||
|
||||
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
||||
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
||||
|
||||
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
|
||||
sample = ctx.latent_model_input
|
||||
if conditioning_mode == ConditioningMode.Both:
|
||||
sample = torch.cat([sample] * 2)
|
||||
|
||||
ctx.unet_kwargs = UNetKwargs(
|
||||
sample=sample,
|
||||
timestep=ctx.timestep,
|
||||
encoder_hidden_states=None, # set later by conditoning
|
||||
cross_attention_kwargs=dict( # noqa: C408
|
||||
percent_through=ctx.step_index / len(ctx.inputs.timesteps),
|
||||
),
|
||||
)
|
||||
|
||||
ctx.conditioning_mode = conditioning_mode
|
||||
ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
|
||||
|
||||
# ext: controlnet/ip/t2i [pre_unet]
|
||||
ext_manager.run_callback(ExtensionCallbackType.PRE_UNET, ctx)
|
||||
|
||||
# ext: inpaint [pre_unet, priority=low]
|
||||
# or
|
||||
# ext: inpaint [override: unet_forward]
|
||||
noise_pred = self._unet_forward(**vars(ctx.unet_kwargs))
|
||||
|
||||
ext_manager.run_callback(ExtensionCallbackType.POST_UNET, ctx)
|
||||
|
||||
# clean up locals
|
||||
ctx.unet_kwargs = None
|
||||
ctx.conditioning_mode = None
|
||||
|
||||
return noise_pred
|
||||
|
||||
def _unet_forward(self, **kwargs) -> torch.Tensor:
|
||||
return self.unet(**kwargs).sample
|
12
invokeai/backend/stable_diffusion/extension_callback_type.py
Normal file
12
invokeai/backend/stable_diffusion/extension_callback_type.py
Normal file
@ -0,0 +1,12 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ExtensionCallbackType(Enum):
|
||||
SETUP = "setup"
|
||||
PRE_DENOISE_LOOP = "pre_denoise_loop"
|
||||
POST_DENOISE_LOOP = "post_denoise_loop"
|
||||
PRE_STEP = "pre_step"
|
||||
POST_STEP = "post_step"
|
||||
PRE_UNET = "pre_unet"
|
||||
POST_UNET = "post_unet"
|
||||
POST_APPLY_CFG = "post_apply_cfg"
|
60
invokeai/backend/stable_diffusion/extensions/base.py
Normal file
60
invokeai/backend/stable_diffusion/extensions/base.py
Normal file
@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallbackMetadata:
|
||||
callback_type: ExtensionCallbackType
|
||||
order: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallbackFunctionWithMetadata:
|
||||
metadata: CallbackMetadata
|
||||
function: Callable[[DenoiseContext], None]
|
||||
|
||||
|
||||
def callback(callback_type: ExtensionCallbackType, order: int = 0):
|
||||
def _decorator(function):
|
||||
function._ext_metadata = CallbackMetadata(
|
||||
callback_type=callback_type,
|
||||
order=order,
|
||||
)
|
||||
return function
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
class ExtensionBase:
|
||||
def __init__(self):
|
||||
self._callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
|
||||
|
||||
# Register all of the callback methods for this instance.
|
||||
for func_name in dir(self):
|
||||
func = getattr(self, func_name)
|
||||
metadata = getattr(func, "_ext_metadata", None)
|
||||
if metadata is not None and isinstance(metadata, CallbackMetadata):
|
||||
if metadata.callback_type not in self._callbacks:
|
||||
self._callbacks[metadata.callback_type] = []
|
||||
self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func))
|
||||
|
||||
def get_callbacks(self):
|
||||
return self._callbacks
|
||||
|
||||
@contextmanager
|
||||
def patch_extension(self, context: DenoiseContext):
|
||||
yield None
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||
yield None
|
63
invokeai/backend/stable_diffusion/extensions/preview.py
Normal file
63
invokeai/backend/stable_diffusion/extensions/preview.py
Normal file
@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
|
||||
|
||||
# TODO: change event to accept image instead of latents
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
step: int
|
||||
order: int
|
||||
total_steps: int
|
||||
timestep: int
|
||||
latents: torch.Tensor
|
||||
predicted_original: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class PreviewExt(ExtensionBase):
|
||||
def __init__(self, callback: Callable[[PipelineIntermediateState], None]):
|
||||
super().__init__()
|
||||
self.callback = callback
|
||||
|
||||
# do last so that all other changes shown
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
|
||||
def initial_preview(self, ctx: DenoiseContext):
|
||||
self.callback(
|
||||
PipelineIntermediateState(
|
||||
step=-1,
|
||||
order=ctx.scheduler.order,
|
||||
total_steps=len(ctx.inputs.timesteps),
|
||||
timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it?
|
||||
latents=ctx.latents,
|
||||
)
|
||||
)
|
||||
|
||||
# do last so that all other changes shown
|
||||
@callback(ExtensionCallbackType.POST_STEP, order=1000)
|
||||
def step_preview(self, ctx: DenoiseContext):
|
||||
if hasattr(ctx.step_output, "denoised"):
|
||||
predicted_original = ctx.step_output.denoised
|
||||
elif hasattr(ctx.step_output, "pred_original_sample"):
|
||||
predicted_original = ctx.step_output.pred_original_sample
|
||||
else:
|
||||
predicted_original = ctx.step_output.prev_sample
|
||||
|
||||
self.callback(
|
||||
PipelineIntermediateState(
|
||||
step=ctx.step_index,
|
||||
order=ctx.scheduler.order,
|
||||
total_steps=len(ctx.inputs.timesteps),
|
||||
timestep=int(ctx.timestep), # TODO: is there any code which uses it?
|
||||
latents=ctx.step_output.prev_sample,
|
||||
predicted_original=predicted_original, # TODO: is there any reason for additional field?
|
||||
)
|
||||
)
|
71
invokeai/backend/stable_diffusion/extensions_manager.py
Normal file
71
invokeai/backend/stable_diffusion/extensions_manager.py
Normal file
@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase
|
||||
|
||||
|
||||
class ExtensionsManager:
|
||||
def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
|
||||
self._is_canceled = is_canceled
|
||||
|
||||
# A list of extensions in the order that they were added to the ExtensionsManager.
|
||||
self._extensions: List[ExtensionBase] = []
|
||||
self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
|
||||
|
||||
def add_extension(self, extension: ExtensionBase):
|
||||
self._extensions.append(extension)
|
||||
self._regenerate_ordered_callbacks()
|
||||
|
||||
def _regenerate_ordered_callbacks(self):
|
||||
"""Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added."""
|
||||
self._ordered_callbacks = {}
|
||||
|
||||
# Fill the ordered callbacks dictionary.
|
||||
for extension in self._extensions:
|
||||
for callback_type, callbacks in extension.get_callbacks().items():
|
||||
if callback_type not in self._ordered_callbacks:
|
||||
self._ordered_callbacks[callback_type] = []
|
||||
self._ordered_callbacks[callback_type].extend(callbacks)
|
||||
|
||||
# Sort each callback list.
|
||||
for callback_type, callbacks in self._ordered_callbacks.items():
|
||||
# Note that sorted() is stable, so if two callbacks have the same order, the order that they extensions were
|
||||
# added will be preserved.
|
||||
self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order)
|
||||
|
||||
def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext):
|
||||
if self._is_canceled and self._is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
callbacks = self._ordered_callbacks.get(callback_type, [])
|
||||
for cb in callbacks:
|
||||
cb.function(ctx)
|
||||
|
||||
@contextmanager
|
||||
def patch_extensions(self, context: DenoiseContext):
|
||||
if self._is_canceled and self._is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
for ext in self._extensions:
|
||||
exit_stack.enter_context(ext.patch_extension(context))
|
||||
|
||||
yield None
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||
if self._is_canceled and self._is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
# TODO: create logic in PR with extension which uses it
|
||||
yield None
|
46
tests/backend/stable_diffusion/extensions/test_base.py
Normal file
46
tests/backend/stable_diffusion/extensions/test_base.py
Normal file
@ -0,0 +1,46 @@
|
||||
from unittest import mock
|
||||
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
|
||||
|
||||
class MockExtension(ExtensionBase):
|
||||
"""A mock ExtensionBase subclass for testing purposes."""
|
||||
|
||||
def __init__(self, x: int):
|
||||
super().__init__()
|
||||
self._x = x
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||
def set_step_index(self, ctx: DenoiseContext):
|
||||
ctx.step_index = self._x
|
||||
|
||||
|
||||
def test_extension_base_callback_registration():
|
||||
"""Test that a callback can be successfully registered with an extension."""
|
||||
val = 5
|
||||
mock_extension = MockExtension(val)
|
||||
|
||||
mock_ctx = mock.MagicMock()
|
||||
|
||||
callbacks = mock_extension.get_callbacks()
|
||||
pre_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.PRE_DENOISE_LOOP, [])
|
||||
assert len(pre_denoise_loop_cbs) == 1
|
||||
|
||||
# Call the mock callback.
|
||||
pre_denoise_loop_cbs[0].function(mock_ctx)
|
||||
|
||||
# Confirm that the callback ran.
|
||||
assert mock_ctx.step_index == val
|
||||
|
||||
|
||||
def test_extension_base_empty_callback_type():
|
||||
"""Test that an empty list is returned when no callbacks are registered for a given callback type."""
|
||||
mock_extension = MockExtension(5)
|
||||
|
||||
# There should be no callbacks registered for POST_DENOISE_LOOP.
|
||||
callbacks = mock_extension.get_callbacks()
|
||||
|
||||
post_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.POST_DENOISE_LOOP, [])
|
||||
assert len(post_denoise_loop_cbs) == 0
|
112
tests/backend/stable_diffusion/test_extension_manager.py
Normal file
112
tests/backend/stable_diffusion/test_extension_manager.py
Normal file
@ -0,0 +1,112 @@
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||
|
||||
|
||||
class MockExtension(ExtensionBase):
|
||||
"""A mock ExtensionBase subclass for testing purposes."""
|
||||
|
||||
def __init__(self, x: int):
|
||||
super().__init__()
|
||||
self._x = x
|
||||
|
||||
# Note that order is not specified. It should default to 0.
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||
def set_step_index(self, ctx: DenoiseContext):
|
||||
ctx.step_index = self._x
|
||||
|
||||
|
||||
class MockExtensionLate(ExtensionBase):
|
||||
"""A mock ExtensionBase subclass with a high order value on its PRE_DENOISE_LOOP callback."""
|
||||
|
||||
def __init__(self, x: int):
|
||||
super().__init__()
|
||||
self._x = x
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
|
||||
def set_step_index(self, ctx: DenoiseContext):
|
||||
ctx.step_index = self._x
|
||||
|
||||
|
||||
def test_extension_manager_run_callback():
|
||||
"""Test that run_callback runs all callbacks for the given callback type."""
|
||||
|
||||
em = ExtensionsManager()
|
||||
mock_extension_1 = MockExtension(1)
|
||||
em.add_extension(mock_extension_1)
|
||||
|
||||
mock_ctx = mock.MagicMock()
|
||||
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
|
||||
|
||||
assert mock_ctx.step_index == 1
|
||||
|
||||
|
||||
def test_extension_manager_run_callback_no_callbacks():
|
||||
"""Test that run_callback does not raise an error when there are no callbacks for the given callback type."""
|
||||
em = ExtensionsManager()
|
||||
mock_ctx = mock.MagicMock()
|
||||
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["extension_1", "extension_2"],
|
||||
# Regardless of initialization order, we expect MockExtensionLate to run last.
|
||||
[(MockExtension(1), MockExtensionLate(2)), (MockExtensionLate(2), MockExtension(1))],
|
||||
)
|
||||
def test_extension_manager_order_callbacks(extension_1: ExtensionBase, extension_2: ExtensionBase):
|
||||
"""Test that run_callback runs callbacks in the correct order."""
|
||||
em = ExtensionsManager()
|
||||
em.add_extension(extension_1)
|
||||
em.add_extension(extension_2)
|
||||
|
||||
mock_ctx = mock.MagicMock()
|
||||
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
|
||||
|
||||
assert mock_ctx.step_index == 2
|
||||
|
||||
|
||||
class MockExtensionStableSort(ExtensionBase):
|
||||
"""A mock extension with three PRE_DENOISE_LOOP callbacks, each with a different order value."""
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=-1000)
|
||||
def early(self, ctx: DenoiseContext):
|
||||
pass
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||
def middle(self, ctx: DenoiseContext):
|
||||
pass
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
|
||||
def late(self, ctx: DenoiseContext):
|
||||
pass
|
||||
|
||||
|
||||
def test_extension_manager_stable_sort():
|
||||
"""Test that when two callbacks have the same 'order' value, they are sorted based on the order they were added to
|
||||
the ExtensionsManager."""
|
||||
|
||||
em = ExtensionsManager()
|
||||
|
||||
mock_extension_1 = MockExtensionStableSort()
|
||||
mock_extension_2 = MockExtensionStableSort()
|
||||
|
||||
em.add_extension(mock_extension_1)
|
||||
em.add_extension(mock_extension_2)
|
||||
|
||||
expected_order = [
|
||||
mock_extension_1.early,
|
||||
mock_extension_2.early,
|
||||
mock_extension_1.middle,
|
||||
mock_extension_2.middle,
|
||||
mock_extension_1.late,
|
||||
mock_extension_2.late,
|
||||
]
|
||||
|
||||
# It's not ideal that we are accessing a private attribute here, but this was the most direct way to assert the
|
||||
# desired behaviour.
|
||||
assert [cb.function for cb in em._ordered_callbacks[ExtensionCallbackType.PRE_DENOISE_LOOP]] == expected_order
|
Loading…
Reference in New Issue
Block a user