Base code from draft PR

This commit is contained in:
Sergey Borisov 2024-07-12 20:31:26 +03:00
parent 712cf00a82
commit 9cc852cf7f
8 changed files with 781 additions and 11 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect
import os
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
@ -39,6 +40,7 @@ from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
StableDiffusionGeneratorPipeline,
@ -53,6 +55,9 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import TorchDevice
@ -314,9 +319,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
context: InvocationContext,
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
unet: UNet2DConditionModel,
latent_height: int,
latent_width: int,
device: torch.device,
dtype: torch.dtype,
cfg_scale: float | list[float],
steps: int,
cfg_rescale_multiplier: float,
@ -330,10 +336,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
uncond_list = [uncond_list]
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
cond_list, context, device, dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype
uncond_list, context, device, dtype
)
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
@ -341,14 +347,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
dtype=dtype,
)
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
dtype=dtype,
)
if isinstance(cfg_scale, list):
@ -707,9 +713,99 @@ class DenoiseLatentsInvocation(BaseInvocation):
return seed, noise, latents
def invoke(self, context: InvocationContext) -> LatentsOutput:
if os.environ.get("USE_MODULAR_DENOISE", False):
return self._new_invoke(context)
else:
return self._old_invoke(context)
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
with ExitStack() as exit_stack:
ext_manager = ExtensionsManager()
device = TorchDevice.choose_torch_device()
dtype = TorchDevice.choose_torch_dtype()
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)
_, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
cfg_scale=self.cfg_scale,
steps=self.steps,
latent_height=latent_height,
latent_width=latent_width,
device=device,
dtype=dtype,
# TODO: old backend, remove
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
seed=seed,
device=device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
)
denoise_ctx = DenoiseContext(
latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
unet=None,
scheduler=scheduler,
)
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
# ext: t2i/ip adapter
ext_manager.modifiers.pre_unet_load(denoise_ctx, ext_manager)
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (model_state_dict, unet),
# ext: controlnet
ext_manager.patch_attention_processor(unet, CustomAttnProcessor2_0),
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet),
):
sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu") # TODO: detach?
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
@ -788,7 +884,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
device=unet.device,
dtype=unet.dtype,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,

View File

@ -0,0 +1,60 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import torch
from diffusers import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
@dataclass
class UNetKwargs:
sample: torch.Tensor
timestep: Union[torch.Tensor, float, int]
encoder_hidden_states: torch.Tensor
class_labels: Optional[torch.Tensor] = None
timestep_cond: Optional[torch.Tensor] = None
attention_mask: Optional[torch.Tensor] = None
cross_attention_kwargs: Optional[Dict[str, Any]] = None
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
mid_block_additional_residual: Optional[torch.Tensor] = None
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
encoder_attention_mask: Optional[torch.Tensor] = None
# return_dict: bool = True
@dataclass
class DenoiseContext:
latents: torch.Tensor
scheduler_step_kwargs: dict[str, Any]
conditioning_data: TextConditioningData
noise: Optional[torch.Tensor]
seed: int
timesteps: torch.Tensor
init_timestep: torch.Tensor
scheduler: SchedulerMixin
unet: Optional[UNet2DConditionModel] = None
orig_latents: Optional[torch.Tensor] = None
step_index: Optional[int] = None
timestep: Optional[torch.Tensor] = None
unet_kwargs: Optional[UNetKwargs] = None
step_output: Optional[SchedulerOutput] = None
latent_model_input: Optional[torch.Tensor] = None
conditioning_mode: Optional[str] = None
negative_noise_pred: Optional[torch.Tensor] = None
positive_noise_pred: Optional[torch.Tensor] = None
noise_pred: Optional[torch.Tensor] = None
extra: dict = field(default_factory=dict)
def __delattr__(self, name: str):
setattr(self, name, None)

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Union
import torch
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
@dataclass
@ -103,7 +104,7 @@ class TextConditioningData:
uncond_regions: Optional[TextConditioningRegions],
cond_regions: Optional[TextConditioningRegions],
guidance_scale: Union[float, List[float]],
guidance_rescale_multiplier: float = 0,
guidance_rescale_multiplier: float = 0, # TODO: old backend, remove
):
self.uncond_text = uncond_text
self.cond_text = cond_text
@ -114,6 +115,7 @@ class TextConditioningData:
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
self.guidance_scale = guidance_scale
# TODO: old backend, remove
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
self.guidance_rescale_multiplier = guidance_rescale_multiplier
@ -121,3 +123,127 @@ class TextConditioningData:
def is_sdxl(self):
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
return isinstance(self.cond_text, SDXLConditioningInfo)
def to_unet_kwargs(self, unet_kwargs, conditioning_mode):
if conditioning_mode == "both":
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(
self.uncond_text.embeds, self.cond_text.embeds
)
elif conditioning_mode == "positive":
encoder_hidden_states = self.cond_text.embeds
encoder_attention_mask = None
else: # elif conditioning_mode == "negative":
encoder_hidden_states = self.uncond_text.embeds
encoder_attention_mask = None
unet_kwargs.encoder_hidden_states = encoder_hidden_states
unet_kwargs.encoder_attention_mask = encoder_attention_mask
if self.is_sdxl():
if conditioning_mode == "negative":
added_cond_kwargs = dict( # noqa: C408
text_embeds=self.cond_text.pooled_embeds,
time_ids=self.cond_text.add_time_ids,
)
elif conditioning_mode == "positive":
added_cond_kwargs = dict( # noqa: C408
text_embeds=self.uncond_text.pooled_embeds,
time_ids=self.uncond_text.add_time_ids,
)
else: # elif conditioning_mode == "both":
added_cond_kwargs = dict( # noqa: C408
text_embeds=torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
self.uncond_text.pooled_embeds,
self.cond_text.pooled_embeds,
],
),
time_ids=torch.cat(
[
self.uncond_text.add_time_ids,
self.cond_text.add_time_ids,
],
),
)
unet_kwargs.added_cond_kwargs = added_cond_kwargs
if self.cond_regions is not None or self.uncond_regions is not None:
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
_tmp_regions = self.cond_regions if self.cond_regions is not None else self.uncond_regions
_, _, h, w = _tmp_regions.masks.shape
dtype = self.cond_text.embeds.dtype
device = self.cond_text.embeds.device
regions = []
for c, r in [
(self.uncond_text, self.uncond_regions),
(self.cond_text, self.cond_regions),
]:
if r is None:
# Create a dummy mask and range for text conditioning that doesn't have region masks.
r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=dtype),
ranges=[Range(start=0, end=c.embeds.shape[1])],
)
regions.append(r)
if unet_kwargs.cross_attention_kwargs is None:
unet_kwargs.cross_attention_kwargs = {}
unet_kwargs.cross_attention_kwargs.update(
regional_prompt_data=RegionalPromptData(regions=regions, device=device, dtype=dtype),
)
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones(
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
)
if cond.shape[1] < max_len:
conditioning_attention_mask = torch.cat(
[
conditioning_attention_mask,
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
],
dim=1,
)
cond = torch.cat(
[
cond,
torch.zeros(
(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
device=cond.device,
dtype=cond.dtype,
),
],
dim=1,
)
if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat(
[
encoder_attention_mask,
conditioning_attention_mask,
]
)
return cond, encoder_attention_mask
encoder_attention_mask = None
if unconditioning.shape[1] != conditioning.shape[1]:
max_len = max(unconditioning.shape[1], conditioning.shape[1])
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
return torch.cat([unconditioning, conditioning]), encoder_attention_mask

View File

@ -1,9 +1,14 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningRegions,
)
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningRegions,
)
class RegionalPromptData:

View File

