From 87e96e1be2d3cb9dee1f08c5b254b3089637b555 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Fri, 12 Jul 2024 22:01:05 +0300 Subject: [PATCH] Rename modifiers to callbacks, convert order to int, a bit unify injection points --- invokeai/app/invocations/denoise_latents.py | 2 +- .../stable_diffusion/diffusion_backend.py | 48 ++++---------- .../stable_diffusion/extensions/base.py | 6 +- .../stable_diffusion/extensions_manager.py | 62 ++++++++----------- 4 files changed, 42 insertions(+), 76 deletions(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index bec8741936..1bc66e423f 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -781,7 +781,7 @@ class DenoiseLatentsInvocation(BaseInvocation): unet_config = context.models.get_config(self.unet.unet.key) # ext: t2i/ip adapter - ext_manager.modifiers.pre_unet_load(denoise_ctx, ext_manager) + ext_manager.callbacks.pre_unet_load(denoise_ctx, ext_manager) unet_info = context.models.load(self.unet.unet) assert isinstance(unet_info.model, UNet2DConditionModel) diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 264fed2fe6..4630d4740d 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -1,8 +1,6 @@ from __future__ import annotations -import PIL.Image import torch -import torchvision.transforms as T from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput from tqdm.auto import tqdm @@ -12,30 +10,6 @@ from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UN from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager -def trim_to_multiple_of(*args, multiple_of=8): - return tuple((x - x % multiple_of) for x in args) - - -def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool = True, multiple_of=8) -> torch.FloatTensor: - """ - - :param image: input image - :param normalize: scale the range to [-1, 1] instead of [0, 1] - :param multiple_of: resize the input so both dimensions are a multiple of this - """ - w, h = trim_to_multiple_of(*image.size, multiple_of=multiple_of) - transformation = T.Compose( - [ - T.Resize((h, w), T.InterpolationMode.LANCZOS, antialias=True), - T.ToTensor(), - ] - ) - tensor = transformation(image) - if normalize: - tensor = tensor * 2.0 - 1.0 - return tensor - - class StableDiffusionBackend: def __init__( self, @@ -64,23 +38,23 @@ class StableDiffusionBackend: # ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed) # ext: preview[pre_denoise_loop, priority=low] - ext_manager.modifiers.pre_denoise_loop(ctx) + ext_manager.callbacks.pre_denoise_loop(ctx, ext_manager) for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.timesteps)): # noqa: B020 # ext: inpaint (apply mask to latents on non-inpaint models) - ext_manager.modifiers.pre_step(ctx) + ext_manager.callbacks.pre_step(ctx, ext_manager) # ext: tiles? [override: step] ctx.step_output = ext_manager.overrides.step(self.step, ctx, ext_manager) # ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models) # ext: preview[post_step, priority=low] - ext_manager.modifiers.post_step(ctx) + ext_manager.callbacks.post_step(ctx, ext_manager) ctx.latents = ctx.step_output.prev_sample # ext: inpaint[post_denoise_loop] (restore unmasked part) - ext_manager.modifiers.post_denoise_loop(ctx) + ext_manager.callbacks.post_denoise_loop(ctx, ext_manager) return ctx.latents @torch.inference_mode() @@ -95,11 +69,11 @@ class StableDiffusionBackend: # not sure if here needed override ctx.negative_noise_pred, ctx.positive_noise_pred = conditioning_call(ctx, ext_manager) - # ext: override combine_noise - ctx.noise_pred = ext_manager.overrides.combine_noise(self.combine_noise, ctx) + # ext: override apply_cfg + ctx.noise_pred = ext_manager.overrides.apply_cfg(self.apply_cfg, ctx) # ext: cfg_rescale [modify_noise_prediction] - ext_manager.modifiers.modify_noise_prediction(ctx) + ext_manager.callbacks.modify_noise_prediction(ctx, ext_manager) # compute the previous noisy sample x_t -> x_t-1 step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.scheduler_step_kwargs) @@ -113,7 +87,7 @@ class StableDiffusionBackend: return step_output @staticmethod - def combine_noise(ctx: DenoiseContext) -> torch.Tensor: + def apply_cfg(ctx: DenoiseContext) -> torch.Tensor: guidance_scale = ctx.conditioning_data.guidance_scale if isinstance(guidance_scale, list): guidance_scale = guidance_scale[ctx.step_index] @@ -141,7 +115,7 @@ class StableDiffusionBackend: ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode) # ext: controlnet/ip/t2i [pre_unet_forward] - ext_manager.modifiers.pre_unet_forward(ctx) + ext_manager.callbacks.pre_unet_forward(ctx, ext_manager) # ext: inpaint [pre_unet_forward, priority=low] # or @@ -175,7 +149,7 @@ class StableDiffusionBackend: ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "negative") # ext: controlnet/ip/t2i [pre_unet_forward] - ext_manager.modifiers.pre_unet_forward(ctx) + ext_manager.callbacks.pre_unet_forward(ctx, ext_manager) # ext: inpaint [pre_unet_forward, priority=low] # or @@ -203,7 +177,7 @@ class StableDiffusionBackend: ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "positive") # ext: controlnet/ip/t2i [pre_unet_forward] - ext_manager.modifiers.pre_unet_forward(ctx) + ext_manager.callbacks.pre_unet_forward(ctx, ext_manager) # ext: inpaint [pre_unet_forward, priority=low] # or diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index d3414eea6f..79227921c3 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -10,14 +10,14 @@ from diffusers import UNet2DConditionModel class InjectionInfo: type: str name: str - order: Optional[str] + order: Optional[int] function: Callable -def modifier(name: str, order: str = "any"): +def callback(name: str, order: int = 0): def _decorator(func): func.__inj_info__ = { - "type": "modifier", + "type": "callback", "name": name, "order": order, } diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 2e6882e0ca..1d4892a982 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -15,29 +15,29 @@ if TYPE_CHECKING: from invokeai.backend.stable_diffusion.extensions import ExtensionBase -class ExtModifiersApi(ABC): +class ExtCallbacksApi(ABC): @abstractmethod - def pre_denoise_loop(self, ctx: DenoiseContext): + def pre_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): pass @abstractmethod - def post_denoise_loop(self, ctx: DenoiseContext): + def post_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): pass @abstractmethod - def pre_step(self, ctx: DenoiseContext): + def pre_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): pass @abstractmethod - def post_step(self, ctx: DenoiseContext): + def post_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): pass @abstractmethod - def modify_noise_prediction(self, ctx: DenoiseContext): + def modify_noise_prediction(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): pass @abstractmethod - def pre_unet_forward(self, ctx: DenoiseContext): + def pre_unet_forward(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): pass @abstractmethod @@ -51,7 +51,7 @@ class ExtOverridesApi(ABC): pass @abstractmethod - def combine_noise(self, orig_func: Callable, ctx: DenoiseContext): + def apply_cfg(self, orig_func: Callable, ctx: DenoiseContext): pass @@ -63,27 +63,19 @@ class ProxyCallsClass: return partial(self._handler, item) -class ModifierInjectionPoint: +class CallbackInjectionPoint: def __init__(self): - self.first = [] - self.any = [] - self.last = [] + self.handlers = {} - def add(self, func: Callable, order: str): - if order == "first": - self.first.append(func) - elif order == "last": - self.last.append(func) - else: # elif order == "any": - self.any.append(func) + def add(self, func: Callable, order: int): + if order not in self.handlers: + self.handlers[order] = [] + self.handlers[order].append(func) def __call__(self, *args, **kwargs): - for func in self.first: - func(*args, **kwargs) - for func in self.any: - func(*args, **kwargs) - for func in reversed(self.last): - func(*args, **kwargs) + for order in sorted(self.handlers.keys(), reverse=True): + for handler in self.handlers[order]: + handler(*args, **kwargs) class ExtensionsManager: @@ -91,9 +83,9 @@ class ExtensionsManager: self.extensions = [] self._overrides = {} - self._modifiers = {} + self._callbacks = {} - self.modifiers: ExtModifiersApi = ProxyCallsClass(self.call_modifier) + self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback) self.overrides: ExtOverridesApi = ProxyCallsClass(self.call_override) def add_extension(self, ext: ExtensionBase): @@ -101,23 +93,23 @@ class ExtensionsManager: ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority) self._overrides.clear() - self._modifiers.clear() + self._callbacks.clear() for ext in ordered_extensions: for inj_info in ext.injections: - if inj_info.type == "modifier": - if inj_info.name not in self._modifiers: - self._modifiers[inj_info.name] = ModifierInjectionPoint() - self._modifiers[inj_info.name].add(inj_info.function, inj_info.order) + if inj_info.type == "callback": + if inj_info.name not in self._callbacks: + self._callbacks[inj_info.name] = CallbackInjectionPoint() + self._callbacks[inj_info.name].add(inj_info.function, inj_info.order) else: if inj_info.name in self._overrides: raise Exception(f"Already overloaded - {inj_info.name}") self._overrides[inj_info.name] = inj_info.function - def call_modifier(self, name: str, *args, **kwargs): - if name in self._modifiers: - self._modifiers[name](*args, **kwargs) + def call_callback(self, name: str, *args, **kwargs): + if name in self._callbacks: + self._callbacks[name](*args, **kwargs) def call_override(self, name: str, orig_func: Callable, *args, **kwargs): if name in self._overrides: