Base code from draft PR

This commit is contained in:
Sergey Borisov 2024-07-12 20:31:26 +03:00
parent 712cf00a82
commit 9cc852cf7f
8 changed files with 781 additions and 11 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect import inspect
import os
from contextlib import ExitStack from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
@ -39,6 +40,7 @@ from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.diffusers_pipeline import ( from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData, ControlNetData,
StableDiffusionGeneratorPipeline, StableDiffusionGeneratorPipeline,
@ -53,6 +55,9 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningData, TextConditioningData,
TextConditioningRegions, TextConditioningRegions,
) )
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
@ -314,9 +319,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
context: InvocationContext, context: InvocationContext,
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]], positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]], negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
unet: UNet2DConditionModel,
latent_height: int, latent_height: int,
latent_width: int, latent_width: int,
device: torch.device,
dtype: torch.dtype,
cfg_scale: float | list[float], cfg_scale: float | list[float],
steps: int, steps: int,
cfg_rescale_multiplier: float, cfg_rescale_multiplier: float,
@ -330,10 +336,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
uncond_list = [uncond_list] uncond_list = [uncond_list]
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks( cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype cond_list, context, device, dtype
) )
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks( uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype uncond_list, context, device, dtype
) )
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings( cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
@ -341,14 +347,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
masks=cond_text_embedding_masks, masks=cond_text_embedding_masks,
latent_height=latent_height, latent_height=latent_height,
latent_width=latent_width, latent_width=latent_width,
dtype=unet.dtype, dtype=dtype,
) )
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings( uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings, text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks, masks=uncond_text_embedding_masks,
latent_height=latent_height, latent_height=latent_height,
latent_width=latent_width, latent_width=latent_width,
dtype=unet.dtype, dtype=dtype,
) )
if isinstance(cfg_scale, list): if isinstance(cfg_scale, list):
@ -707,9 +713,99 @@ class DenoiseLatentsInvocation(BaseInvocation):
return seed, noise, latents return seed, noise, latents
def invoke(self, context: InvocationContext) -> LatentsOutput:
if os.environ.get("USE_MODULAR_DENOISE", False):
return self._new_invoke(context)
else:
return self._old_invoke(context)
@torch.no_grad() @torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers. @SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput: def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
with ExitStack() as exit_stack:
ext_manager = ExtensionsManager()
device = TorchDevice.choose_torch_device()
dtype = TorchDevice.choose_torch_dtype()
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)
_, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
cfg_scale=self.cfg_scale,
steps=self.steps,
latent_height=latent_height,
latent_width=latent_width,
device=device,
dtype=dtype,
# TODO: old backend, remove
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
seed=seed,
device=device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
)
denoise_ctx = DenoiseContext(
latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
unet=None,
scheduler=scheduler,
)
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
# ext: t2i/ip adapter
ext_manager.modifiers.pre_unet_load(denoise_ctx, ext_manager)
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (model_state_dict, unet),
# ext: controlnet
ext_manager.patch_attention_processor(unet, CustomAttnProcessor2_0),
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet),
):
sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu") # TODO: detach?
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
@ -788,7 +884,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
context=context, context=context,
positive_conditioning_field=self.positive_conditioning, positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning, negative_conditioning_field=self.negative_conditioning,
unet=unet, device=unet.device,
dtype=unet.dtype,
latent_height=latent_height, latent_height=latent_height,
latent_width=latent_width, latent_width=latent_width,
cfg_scale=self.cfg_scale, cfg_scale=self.cfg_scale,

View File

@ -0,0 +1,60 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import torch
from diffusers import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
@dataclass
class UNetKwargs:
sample: torch.Tensor
timestep: Union[torch.Tensor, float, int]
encoder_hidden_states: torch.Tensor
class_labels: Optional[torch.Tensor] = None
timestep_cond: Optional[torch.Tensor] = None
attention_mask: Optional[torch.Tensor] = None
cross_attention_kwargs: Optional[Dict[str, Any]] = None
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
mid_block_additional_residual: Optional[torch.Tensor] = None
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
encoder_attention_mask: Optional[torch.Tensor] = None
# return_dict: bool = True
@dataclass
class DenoiseContext:
latents: torch.Tensor
scheduler_step_kwargs: dict[str, Any]
conditioning_data: TextConditioningData
noise: Optional[torch.Tensor]
seed: int
timesteps: torch.Tensor
init_timestep: torch.Tensor
scheduler: SchedulerMixin
unet: Optional[UNet2DConditionModel] = None
orig_latents: Optional[torch.Tensor] = None
step_index: Optional[int] = None
timestep: Optional[torch.Tensor] = None
unet_kwargs: Optional[UNetKwargs] = None
step_output: Optional[SchedulerOutput] = None
latent_model_input: Optional[torch.Tensor] = None
conditioning_mode: Optional[str] = None
negative_noise_pred: Optional[torch.Tensor] = None
positive_noise_pred: Optional[torch.Tensor] = None
noise_pred: Optional[torch.Tensor] = None
extra: dict = field(default_factory=dict)
def __delattr__(self, name: str):
setattr(self, name, None)

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Union
import torch import torch
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
@dataclass @dataclass
@ -103,7 +104,7 @@ class TextConditioningData:
uncond_regions: Optional[TextConditioningRegions], uncond_regions: Optional[TextConditioningRegions],
cond_regions: Optional[TextConditioningRegions], cond_regions: Optional[TextConditioningRegions],
guidance_scale: Union[float, List[float]], guidance_scale: Union[float, List[float]],
guidance_rescale_multiplier: float = 0, guidance_rescale_multiplier: float = 0, # TODO: old backend, remove
): ):
self.uncond_text = uncond_text self.uncond_text = uncond_text
self.cond_text = cond_text self.cond_text = cond_text
@ -114,6 +115,7 @@ class TextConditioningData:
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate # Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality. # images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
self.guidance_scale = guidance_scale self.guidance_scale = guidance_scale
# TODO: old backend, remove
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7. # For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). # See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
self.guidance_rescale_multiplier = guidance_rescale_multiplier self.guidance_rescale_multiplier = guidance_rescale_multiplier
@ -121,3 +123,127 @@ class TextConditioningData:
def is_sdxl(self): def is_sdxl(self):
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo) assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
return isinstance(self.cond_text, SDXLConditioningInfo) return isinstance(self.cond_text, SDXLConditioningInfo)
def to_unet_kwargs(self, unet_kwargs, conditioning_mode):
if conditioning_mode == "both":
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(
self.uncond_text.embeds, self.cond_text.embeds
)
elif conditioning_mode == "positive":
encoder_hidden_states = self.cond_text.embeds
encoder_attention_mask = None
else: # elif conditioning_mode == "negative":
encoder_hidden_states = self.uncond_text.embeds
encoder_attention_mask = None
unet_kwargs.encoder_hidden_states = encoder_hidden_states
unet_kwargs.encoder_attention_mask = encoder_attention_mask
if self.is_sdxl():
if conditioning_mode == "negative":
added_cond_kwargs = dict( # noqa: C408
text_embeds=self.cond_text.pooled_embeds,
time_ids=self.cond_text.add_time_ids,
)
elif conditioning_mode == "positive":
added_cond_kwargs = dict( # noqa: C408
text_embeds=self.uncond_text.pooled_embeds,
time_ids=self.uncond_text.add_time_ids,
)
else: # elif conditioning_mode == "both":
added_cond_kwargs = dict( # noqa: C408
text_embeds=torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
self.uncond_text.pooled_embeds,
self.cond_text.pooled_embeds,
],
),
time_ids=torch.cat(
[
self.uncond_text.add_time_ids,
self.cond_text.add_time_ids,
],
),
)
unet_kwargs.added_cond_kwargs = added_cond_kwargs
if self.cond_regions is not None or self.uncond_regions is not None:
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
_tmp_regions = self.cond_regions if self.cond_regions is not None else self.uncond_regions
_, _, h, w = _tmp_regions.masks.shape
dtype = self.cond_text.embeds.dtype
device = self.cond_text.embeds.device
regions = []
for c, r in [
(self.uncond_text, self.uncond_regions),
(self.cond_text, self.cond_regions),
]:
if r is None:
# Create a dummy mask and range for text conditioning that doesn't have region masks.
r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=dtype),
ranges=[Range(start=0, end=c.embeds.shape[1])],
)
regions.append(r)
if unet_kwargs.cross_attention_kwargs is None:
unet_kwargs.cross_attention_kwargs = {}
unet_kwargs.cross_attention_kwargs.update(
regional_prompt_data=RegionalPromptData(regions=regions, device=device, dtype=dtype),
)
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones(
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
)
if cond.shape[1] < max_len:
conditioning_attention_mask = torch.cat(
[
conditioning_attention_mask,
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
],
dim=1,
)
cond = torch.cat(
[
cond,
torch.zeros(
(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
device=cond.device,
dtype=cond.dtype,
),
],
dim=1,
)
if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat(
[
encoder_attention_mask,
conditioning_attention_mask,
]
)
return cond, encoder_attention_mask
encoder_attention_mask = None
if unconditioning.shape[1] != conditioning.shape[1]:
max_len = max(unconditioning.shape[1], conditioning.shape[1])
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
return torch.cat([unconditioning, conditioning]), encoder_attention_mask

View File

@ -1,9 +1,14 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( if TYPE_CHECKING:
TextConditioningRegions, from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
) TextConditioningRegions,
)
class RegionalPromptData: class RegionalPromptData:

View File

@ -0,0 +1,220 @@
from __future__ import annotations
import PIL.Image
import torch
import torchvision.transforms as T
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from tqdm.auto import tqdm
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
def trim_to_multiple_of(*args, multiple_of=8):
return tuple((x - x % multiple_of) for x in args)
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool = True, multiple_of=8) -> torch.FloatTensor:
"""
:param image: input image
:param normalize: scale the range to [-1, 1] instead of [0, 1]
:param multiple_of: resize the input so both dimensions are a multiple of this
"""
w, h = trim_to_multiple_of(*image.size, multiple_of=multiple_of)
transformation = T.Compose(
[
T.Resize((h, w), T.InterpolationMode.LANCZOS, antialias=True),
T.ToTensor(),
]
)
tensor = transformation(image)
if normalize:
tensor = tensor * 2.0 - 1.0
return tensor
class StableDiffusionBackend:
def __init__(
self,
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
):
self.unet = unet
self.scheduler = scheduler
config = get_config()
self.sequential_guidance = config.sequential_guidance
def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
if ctx.init_timestep.shape[0] == 0:
return ctx.latents
ctx.orig_latents = ctx.latents.clone()
if ctx.noise is not None:
batch_size = ctx.latents.shape[0]
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
ctx.latents = ctx.scheduler.add_noise(ctx.latents, ctx.noise, ctx.init_timestep.expand(batch_size))
# if no work to do, return latents
if ctx.timesteps.shape[0] == 0:
return ctx.latents
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
# ext: preview[pre_denoise_loop, priority=low]
ext_manager.modifiers.pre_denoise_loop(ctx)
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.timesteps)): # noqa: B020
# ext: inpaint (apply mask to latents on non-inpaint models)
ext_manager.modifiers.pre_step(ctx)
# ext: tiles? [override: step]
ctx.step_output = ext_manager.overrides.step(self.step, ctx, ext_manager)
# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
# ext: preview[post_step, priority=low]
ext_manager.modifiers.post_step(ctx)
ctx.latents = ctx.step_output.prev_sample
# ext: inpaint[post_denoise_loop] (restore unmasked part)
ext_manager.modifiers.post_denoise_loop(ctx)
return ctx.latents
@torch.inference_mode()
def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> SchedulerOutput:
ctx.latent_model_input = ctx.scheduler.scale_model_input(ctx.latents, ctx.timestep)
if self.sequential_guidance:
conditioning_call = self._apply_standard_conditioning_sequentially
else:
conditioning_call = self._apply_standard_conditioning
# not sure if here needed override
ctx.negative_noise_pred, ctx.positive_noise_pred = conditioning_call(ctx, ext_manager)
# ext: override combine_noise
ctx.noise_pred = ext_manager.overrides.combine_noise(self.combine_noise, ctx)
# ext: cfg_rescale [modify_noise_prediction]
ext_manager.modifiers.modify_noise_prediction(ctx)
# compute the previous noisy sample x_t -> x_t-1
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.scheduler_step_kwargs)
# del locals
del ctx.latent_model_input
del ctx.negative_noise_pred
del ctx.positive_noise_pred
del ctx.noise_pred
return step_output
@staticmethod
def combine_noise(ctx: DenoiseContext) -> torch.Tensor:
guidance_scale = ctx.conditioning_data.guidance_scale
if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[ctx.step_index]
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
def _apply_standard_conditioning(
self, ctx: DenoiseContext, ext_manager: ExtensionsManager
) -> tuple[torch.Tensor, torch.Tensor]:
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
the cost of higher memory usage.
"""
ctx.unet_kwargs = UNetKwargs(
sample=torch.cat([ctx.latent_model_input] * 2),
timestep=ctx.timestep,
encoder_hidden_states=None, # set later by conditoning
cross_attention_kwargs=dict( # noqa: C408
percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps,
),
)
ctx.conditioning_mode = "both"
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
# ext: controlnet/ip/t2i [pre_unet_forward]
ext_manager.modifiers.pre_unet_forward(ctx)
# ext: inpaint [pre_unet_forward, priority=low]
# or
# ext: inpaint [override: unet_forward]
both_results = self._unet_forward(**vars(ctx.unet_kwargs))
negative_next_x, positive_next_x = both_results.chunk(2)
# del locals
del ctx.unet_kwargs
del ctx.conditioning_mode
return negative_next_x, positive_next_x
def _apply_standard_conditioning_sequentially(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
slower execution speed.
"""
###################
# Negative pass
###################
ctx.unet_kwargs = UNetKwargs(
sample=ctx.latent_model_input,
timestep=ctx.timestep,
encoder_hidden_states=None, # set later by conditoning
cross_attention_kwargs=dict( # noqa: C408
percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps,
),
)
ctx.conditioning_mode = "negative"
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "negative")
# ext: controlnet/ip/t2i [pre_unet_forward]
ext_manager.modifiers.pre_unet_forward(ctx)
# ext: inpaint [pre_unet_forward, priority=low]
# or
# ext: inpaint [override: unet_forward]
negative_next_x = self._unet_forward(**vars(ctx.unet_kwargs))
del ctx.unet_kwargs
del ctx.conditioning_mode
# TODO: gc.collect() ?
###################
# Positive pass
###################
ctx.unet_kwargs = UNetKwargs(
sample=ctx.latent_model_input,
timestep=ctx.timestep,
encoder_hidden_states=None, # set later by conditoning
cross_attention_kwargs=dict( # noqa: C408
percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps,
),
)
ctx.conditioning_mode = "positive"
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "positive")
# ext: controlnet/ip/t2i [pre_unet_forward]
ext_manager.modifiers.pre_unet_forward(ctx)
# ext: inpaint [pre_unet_forward, priority=low]
# or
# ext: inpaint [override: unet_forward]
positive_next_x = self._unet_forward(**vars(ctx.unet_kwargs))
del ctx.unet_kwargs
del ctx.conditioning_mode
# TODO: gc.collect() ?
return negative_next_x, positive_next_x
def _unet_forward(self, **kwargs) -> torch.Tensor:
return self.unet(**kwargs).sample