@ -0,0 +1,220 @@
from __future__ import annotations
import PIL.Image
import torch
import torchvision.transforms as T
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from tqdm.auto import tqdm
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
def trim_to_multiple_of(*args, multiple_of=8):
return tuple((x - x % multiple_of) for x in args)
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool = True, multiple_of=8) -> torch.FloatTensor:
"""
:param image: input image
:param normalize: scale the range to [-1, 1] instead of [0, 1]
:param multiple_of: resize the input so both dimensions are a multiple of this
"""
w, h = trim_to_multiple_of(*image.size, multiple_of=multiple_of)
transformation = T.Compose(
[
T.Resize((h, w), T.InterpolationMode.LANCZOS, antialias=True),
T.ToTensor(),
]
)
tensor = transformation(image)
if normalize:
tensor = tensor * 2.0 - 1.0
return tensor
class StableDiffusionBackend:
def __init__(
self,
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
):
self.unet = unet
self.scheduler = scheduler
config = get_config()
self.sequential_guidance = config.sequential_guidance
def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
if ctx.init_timestep.shape[0] == 0:
return ctx.latents
ctx.orig_latents = ctx.latents.clone()
if ctx.noise is not None:
batch_size = ctx.latents.shape[0]
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
ctx.latents = ctx.scheduler.add_noise(ctx.latents, ctx.noise, ctx.init_timestep.expand(batch_size))
# if no work to do, return latents
if ctx.timesteps.shape[0] == 0:
return ctx.latents
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
# ext: preview[pre_denoise_loop, priority=low]
ext_manager.modifiers.pre_denoise_loop(ctx)
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.timesteps)): # noqa: B020
# ext: inpaint (apply mask to latents on non-inpaint models)
ext_manager.modifiers.pre_step(ctx)
# ext: tiles? [override: step]
ctx.step_output = ext_manager.overrides.step(self.step, ctx, ext_manager)
# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
# ext: preview[post_step, priority=low]
ext_manager.modifiers.post_step(ctx)
ctx.latents = ctx.step_output.prev_sample
# ext: inpaint[post_denoise_loop] (restore unmasked part)
ext_manager.modifiers.post_denoise_loop(ctx)
return ctx.latents
@torch.inference_mode()
def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> SchedulerOutput:
ctx.latent_model_input = ctx.scheduler.scale_model_input(ctx.latents, ctx.timestep)
if self.sequential_guidance:
conditioning_call = self._apply_standard_conditioning_sequentially
else:
conditioning_call = self._apply_standard_conditioning
# not sure if here needed override
ctx.negative_noise_pred, ctx.positive_noise_pred = conditioning_call(ctx, ext_manager)
# ext: override combine_noise
ctx.noise_pred = ext_manager.overrides.combine_noise(self.combine_noise, ctx)
# ext: cfg_rescale [modify_noise_prediction]
ext_manager.modifiers.modify_noise_prediction(ctx)
# compute the previous noisy sample x_t -> x_t-1
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.scheduler_step_kwargs)
# del locals
del ctx.latent_model_input
del ctx.negative_noise_pred
del ctx.positive_noise_pred
del ctx.noise_pred
return step_output
@staticmethod
def combine_noise(ctx: DenoiseContext) -> torch.Tensor:
guidance_scale = ctx.conditioning_data.guidance_scale
if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[ctx.step_index]
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
def _apply_standard_conditioning(
self, ctx: DenoiseContext, ext_manager: ExtensionsManager
) -> tuple[torch.Tensor, torch.Tensor]:
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
the cost of higher memory usage.
"""
ctx.unet_kwargs = UNetKwargs(
sample=torch.cat([ctx.latent_model_input] * 2),
timestep=ctx.timestep,
encoder_hidden_states=None, # set later by conditoning
cross_attention_kwargs=dict( # noqa: C408
percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps,
),
)
ctx.conditioning_mode = "both"
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
# ext: controlnet/ip/t2i [pre_unet_forward]
ext_manager.modifiers.pre_unet_forward(ctx)
# ext: inpaint [pre_unet_forward, priority=low]
# or
# ext: inpaint [override: unet_forward]
both_results = self._unet_forward(**vars(ctx.unet_kwargs))
negative_next_x, positive_next_x = both_results.chunk(2)
# del locals
del ctx.unet_kwargs
del ctx.conditioning_mode
return negative_next_x, positive_next_x
def _apply_standard_conditioning_sequentially(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
slower execution speed.
"""
###################
# Negative pass
###################
ctx.unet_kwargs = UNetKwargs(
sample=ctx.latent_model_input,
timestep=ctx.timestep,
encoder_hidden_states=None, # set later by conditoning
cross_attention_kwargs=dict( # noqa: C408
percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps,
),
)
ctx.conditioning_mode = "negative"
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "negative")
# ext: controlnet/ip/t2i [pre_unet_forward]
ext_manager.modifiers.pre_unet_forward(ctx)
# ext: inpaint [pre_unet_forward, priority=low]
# or
# ext: inpaint [override: unet_forward]
negative_next_x = self._unet_forward(**vars(ctx.unet_kwargs))
del ctx.unet_kwargs
del ctx.conditioning_mode
# TODO: gc.collect() ?
###################
# Positive pass
###################
ctx.unet_kwargs = UNetKwargs(
sample=ctx.latent_model_input,
timestep=ctx.timestep,
encoder_hidden_states=None, # set later by conditoning
cross_attention_kwargs=dict( # noqa: C408
percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps,
),
)
ctx.conditioning_mode = "positive"
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "positive")
# ext: controlnet/ip/t2i [pre_unet_forward]
ext_manager.modifiers.pre_unet_forward(ctx)
# ext: inpaint [pre_unet_forward, priority=low]
# or
# ext: inpaint [override: unet_forward]
positive_next_x = self._unet_forward(**vars(ctx.unet_kwargs))
del ctx.unet_kwargs
del ctx.conditioning_mode
# TODO: gc.collect() ?
return negative_next_x, positive_next_x
def _unet_forward(self, **kwargs) -> torch.Tensor:
return self.unet(**kwargs).sample

