mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Base of modular backend (#6606)
## Summary Base code of new modular backend from #6577. Contains normal generation and regional prompts support. Also preview extension included to test if extensions logic works. ## Related Issues / Discussions https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d ## QA Instructions Run with and without set `USE_MODULAR_DENOISE` environment. Currently only normal and regional conditionings supported, so just generate some images and compare with main output. ## Merge Plan Discuss a bit more about injection point names? As if for example in future unet will be overridable, current `pre_unet`/`post_unet` assumes to name override as `unet` what feels a bit odd. Also `apply_cfg` - future implementation could ignore/not use cfg, so in this case `combine_noise_predictions`/`combine_noise` seems more suitable. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
This commit is contained in:
commit
473f4cc1c3
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
import inspect
|
import inspect
|
||||||
|
import os
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -39,6 +40,7 @@ from invokeai.backend.lora import LoRAModelRaw
|
|||||||
from invokeai.backend.model_manager import BaseModelType
|
from invokeai.backend.model_manager import BaseModelType
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ControlNetData,
|
ControlNetData,
|
||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
@ -53,6 +55,11 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
TextConditioningData,
|
TextConditioningData,
|
||||||
TextConditioningRegions,
|
TextConditioningRegions,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||||
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
@ -314,9 +321,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||||
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||||
unet: UNet2DConditionModel,
|
|
||||||
latent_height: int,
|
latent_height: int,
|
||||||
latent_width: int,
|
latent_width: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
cfg_scale: float | list[float],
|
cfg_scale: float | list[float],
|
||||||
steps: int,
|
steps: int,
|
||||||
cfg_rescale_multiplier: float,
|
cfg_rescale_multiplier: float,
|
||||||
@ -330,10 +338,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
uncond_list = [uncond_list]
|
uncond_list = [uncond_list]
|
||||||
|
|
||||||
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||||
cond_list, context, unet.device, unet.dtype
|
cond_list, context, device, dtype
|
||||||
)
|
)
|
||||||
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||||
uncond_list, context, unet.device, unet.dtype
|
uncond_list, context, device, dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||||
@ -341,14 +349,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
masks=cond_text_embedding_masks,
|
masks=cond_text_embedding_masks,
|
||||||
latent_height=latent_height,
|
latent_height=latent_height,
|
||||||
latent_width=latent_width,
|
latent_width=latent_width,
|
||||||
dtype=unet.dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||||
text_conditionings=uncond_text_embeddings,
|
text_conditionings=uncond_text_embeddings,
|
||||||
masks=uncond_text_embedding_masks,
|
masks=uncond_text_embedding_masks,
|
||||||
latent_height=latent_height,
|
latent_height=latent_height,
|
||||||
latent_width=latent_width,
|
latent_width=latent_width,
|
||||||
dtype=unet.dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(cfg_scale, list):
|
if isinstance(cfg_scale, list):
|
||||||
@ -707,9 +715,108 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
return seed, noise, latents
|
return seed, noise, latents
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
if os.environ.get("USE_MODULAR_DENOISE", False):
|
||||||
|
return self._new_invoke(context)
|
||||||
|
else:
|
||||||
|
return self._old_invoke(context)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
ext_manager = ExtensionsManager(is_canceled=context.util.is_canceled)
|
||||||
|
|
||||||
|
device = TorchDevice.choose_torch_device()
|
||||||
|
dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
|
||||||
|
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
|
latents = latents.to(device=device, dtype=dtype)
|
||||||
|
if noise is not None:
|
||||||
|
noise = noise.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
|
||||||
|
conditioning_data = self.get_conditioning_data(
|
||||||
|
context=context,
|
||||||
|
positive_conditioning_field=self.positive_conditioning,
|
||||||
|
negative_conditioning_field=self.negative_conditioning,
|
||||||
|
cfg_scale=self.cfg_scale,
|
||||||
|
steps=self.steps,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
# TODO: old backend, remove
|
||||||
|
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler = get_scheduler(
|
||||||
|
context=context,
|
||||||
|
scheduler_info=self.unet.scheduler,
|
||||||
|
scheduler_name=self.scheduler,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||||
|
scheduler,
|
||||||
|
seed=seed,
|
||||||
|
device=device,
|
||||||
|
steps=self.steps,
|
||||||
|
denoising_start=self.denoising_start,
|
||||||
|
denoising_end=self.denoising_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
denoise_ctx = DenoiseContext(
|
||||||
|
inputs=DenoiseInputs(
|
||||||
|
orig_latents=latents,
|
||||||
|
timesteps=timesteps,
|
||||||
|
init_timestep=init_timestep,
|
||||||
|
noise=noise,
|
||||||
|
seed=seed,
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
attention_processor_cls=CustomAttnProcessor2_0,
|
||||||
|
),
|
||||||
|
unet=None,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||||
|
unet_config = context.models.get_config(self.unet.unet.key)
|
||||||
|
|
||||||
|
### preview
|
||||||
|
def step_callback(state: PipelineIntermediateState) -> None:
|
||||||
|
context.util.sd_step_callback(state, unet_config.base)
|
||||||
|
|
||||||
|
ext_manager.add_extension(PreviewExt(step_callback))
|
||||||
|
|
||||||
|
# ext: t2i/ip adapter
|
||||||
|
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||||
|
|
||||||
|
unet_info = context.models.load(self.unet.unet)
|
||||||
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
|
with (
|
||||||
|
unet_info.model_on_device() as (model_state_dict, unet),
|
||||||
|
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
||||||
|
# ext: controlnet
|
||||||
|
ext_manager.patch_extensions(unet),
|
||||||
|
# ext: freeu, seamless, ip adapter, lora
|
||||||
|
ext_manager.patch_unet(model_state_dict, unet),
|
||||||
|
):
|
||||||
|
sd_backend = StableDiffusionBackend(unet, scheduler)
|
||||||
|
denoise_ctx.unet = unet
|
||||||
|
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
result_latents = result_latents.detach().to("cpu")
|
||||||
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
|
name = context.tensors.save(tensor=result_latents)
|
||||||
|
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||||
|
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
|
|
||||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||||
@ -788,7 +895,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
context=context,
|
context=context,
|
||||||
positive_conditioning_field=self.positive_conditioning,
|
positive_conditioning_field=self.positive_conditioning,
|
||||||
negative_conditioning_field=self.negative_conditioning,
|
negative_conditioning_field=self.negative_conditioning,
|
||||||
unet=unet,
|
device=unet.device,
|
||||||
|
dtype=unet.dtype,
|
||||||
latent_height=latent_height,
|
latent_height=latent_height,
|
||||||
latent_width=latent_width,
|
latent_width=latent_width,
|
||||||
cfg_scale=self.cfg_scale,
|
cfg_scale=self.cfg_scale,
|
||||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -32,8 +32,27 @@ with LoRAHelper.apply_lora_unet(unet, loras):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# TODO: rename smth like ModelPatcher and add TI method?
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
|
@staticmethod
|
||||||
|
@contextmanager
|
||||||
|
def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]):
|
||||||
|
"""A context manager that patches `unet` with the provided attention processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unet (UNet2DConditionModel): The UNet model to patch.
|
||||||
|
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
|
||||||
|
"""
|
||||||
|
unet_orig_processors = unet.attn_processors
|
||||||
|
|
||||||
|
# create separate instance for each attention, to be able modify each attention separately
|
||||||
|
unet_new_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
|
||||||
|
try:
|
||||||
|
unet.set_attn_processor(unet_new_processors)
|
||||||
|
yield None
|
||||||
|
|
||||||
|
finally:
|
||||||
|
unet.set_attn_processor(unet_orig_processors)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||||
assert "." not in lora_key
|
assert "." not in lora_key
|
||||||
|
131
invokeai/backend/stable_diffusion/denoise_context.py
Normal file
131
invokeai/backend/stable_diffusion/denoise_context.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode, TextConditioningData
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UNetKwargs:
|
||||||
|
sample: torch.Tensor
|
||||||
|
timestep: Union[torch.Tensor, float, int]
|
||||||
|
encoder_hidden_states: torch.Tensor
|
||||||
|
|
||||||
|
class_labels: Optional[torch.Tensor] = None
|
||||||
|
timestep_cond: Optional[torch.Tensor] = None
|
||||||
|
attention_mask: Optional[torch.Tensor] = None
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
||||||
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
|
||||||
|
mid_block_additional_residual: Optional[torch.Tensor] = None
|
||||||
|
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None
|
||||||
|
# return_dict: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DenoiseInputs:
|
||||||
|
"""Initial variables passed to denoise. Supposed to be unchanged."""
|
||||||
|
|
||||||
|
# The latent-space image to denoise.
|
||||||
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
|
# - If we are inpainting, this is the initial latent image before noise has been added.
|
||||||
|
# - If we are generating a new image, this should be initialized to zeros.
|
||||||
|
# - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
|
||||||
|
orig_latents: torch.Tensor
|
||||||
|
|
||||||
|
# kwargs forwarded to the scheduler.step() method.
|
||||||
|
scheduler_step_kwargs: dict[str, Any]
|
||||||
|
|
||||||
|
# Text conditionging data.
|
||||||
|
conditioning_data: TextConditioningData
|
||||||
|
|
||||||
|
# Noise used for two purposes:
|
||||||
|
# 1. Used by the scheduler to noise the initial `latents` before denoising.
|
||||||
|
# 2. Used to noise the `masked_latents` when inpainting.
|
||||||
|
# `noise` should be None if the `latents` tensor has already been noised.
|
||||||
|
# Shape: [1 or batch, channels, latent_height, latent_width]
|
||||||
|
noise: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# The seed used to generate the noise for the denoising process.
|
||||||
|
# HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
|
||||||
|
# same noise used earlier in the pipeline. This should really be handled in a clearer way.
|
||||||
|
seed: int
|
||||||
|
|
||||||
|
# The timestep schedule for the denoising process.
|
||||||
|
timesteps: torch.Tensor
|
||||||
|
|
||||||
|
# The first timestep in the schedule. This is used to determine the initial noise level, so
|
||||||
|
# should be populated if you want noise applied *even* if timesteps is empty.
|
||||||
|
init_timestep: torch.Tensor
|
||||||
|
|
||||||
|
# Class of attention processor that is used.
|
||||||
|
attention_processor_cls: Type[Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DenoiseContext:
|
||||||
|
"""Context with all variables in denoise"""
|
||||||
|
|
||||||
|
# Initial variables passed to denoise. Supposed to be unchanged.
|
||||||
|
inputs: DenoiseInputs
|
||||||
|
|
||||||
|
# Scheduler which used to apply noise predictions.
|
||||||
|
scheduler: SchedulerMixin
|
||||||
|
|
||||||
|
# UNet model.
|
||||||
|
unet: Optional[UNet2DConditionModel] = None
|
||||||
|
|
||||||
|
# Current state of latent-space image in denoising process.
|
||||||
|
# None until `pre_denoise_loop` callback.
|
||||||
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
|
latents: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# Current denoising step index.
|
||||||
|
# None until `pre_step` callback.
|
||||||
|
step_index: Optional[int] = None
|
||||||
|
|
||||||
|
# Current denoising step timestep.
|
||||||
|
# None until `pre_step` callback.
|
||||||
|
timestep: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# Arguments which will be passed to UNet model.
|
||||||
|
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
||||||
|
unet_kwargs: Optional[UNetKwargs] = None
|
||||||
|
|
||||||
|
# SchedulerOutput class returned from step function(normally, generated by scheduler).
|
||||||
|
# Supposed to be used only in `post_step` callback, otherwise can be None.
|
||||||
|
step_output: Optional[SchedulerOutput] = None
|
||||||
|
|
||||||
|
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
|
||||||
|
# Available in events inside step(between `pre_step` and `post_stop`).
|
||||||
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
|
latent_model_input: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# [TMP] Defines on which conditionings current unet call will be runned.
|
||||||
|
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
||||||
|
conditioning_mode: Optional[ConditioningMode] = None
|
||||||
|
|
||||||
|
# [TMP] Noise predictions from negative conditioning.
|
||||||
|
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||||
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
|
negative_noise_pred: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# [TMP] Noise predictions from positive conditioning.
|
||||||
|
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||||
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
|
positive_noise_pred: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# Combined noise prediction from passed conditionings.
|
||||||
|
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||||
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
|
noise_pred: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# Dictionary for extensions to pass extra info about denoise process to other extensions.
|
||||||
|
extra: dict = field(default_factory=dict)
|
@ -23,21 +23,12 @@ from invokeai.app.services.config.config_default import get_config
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState
|
||||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PipelineIntermediateState:
|
|
||||||
step: int
|
|
||||||
order: int
|
|
||||||
total_steps: int
|
|
||||||
timestep: int
|
|
||||||
latents: torch.Tensor
|
|
||||||
predicted_original: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AddsMaskGuidance:
|
class AddsMaskGuidance:
|
||||||
mask: torch.Tensor
|
mask: torch.Tensor
|
||||||
|
@ -1,10 +1,17 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union
|
from enum import Enum
|
||||||
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -95,6 +102,12 @@ class TextConditioningRegions:
|
|||||||
assert self.masks.shape[1] == len(self.ranges)
|
assert self.masks.shape[1] == len(self.ranges)
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningMode(Enum):
|
||||||
|
Both = "both"
|
||||||
|
Negative = "negative"
|
||||||
|
Positive = "positive"
|
||||||
|
|
||||||
|
|
||||||
class TextConditioningData:
|
class TextConditioningData:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -103,7 +116,7 @@ class TextConditioningData:
|
|||||||
uncond_regions: Optional[TextConditioningRegions],
|
uncond_regions: Optional[TextConditioningRegions],
|
||||||
cond_regions: Optional[TextConditioningRegions],
|
cond_regions: Optional[TextConditioningRegions],
|
||||||
guidance_scale: Union[float, List[float]],
|
guidance_scale: Union[float, List[float]],
|
||||||
guidance_rescale_multiplier: float = 0,
|
guidance_rescale_multiplier: float = 0, # TODO: old backend, remove
|
||||||
):
|
):
|
||||||
self.uncond_text = uncond_text
|
self.uncond_text = uncond_text
|
||||||
self.cond_text = cond_text
|
self.cond_text = cond_text
|
||||||
@ -114,6 +127,7 @@ class TextConditioningData:
|
|||||||
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||||
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||||
self.guidance_scale = guidance_scale
|
self.guidance_scale = guidance_scale
|
||||||
|
# TODO: old backend, remove
|
||||||
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
||||||
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||||
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
||||||
@ -121,3 +135,114 @@ class TextConditioningData:
|
|||||||
def is_sdxl(self):
|
def is_sdxl(self):
|
||||||
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||||
return isinstance(self.cond_text, SDXLConditioningInfo)
|
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||||
|
|
||||||
|
def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode):
|
||||||
|
"""Fills unet arguments with data from provided conditionings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unet_kwargs (UNetKwargs): Object which stores UNet model arguments.
|
||||||
|
conditioning_mode (ConditioningMode): Describes which conditionings should be used.
|
||||||
|
"""
|
||||||
|
_, _, h, w = unet_kwargs.sample.shape
|
||||||
|
device = unet_kwargs.sample.device
|
||||||
|
dtype = unet_kwargs.sample.dtype
|
||||||
|
|
||||||
|
# TODO: combine regions with conditionings
|
||||||
|
if conditioning_mode == ConditioningMode.Both:
|
||||||
|
conditionings = [self.uncond_text, self.cond_text]
|
||||||
|
c_regions = [self.uncond_regions, self.cond_regions]
|
||||||
|
elif conditioning_mode == ConditioningMode.Positive:
|
||||||
|
conditionings = [self.cond_text]
|
||||||
|
c_regions = [self.cond_regions]
|
||||||
|
elif conditioning_mode == ConditioningMode.Negative:
|
||||||
|
conditionings = [self.uncond_text]
|
||||||
|
c_regions = [self.uncond_regions]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected conditioning mode: {conditioning_mode}")
|
||||||
|
|
||||||
|
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||||
|
[c.embeds for c in conditionings]
|
||||||
|
)
|
||||||
|
|
||||||
|
unet_kwargs.encoder_hidden_states = encoder_hidden_states
|
||||||
|
unet_kwargs.encoder_attention_mask = encoder_attention_mask
|
||||||
|
|
||||||
|
if self.is_sdxl():
|
||||||
|
added_cond_kwargs = dict( # noqa: C408
|
||||||
|
text_embeds=torch.cat([c.pooled_embeds for c in conditionings]),
|
||||||
|
time_ids=torch.cat([c.add_time_ids for c in conditionings]),
|
||||||
|
)
|
||||||
|
|
||||||
|
unet_kwargs.added_cond_kwargs = added_cond_kwargs
|
||||||
|
|
||||||
|
if any(r is not None for r in c_regions):
|
||||||
|
tmp_regions = []
|
||||||
|
for c, r in zip(conditionings, c_regions, strict=True):
|
||||||
|
if r is None:
|
||||||
|
r = TextConditioningRegions(
|
||||||
|
masks=torch.ones((1, 1, h, w), dtype=dtype),
|
||||||
|
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||||
|
)
|
||||||
|
tmp_regions.append(r)
|
||||||
|
|
||||||
|
if unet_kwargs.cross_attention_kwargs is None:
|
||||||
|
unet_kwargs.cross_attention_kwargs = {}
|
||||||
|
|
||||||
|
unet_kwargs.cross_attention_kwargs.update(
|
||||||
|
regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int) -> torch.Tensor:
|
||||||
|
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _pad_conditioning(
|
||||||
|
cls,
|
||||||
|
cond: torch.Tensor,
|
||||||
|
target_len: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Pad provided conditioning tensor to target_len by zeros and returns mask of unpadded bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cond (torch.Tensor): Conditioning tensor which to pads by zeros.
|
||||||
|
target_len (int): To which length(tokens count) pad tensor.
|
||||||
|
"""
|
||||||
|
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||||
|
|
||||||
|
if cond.shape[1] < target_len:
|
||||||
|
conditioning_attention_mask = cls._pad_zeros(
|
||||||
|
conditioning_attention_mask,
|
||||||
|
pad_shape=(cond.shape[0], target_len - cond.shape[1]),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cond = cls._pad_zeros(
|
||||||
|
cond,
|
||||||
|
pad_shape=(cond.shape[0], target_len - cond.shape[1], cond.shape[2]),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cond, conditioning_attention_mask
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _concat_conditionings_for_batch(
|
||||||
|
cls,
|
||||||
|
conditionings: List[torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""Concatenate provided conditioning tensors to one batched tensor.
|
||||||
|
If tensors have different sizes then pad them by zeros and creates
|
||||||
|
encoder_attention_mask to exclude padding from attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conditionings (List[torch.Tensor]): List of conditioning tensors to concatenate.
|
||||||
|
"""
|
||||||
|
encoder_attention_mask = None
|
||||||
|
max_len = max([c.shape[1] for c in conditionings])
|
||||||
|
if any(c.shape[1] != max_len for c in conditionings):
|
||||||
|
encoder_attention_masks = [None] * len(conditionings)
|
||||||
|
for i in range(len(conditionings)):
|
||||||
|
conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(conditionings[i], max_len)
|
||||||
|
encoder_attention_mask = torch.cat(encoder_attention_masks)
|
||||||
|
|
||||||
|
return torch.cat(conditionings), encoder_attention_mask
|
||||||
|
@ -1,9 +1,14 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
if TYPE_CHECKING:
|
||||||
TextConditioningRegions,
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
)
|
TextConditioningRegions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RegionalPromptData:
|
class RegionalPromptData:
|
||||||
|
140
invokeai/backend/stable_diffusion/diffusion_backend.py
Normal file
140
invokeai/backend/stable_diffusion/diffusion_backend.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
from invokeai.app.services.config.config_default import get_config
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionBackend:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
scheduler: SchedulerMixin,
|
||||||
|
):
|
||||||
|
self.unet = unet
|
||||||
|
self.scheduler = scheduler
|
||||||
|
config = get_config()
|
||||||
|
self._sequential_guidance = config.sequential_guidance
|
||||||
|
|
||||||
|
def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||||
|
if ctx.inputs.init_timestep.shape[0] == 0:
|
||||||
|
return ctx.inputs.orig_latents
|
||||||
|
|
||||||
|
ctx.latents = ctx.inputs.orig_latents.clone()
|
||||||
|
|
||||||
|
if ctx.inputs.noise is not None:
|
||||||
|
batch_size = ctx.latents.shape[0]
|
||||||
|
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||||
|
ctx.latents = ctx.scheduler.add_noise(
|
||||||
|
ctx.latents, ctx.inputs.noise, ctx.inputs.init_timestep.expand(batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
# if no work to do, return latents
|
||||||
|
if ctx.inputs.timesteps.shape[0] == 0:
|
||||||
|
return ctx.latents
|
||||||
|
|
||||||
|
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
|
||||||
|
# ext: preview[pre_denoise_loop, priority=low]
|
||||||
|
ext_manager.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, ctx)
|
||||||
|
|
||||||
|
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020
|
||||||
|
# ext: inpaint (apply mask to latents on non-inpaint models)
|
||||||
|
ext_manager.run_callback(ExtensionCallbackType.PRE_STEP, ctx)
|
||||||
|
|
||||||
|
# ext: tiles? [override: step]
|
||||||
|
ctx.step_output = self.step(ctx, ext_manager)
|
||||||
|
|
||||||
|
# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
|
||||||
|
# ext: preview[post_step, priority=low]
|
||||||
|
ext_manager.run_callback(ExtensionCallbackType.POST_STEP, ctx)
|
||||||
|
|
||||||
|
ctx.latents = ctx.step_output.prev_sample
|
||||||
|
|
||||||
|
# ext: inpaint[post_denoise_loop] (restore unmasked part)
|
||||||
|
ext_manager.run_callback(ExtensionCallbackType.POST_DENOISE_LOOP, ctx)
|
||||||
|
return ctx.latents
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> SchedulerOutput:
|
||||||
|
ctx.latent_model_input = ctx.scheduler.scale_model_input(ctx.latents, ctx.timestep)
|
||||||
|
|
||||||
|
# TODO: conditionings as list(conditioning_data.to_unet_kwargs - ready)
|
||||||
|
# Note: The current handling of conditioning doesn't feel very future-proof.
|
||||||
|
# This might change in the future as new requirements come up, but for now,
|
||||||
|
# this is the rough plan.
|
||||||
|
if self._sequential_guidance:
|
||||||
|
ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Negative)
|
||||||
|
ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Positive)
|
||||||
|
else:
|
||||||
|
both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
|
||||||
|
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
|
||||||
|
|
||||||
|
# ext: override apply_cfg
|
||||||
|
ctx.noise_pred = self.apply_cfg(ctx)
|
||||||
|
|
||||||
|
# ext: cfg_rescale [modify_noise_prediction]
|
||||||
|
# TODO: rename
|
||||||
|
ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
|
||||||
|
|
||||||
|
# clean up locals
|
||||||
|
ctx.latent_model_input = None
|
||||||
|
ctx.negative_noise_pred = None
|
||||||
|
ctx.positive_noise_pred = None
|
||||||
|
ctx.noise_pred = None
|
||||||
|
|
||||||
|
return step_output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
|
||||||
|
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
|
||||||
|
if isinstance(guidance_scale, list):
|
||||||
|
guidance_scale = guidance_scale[ctx.step_index]
|
||||||
|
|
||||||
|
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
||||||
|
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
||||||
|
|
||||||
|
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
|
||||||
|
sample = ctx.latent_model_input
|
||||||
|
if conditioning_mode == ConditioningMode.Both:
|
||||||
|
sample = torch.cat([sample] * 2)
|
||||||
|
|
||||||
|
ctx.unet_kwargs = UNetKwargs(
|
||||||
|
sample=sample,
|
||||||
|
timestep=ctx.timestep,
|
||||||
|
encoder_hidden_states=None, # set later by conditoning
|
||||||
|
cross_attention_kwargs=dict( # noqa: C408
|
||||||
|
percent_through=ctx.step_index / len(ctx.inputs.timesteps),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.conditioning_mode = conditioning_mode
|
||||||
|
ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
|
||||||
|
|
||||||
|
# ext: controlnet/ip/t2i [pre_unet]
|
||||||
|
ext_manager.run_callback(ExtensionCallbackType.PRE_UNET, ctx)
|
||||||
|
|
||||||
|
# ext: inpaint [pre_unet, priority=low]
|
||||||
|
# or
|
||||||
|
# ext: inpaint [override: unet_forward]
|
||||||
|
noise_pred = self._unet_forward(**vars(ctx.unet_kwargs))
|
||||||
|
|
||||||
|
ext_manager.run_callback(ExtensionCallbackType.POST_UNET, ctx)
|
||||||
|
|
||||||
|
# clean up locals
|
||||||
|
ctx.unet_kwargs = None
|
||||||
|
ctx.conditioning_mode = None
|
||||||
|
|
||||||
|
return noise_pred
|
||||||
|
|
||||||
|
def _unet_forward(self, **kwargs) -> torch.Tensor:
|
||||||
|
return self.unet(**kwargs).sample
|
12
invokeai/backend/stable_diffusion/extension_callback_type.py
Normal file
12
invokeai/backend/stable_diffusion/extension_callback_type.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionCallbackType(Enum):
|
||||||
|
SETUP = "setup"
|
||||||
|
PRE_DENOISE_LOOP = "pre_denoise_loop"
|
||||||
|
POST_DENOISE_LOOP = "post_denoise_loop"
|
||||||
|
PRE_STEP = "pre_step"
|
||||||
|
POST_STEP = "post_step"
|
||||||
|
PRE_UNET = "pre_unet"
|
||||||
|
POST_UNET = "post_unet"
|
||||||
|
POST_APPLY_CFG = "post_apply_cfg"
|
60
invokeai/backend/stable_diffusion/extensions/base.py
Normal file
60
invokeai/backend/stable_diffusion/extensions/base.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Callable, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CallbackMetadata:
|
||||||
|
callback_type: ExtensionCallbackType
|
||||||
|
order: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CallbackFunctionWithMetadata:
|
||||||
|
metadata: CallbackMetadata
|
||||||
|
function: Callable[[DenoiseContext], None]
|
||||||
|
|
||||||
|
|
||||||
|
def callback(callback_type: ExtensionCallbackType, order: int = 0):
|
||||||
|
def _decorator(function):
|
||||||
|
function._ext_metadata = CallbackMetadata(
|
||||||
|
callback_type=callback_type,
|
||||||
|
order=order,
|
||||||
|
)
|
||||||
|
return function
|
||||||
|
|
||||||
|
return _decorator
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionBase:
|
||||||
|
def __init__(self):
|
||||||
|
self._callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
|
||||||
|
|
||||||
|
# Register all of the callback methods for this instance.
|
||||||
|
for func_name in dir(self):
|
||||||
|
func = getattr(self, func_name)
|
||||||
|
metadata = getattr(func, "_ext_metadata", None)
|
||||||
|
if metadata is not None and isinstance(metadata, CallbackMetadata):
|
||||||
|
if metadata.callback_type not in self._callbacks:
|
||||||
|
self._callbacks[metadata.callback_type] = []
|
||||||
|
self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func))
|
||||||
|
|
||||||
|
def get_callbacks(self):
|
||||||
|
return self._callbacks
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_extension(self, context: DenoiseContext):
|
||||||
|
yield None
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||||
|
yield None
|
63
invokeai/backend/stable_diffusion/extensions/preview.py
Normal file
63
invokeai/backend/stable_diffusion/extensions/preview.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: change event to accept image instead of latents
|
||||||
|
@dataclass
|
||||||
|
class PipelineIntermediateState:
|
||||||
|
step: int
|
||||||
|
order: int
|
||||||
|
total_steps: int
|
||||||
|
timestep: int
|
||||||
|
latents: torch.Tensor
|
||||||
|
predicted_original: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewExt(ExtensionBase):
|
||||||
|
def __init__(self, callback: Callable[[PipelineIntermediateState], None]):
|
||||||
|
super().__init__()
|
||||||
|
self.callback = callback
|
||||||
|
|
||||||
|
# do last so that all other changes shown
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
|
||||||
|
def initial_preview(self, ctx: DenoiseContext):
|
||||||
|
self.callback(
|
||||||
|
PipelineIntermediateState(
|
||||||
|
step=-1,
|
||||||
|
order=ctx.scheduler.order,
|
||||||
|
total_steps=len(ctx.inputs.timesteps),
|
||||||
|
timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it?
|
||||||
|
latents=ctx.latents,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# do last so that all other changes shown
|
||||||
|
@callback(ExtensionCallbackType.POST_STEP, order=1000)
|
||||||
|
def step_preview(self, ctx: DenoiseContext):
|
||||||
|
if hasattr(ctx.step_output, "denoised"):
|
||||||
|
predicted_original = ctx.step_output.denoised
|
||||||
|
elif hasattr(ctx.step_output, "pred_original_sample"):
|
||||||
|
predicted_original = ctx.step_output.pred_original_sample
|
||||||
|
else:
|
||||||
|
predicted_original = ctx.step_output.prev_sample
|
||||||
|
|
||||||
|
self.callback(
|
||||||
|
PipelineIntermediateState(
|
||||||
|
step=ctx.step_index,
|
||||||
|
order=ctx.scheduler.order,
|
||||||
|
total_steps=len(ctx.inputs.timesteps),
|
||||||
|
timestep=int(ctx.timestep), # TODO: is there any code which uses it?
|
||||||
|
latents=ctx.step_output.prev_sample,
|
||||||
|
predicted_original=predicted_original, # TODO: is there any reason for additional field?
|
||||||
|
)
|
||||||
|
)
|
71
invokeai/backend/stable_diffusion/extensions_manager.py
Normal file
71
invokeai/backend/stable_diffusion/extensions_manager.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import ExitStack, contextmanager
|
||||||
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionsManager:
|
||||||
|
def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
|
||||||
|
self._is_canceled = is_canceled
|
||||||
|
|
||||||
|
# A list of extensions in the order that they were added to the ExtensionsManager.
|
||||||
|
self._extensions: List[ExtensionBase] = []
|
||||||
|
self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
|
||||||
|
|
||||||
|
def add_extension(self, extension: ExtensionBase):
|
||||||
|
self._extensions.append(extension)
|
||||||
|
self._regenerate_ordered_callbacks()
|
||||||
|
|
||||||
|
def _regenerate_ordered_callbacks(self):
|
||||||
|
"""Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added."""
|
||||||
|
self._ordered_callbacks = {}
|
||||||
|
|
||||||
|
# Fill the ordered callbacks dictionary.
|
||||||
|
for extension in self._extensions:
|
||||||
|
for callback_type, callbacks in extension.get_callbacks().items():
|
||||||
|
if callback_type not in self._ordered_callbacks:
|
||||||
|
self._ordered_callbacks[callback_type] = []
|
||||||
|
self._ordered_callbacks[callback_type].extend(callbacks)
|
||||||
|
|
||||||
|
# Sort each callback list.
|
||||||
|
for callback_type, callbacks in self._ordered_callbacks.items():
|
||||||
|
# Note that sorted() is stable, so if two callbacks have the same order, the order that they extensions were
|
||||||
|
# added will be preserved.
|
||||||
|
self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order)
|
||||||
|
|
||||||
|
def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext):
|
||||||
|
if self._is_canceled and self._is_canceled():
|
||||||
|
raise CanceledException
|
||||||
|
|
||||||
|
callbacks = self._ordered_callbacks.get(callback_type, [])
|
||||||
|
for cb in callbacks:
|
||||||
|
cb.function(ctx)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_extensions(self, context: DenoiseContext):
|
||||||
|
if self._is_canceled and self._is_canceled():
|
||||||
|
raise CanceledException
|
||||||
|
|
||||||
|
with ExitStack() as exit_stack:
|
||||||
|
for ext in self._extensions:
|
||||||
|
exit_stack.enter_context(ext.patch_extension(context))
|
||||||
|
|
||||||
|
yield None
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||||
|
if self._is_canceled and self._is_canceled():
|
||||||
|
raise CanceledException
|
||||||
|
|
||||||
|
# TODO: create logic in PR with extension which uses it
|
||||||
|
yield None
|
46
tests/backend/stable_diffusion/extensions/test_base.py
Normal file
46
tests/backend/stable_diffusion/extensions/test_base.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||||
|
|
||||||
|
|
||||||
|
class MockExtension(ExtensionBase):
|
||||||
|
"""A mock ExtensionBase subclass for testing purposes."""
|
||||||
|
|
||||||
|
def __init__(self, x: int):
|
||||||
|
super().__init__()
|
||||||
|
self._x = x
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
|
def set_step_index(self, ctx: DenoiseContext):
|
||||||
|
ctx.step_index = self._x
|
||||||
|
|
||||||
|
|
||||||
|
def test_extension_base_callback_registration():
|
||||||
|
"""Test that a callback can be successfully registered with an extension."""
|
||||||
|
val = 5
|
||||||
|
mock_extension = MockExtension(val)
|
||||||
|
|
||||||
|
mock_ctx = mock.MagicMock()
|
||||||
|
|
||||||
|
callbacks = mock_extension.get_callbacks()
|
||||||
|
pre_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.PRE_DENOISE_LOOP, [])
|
||||||
|
assert len(pre_denoise_loop_cbs) == 1
|
||||||
|
|
||||||
|
# Call the mock callback.
|
||||||
|
pre_denoise_loop_cbs[0].function(mock_ctx)
|
||||||
|
|
||||||
|
# Confirm that the callback ran.
|
||||||
|
assert mock_ctx.step_index == val
|
||||||
|
|
||||||
|
|
||||||
|
def test_extension_base_empty_callback_type():
|
||||||
|
"""Test that an empty list is returned when no callbacks are registered for a given callback type."""
|
||||||
|
mock_extension = MockExtension(5)
|
||||||
|
|
||||||
|
# There should be no callbacks registered for POST_DENOISE_LOOP.
|
||||||
|
callbacks = mock_extension.get_callbacks()
|
||||||
|
|
||||||
|
post_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.POST_DENOISE_LOOP, [])
|
||||||
|
assert len(post_denoise_loop_cbs) == 0
|
112
tests/backend/stable_diffusion/test_extension_manager.py
Normal file
112
tests/backend/stable_diffusion/test_extension_manager.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||||
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
|
|
||||||
|
|
||||||
|
class MockExtension(ExtensionBase):
|
||||||
|
"""A mock ExtensionBase subclass for testing purposes."""
|
||||||
|
|
||||||
|
def __init__(self, x: int):
|
||||||
|
super().__init__()
|
||||||
|
self._x = x
|
||||||
|
|
||||||
|
# Note that order is not specified. It should default to 0.
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
|
def set_step_index(self, ctx: DenoiseContext):
|
||||||
|
ctx.step_index = self._x
|
||||||
|
|
||||||
|
|
||||||
|
class MockExtensionLate(ExtensionBase):
|
||||||
|
"""A mock ExtensionBase subclass with a high order value on its PRE_DENOISE_LOOP callback."""
|
||||||
|
|
||||||
|
def __init__(self, x: int):
|
||||||
|
super().__init__()
|
||||||
|
self._x = x
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
|
||||||
|
def set_step_index(self, ctx: DenoiseContext):
|
||||||
|
ctx.step_index = self._x
|
||||||
|
|
||||||
|
|
||||||
|
def test_extension_manager_run_callback():
|
||||||
|
"""Test that run_callback runs all callbacks for the given callback type."""
|
||||||
|
|
||||||
|
em = ExtensionsManager()
|
||||||
|
mock_extension_1 = MockExtension(1)
|
||||||
|
em.add_extension(mock_extension_1)
|
||||||
|
|
||||||
|
mock_ctx = mock.MagicMock()
|
||||||
|
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
|
||||||
|
|
||||||
|
assert mock_ctx.step_index == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_extension_manager_run_callback_no_callbacks():
|
||||||
|
"""Test that run_callback does not raise an error when there are no callbacks for the given callback type."""
|
||||||
|
em = ExtensionsManager()
|
||||||
|
mock_ctx = mock.MagicMock()
|
||||||
|
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["extension_1", "extension_2"],
|
||||||
|
# Regardless of initialization order, we expect MockExtensionLate to run last.
|
||||||
|
[(MockExtension(1), MockExtensionLate(2)), (MockExtensionLate(2), MockExtension(1))],
|
||||||
|
)
|
||||||
|
def test_extension_manager_order_callbacks(extension_1: ExtensionBase, extension_2: ExtensionBase):
|
||||||
|
"""Test that run_callback runs callbacks in the correct order."""
|
||||||
|
em = ExtensionsManager()
|
||||||
|
em.add_extension(extension_1)
|
||||||
|
em.add_extension(extension_2)
|
||||||
|
|
||||||
|
mock_ctx = mock.MagicMock()
|
||||||
|
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
|
||||||
|
|
||||||
|
assert mock_ctx.step_index == 2
|
||||||
|
|
||||||
|
|
||||||
|
class MockExtensionStableSort(ExtensionBase):
|
||||||
|
"""A mock extension with three PRE_DENOISE_LOOP callbacks, each with a different order value."""
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=-1000)
|
||||||
|
def early(self, ctx: DenoiseContext):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
|
def middle(self, ctx: DenoiseContext):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
|
||||||
|
def late(self, ctx: DenoiseContext):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_extension_manager_stable_sort():
|
||||||
|
"""Test that when two callbacks have the same 'order' value, they are sorted based on the order they were added to
|
||||||
|
the ExtensionsManager."""
|
||||||
|
|
||||||
|
em = ExtensionsManager()
|
||||||
|
|
||||||
|
mock_extension_1 = MockExtensionStableSort()
|
||||||
|
mock_extension_2 = MockExtensionStableSort()
|
||||||
|
|
||||||
|
em.add_extension(mock_extension_1)
|
||||||
|
em.add_extension(mock_extension_2)
|
||||||
|
|
||||||
|
expected_order = [
|
||||||
|
mock_extension_1.early,
|
||||||
|
mock_extension_2.early,
|
||||||
|
mock_extension_1.middle,
|
||||||
|
mock_extension_2.middle,
|
||||||
|
mock_extension_1.late,
|
||||||
|
mock_extension_2.late,
|
||||||
|
]
|
||||||
|
|
||||||
|
# It's not ideal that we are accessing a private attribute here, but this was the most direct way to assert the
|
||||||
|
# desired behaviour.
|
||||||
|
assert [cb.function for cb in em._ordered_callbacks[ExtensionCallbackType.PRE_DENOISE_LOOP]] == expected_order
|
Loading…
Reference in New Issue
Block a user