View File

@ -0,0 +1,9 @@
"""
Initialization file for the invokeai.backend.stable_diffusion.extensions package
"""
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
__all__ = [
"ExtensionBase",
]

View File

@ -0,0 +1,58 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional
import torch
from diffusers import UNet2DConditionModel
@dataclass
class InjectionInfo:
type: str
name: str
order: Optional[str]
function: Callable
def modifier(name: str, order: str = "any"):
def _decorator(func):
func.__inj_info__ = {
"type": "modifier",
"name": name,
"order": order,
}
return func
return _decorator
def override(name: str):
def _decorator(func):
func.__inj_info__ = {
"type": "override",
"name": name,
"order": None,
}
return func
return _decorator
class ExtensionBase:
def __init__(self, priority: int):
self.priority = priority
self.injections: List[InjectionInfo] = []
for func_name in dir(self):
func = getattr(self, func_name)
if not callable(func) or not hasattr(func, "__inj_info__"):
continue
self.injections.append(InjectionInfo(**func.__inj_info__, function=func))
@contextmanager
def patch_attention_processor(self, attention_processor_cls: object):
yield None
@contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
yield None

View File

@ -0,0 +1,195 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from contextlib import ExitStack, contextmanager
from functools import partial
from typing import TYPE_CHECKING, Callable, Dict
import torch
from diffusers import UNet2DConditionModel
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extensions import ExtensionBase
class ExtModifiersApi(ABC):
@abstractmethod
def pre_denoise_loop(self, ctx: DenoiseContext):
pass
@abstractmethod
def post_denoise_loop(self, ctx: DenoiseContext):
pass
@abstractmethod
def pre_step(self, ctx: DenoiseContext):
pass
@abstractmethod
def post_step(self, ctx: DenoiseContext):
pass
@abstractmethod
def modify_noise_prediction(self, ctx: DenoiseContext):
pass
@abstractmethod
def pre_unet_forward(self, ctx: DenoiseContext):
pass
@abstractmethod
def pre_unet_load(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
pass
class ExtOverridesApi(ABC):
@abstractmethod
def step(self, orig_func: Callable, ctx: DenoiseContext, ext_manager: ExtensionsManager):
pass
@abstractmethod
def combine_noise(self, orig_func: Callable, ctx: DenoiseContext):
pass
class ProxyCallsClass:
def __init__(self, handler):
self._handler = handler
def __getattr__(self, item):
return partial(self._handler, item)
class ModifierInjectionPoint:
def __init__(self):
self.first = []
self.any = []
self.last = []
def add(self, func: Callable, order: str):
if order == "first":
self.first.append(func)
elif order == "last":
self.last.append(func)
else: # elif order == "any":
self.any.append(func)
def __call__(self, *args, **kwargs):
for func in self.first:
func(*args, **kwargs)
for func in self.any:
func(*args, **kwargs)
for func in reversed(self.last):
func(*args, **kwargs)
class ExtensionsManager:
def __init__(self):
self.extensions = []
self._overrides = {}
self._modifiers = {}
self.modifiers: ExtModifiersApi = ProxyCallsClass(self.call_modifier)
self.overrides: ExtOverridesApi = ProxyCallsClass(self.call_override)
def add_extension(self, ext: ExtensionBase):
self.extensions.append(ext)
ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority)
self._overrides.clear()
self._modifiers.clear()
for ext in ordered_extensions:
for inj_info in ext.injections:
if inj_info.type == "modifier":
if inj_info.name not in self._modifiers:
self._modifiers[inj_info.name] = ModifierInjectionPoint()
self._modifiers[inj_info.name].add(inj_info.function, inj_info.order)
else:
if inj_info.name in self._overrides:
raise Exception(f"Already overloaded - {inj_info.name}")
self._overrides[inj_info.name] = inj_info.function
def call_modifier(self, name: str, *args, **kwargs):
if name in self._modifiers:
self._modifiers[name](*args, **kwargs)
def call_override(self, name: str, orig_func: Callable, *args, **kwargs):
if name in self._overrides:
return self._overrides[name](orig_func, *args, **kwargs)
else:
return orig_func(*args, **kwargs)
# TODO: is there any need in such high abstarction
# @contextmanager
# def patch_extensions(self):
# exit_stack = ExitStack()
# try:
# for ext in self.extensions:
# exit_stack.enter_context(ext.patch_extension(self))
#
# yield None
#
# finally:
# exit_stack.close()
@contextmanager
def patch_attention_processor(self, unet: UNet2DConditionModel, attn_processor_cls: object):
unet_orig_processors = unet.attn_processors
exit_stack = ExitStack()
try:
# just to be sure that attentions have not same processor instance
attn_procs = {}
for name in unet.attn_processors.keys():
attn_procs[name] = attn_processor_cls()
unet.set_attn_processor(attn_procs)
for ext in self.extensions:
exit_stack.enter_context(ext.patch_attention_processor(attn_processor_cls))
yield None
finally:
unet.set_attn_processor(unet_orig_processors)
exit_stack.close()
@contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
exit_stack = ExitStack()
try:
changed_keys = set()
changed_unknown_keys = {}
ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority)
for ext in ordered_extensions:
patch_result = exit_stack.enter_context(ext.patch_unet(state_dict, unet))
if patch_result is None:
continue
new_keys, new_unk_keys = patch_result
changed_keys.update(new_keys)
# skip already seen keys, as new weight might be changed
for k, v in new_unk_keys.items():
if k in changed_unknown_keys:
continue
changed_unknown_keys[k] = v
yield None
finally:
exit_stack.close()
assert hasattr(unet, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad():
for module_key in changed_keys:
weight = state_dict[module_key]
unet.get_submodule(module_key).weight.copy_(
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
)
for module_key, weight in changed_unknown_keys.items():
unet.get_submodule(module_key).weight.copy_(
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
)