View File

@ -0,0 +1,9 @@
"""
Initialization file for the invokeai.backend.stable_diffusion.extensions package
"""
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
__all__ = [
"ExtensionBase",
]

View File

@ -0,0 +1,58 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional
import torch
from diffusers import UNet2DConditionModel
@dataclass
class InjectionInfo:
type: str
name: str
order: Optional[str]
function: Callable
def modifier(name: str, order: str = "any"):
def _decorator(func):
func.__inj_info__ = {
"type": "modifier",
"name": name,
"order": order,
}
return func
return _decorator
def override(name: str):
def _decorator(func):
func.__inj_info__ = {
"type": "override",
"name": name,
"order": None,
}
return func
return _decorator
class ExtensionBase:
def __init__(self, priority: int):
self.priority = priority
self.injections: List[InjectionInfo] = []
for func_name in dir(self):
func = getattr(self, func_name)
if not callable(func) or not hasattr(func, "__inj_info__"):
continue
self.injections.append(InjectionInfo(**func.__inj_info__, function=func))
@contextmanager
def patch_attention_processor(self, attention_processor_cls: object):
yield None
@contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
yield None

View File

@ -0,0 +1,195 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from contextlib import ExitStack, contextmanager
from functools import partial
from typing import TYPE_CHECKING, Callable, Dict
import torch
from diffusers import UNet2DConditionModel
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extensions import ExtensionBase
class ExtModifiersApi(ABC):
@abstractmethod
def pre_denoise_loop(self, ctx: DenoiseContext):
pass
@abstractmethod
def post_denoise_loop(self, ctx: DenoiseContext):
pass
@abstractmethod
def pre_step(self, ctx: DenoiseContext):
pass
@abstractmethod
def post_step(self, ctx: DenoiseContext):
pass
@abstractmethod
def modify_noise_prediction(self, ctx: DenoiseContext):
pass
@abstractmethod
def pre_unet_forward(self, ctx: DenoiseContext):
pass
@abstractmethod
def pre_unet_load(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
pass
class ExtOverridesApi(ABC):
@abstractmethod
def step(self, orig_func: Callable, ctx: DenoiseContext, ext_manager: ExtensionsManager):
pass
@abstractmethod
def combine_noise(self, orig_func: Callable, ctx: DenoiseContext):
pass
class ProxyCallsClass:
def __init__(self, handler):
self._handler = handler
def __getattr__(self, item):
return partial(self._handler, item)
class ModifierInjectionPoint:
def __init__(self):
self.first = []
self.any = []
self.last = []
def add(self, func: Callable, order: str):
if order == "first":
self.first.append(func)
elif order == "last":
self.last.append(func)
else: # elif order == "any":
self.any.append(func)
def __call__(self, *args, **kwargs):
for func in self.first:
func(*args, **kwargs)
for func in self.any:
func(*args, **kwargs)
for func in reversed(self.last):
func(*args, **kwargs)
class ExtensionsManager:
def __init__(self):
self.extensions = []
self._overrides = {}
self._modifiers = {}
self.modifiers: ExtModifiersApi = ProxyCallsClass(self.call_modifier)
self.overrides: ExtOverridesApi = ProxyCallsClass(self.call_override)
def add_extension(self, ext: ExtensionBase):
self.extensions.append(ext)
ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority)
self._overrides.clear()
self._modifiers.clear()
for ext in ordered_extensions:
for inj_info in ext.injections:
if inj_info.type == "modifier":
if inj_info.name not in self._modifiers:
self._modifiers[inj_info.name] = ModifierInjectionPoint()
self._modifiers[inj_info.name].add(inj_info.function, inj_info.order)
else:
if inj_info.name in self._overrides:
raise Exception(f"Already overloaded - {inj_info.name}")
self._overrides[inj_info.name] = inj_info.function
def call_modifier(self, name: str, *args, **kwargs):
if name in self._modifiers:
self._modifiers[name](*args, **kwargs)
def call_override(self, name: str, orig_func: Callable, *args, **kwargs):
if name in self._overrides:
return self._overrides[name](orig_func, *args, **kwargs)
else:
return orig_func(*args, **kwargs)
# TODO: is there any need in such high abstarction
# @contextmanager
# def patch_extensions(self):
# exit_stack = ExitStack()
# try:
# for ext in self.extensions:
# exit_stack.enter_context(ext.patch_extension(self))
#
# yield None
#
# finally:
# exit_stack.close()
@contextmanager
def patch_attention_processor(self, unet: UNet2DConditionModel, attn_processor_cls: object):
unet_orig_processors = unet.attn_processors
exit_stack = ExitStack()
try:
# just to be sure that attentions have not same processor instance
attn_procs = {}
for name in unet.attn_processors.keys():
attn_procs[name] = attn_processor_cls()
unet.set_attn_processor(attn_procs)
for ext in self.extensions:
exit_stack.enter_context(ext.patch_attention_processor(attn_processor_cls))
yield None
finally:
unet.set_attn_processor(unet_orig_processors)
exit_stack.close()
@contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
exit_stack = ExitStack()
try:
changed_keys = set()
changed_unknown_keys = {}
ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority)
for ext in ordered_extensions:
patch_result = exit_stack.enter_context(ext.patch_unet(state_dict, unet))
if patch_result is None:
continue
new_keys, new_unk_keys = patch_result
changed_keys.update(new_keys)
# skip already seen keys, as new weight might be changed
for k, v in new_unk_keys.items():
if k in changed_unknown_keys:
continue
changed_unknown_keys[k] = v
yield None
finally:
exit_stack.close()
assert hasattr(unet, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad():
for module_key in changed_keys:
weight = state_dict[module_key]
unet.get_submodule(module_key).weight.copy_(
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
)
for module_key, weight in changed_unknown_keys.items():
unet.get_submodule(module_key).weight.copy_(
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
)