Rename modifiers to callbacks, convert order to int, a bit unify injection points

This commit is contained in:
Sergey Borisov 2024-07-12 22:01:05 +03:00
parent 0bc60378d3
commit 87e96e1be2
4 changed files with 42 additions and 76 deletions

View File

@ -781,7 +781,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
unet_config = context.models.get_config(self.unet.unet.key)
# ext: t2i/ip adapter
ext_manager.modifiers.pre_unet_load(denoise_ctx, ext_manager)
ext_manager.callbacks.pre_unet_load(denoise_ctx, ext_manager)
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)

View File

@ -1,8 +1,6 @@
from __future__ import annotations
import PIL.Image
import torch
import torchvision.transforms as T
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from tqdm.auto import tqdm
@ -12,30 +10,6 @@ from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UN
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
def trim_to_multiple_of(*args, multiple_of=8):
return tuple((x - x % multiple_of) for x in args)
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool = True, multiple_of=8) -> torch.FloatTensor:
"""
:param image: input image
:param normalize: scale the range to [-1, 1] instead of [0, 1]
:param multiple_of: resize the input so both dimensions are a multiple of this
"""
w, h = trim_to_multiple_of(*image.size, multiple_of=multiple_of)
transformation = T.Compose(
[
T.Resize((h, w), T.InterpolationMode.LANCZOS, antialias=True),
T.ToTensor(),
]
)
tensor = transformation(image)
if normalize:
tensor = tensor * 2.0 - 1.0
return tensor
class StableDiffusionBackend:
def __init__(
self,
@ -64,23 +38,23 @@ class StableDiffusionBackend:
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
# ext: preview[pre_denoise_loop, priority=low]
ext_manager.modifiers.pre_denoise_loop(ctx)
ext_manager.callbacks.pre_denoise_loop(ctx, ext_manager)
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.timesteps)): # noqa: B020
# ext: inpaint (apply mask to latents on non-inpaint models)
ext_manager.modifiers.pre_step(ctx)
ext_manager.callbacks.pre_step(ctx, ext_manager)
# ext: tiles? [override: step]
ctx.step_output = ext_manager.overrides.step(self.step, ctx, ext_manager)
# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
# ext: preview[post_step, priority=low]
ext_manager.modifiers.post_step(ctx)
ext_manager.callbacks.post_step(ctx, ext_manager)
ctx.latents = ctx.step_output.prev_sample
# ext: inpaint[post_denoise_loop] (restore unmasked part)
ext_manager.modifiers.post_denoise_loop(ctx)
ext_manager.callbacks.post_denoise_loop(ctx, ext_manager)
return ctx.latents
@torch.inference_mode()
@ -95,11 +69,11 @@ class StableDiffusionBackend:
# not sure if here needed override
ctx.negative_noise_pred, ctx.positive_noise_pred = conditioning_call(ctx, ext_manager)
# ext: override combine_noise
ctx.noise_pred = ext_manager.overrides.combine_noise(self.combine_noise, ctx)
# ext: override apply_cfg
ctx.noise_pred = ext_manager.overrides.apply_cfg(self.apply_cfg, ctx)
# ext: cfg_rescale [modify_noise_prediction]
ext_manager.modifiers.modify_noise_prediction(ctx)
ext_manager.callbacks.modify_noise_prediction(ctx, ext_manager)
# compute the previous noisy sample x_t -> x_t-1
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.scheduler_step_kwargs)
@ -113,7 +87,7 @@ class StableDiffusionBackend:
return step_output
@staticmethod
def combine_noise(ctx: DenoiseContext) -> torch.Tensor:
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
guidance_scale = ctx.conditioning_data.guidance_scale
if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[ctx.step_index]
@ -141,7 +115,7 @@ class StableDiffusionBackend:
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
# ext: controlnet/ip/t2i [pre_unet_forward]
ext_manager.modifiers.pre_unet_forward(ctx)
ext_manager.callbacks.pre_unet_forward(ctx, ext_manager)
# ext: inpaint [pre_unet_forward, priority=low]
# or
@ -175,7 +149,7 @@ class StableDiffusionBackend:
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "negative")
# ext: controlnet/ip/t2i [pre_unet_forward]
ext_manager.modifiers.pre_unet_forward(ctx)
ext_manager.callbacks.pre_unet_forward(ctx, ext_manager)
# ext: inpaint [pre_unet_forward, priority=low]
# or
@ -203,7 +177,7 @@ class StableDiffusionBackend:
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "positive")
# ext: controlnet/ip/t2i [pre_unet_forward]
ext_manager.modifiers.pre_unet_forward(ctx)
ext_manager.callbacks.pre_unet_forward(ctx, ext_manager)
# ext: inpaint [pre_unet_forward, priority=low]
# or

View File

@ -10,14 +10,14 @@ from diffusers import UNet2DConditionModel
class InjectionInfo:
type: str
name: str
order: Optional[str]
order: Optional[int]
function: Callable
def modifier(name: str, order: str = "any"):
def callback(name: str, order: int = 0):
def _decorator(func):
func.__inj_info__ = {
"type": "modifier",
"type": "callback",
"name": name,
"order": order,
}

View File

@ -15,29 +15,29 @@ if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.extensions import ExtensionBase
class ExtModifiersApi(ABC):
class ExtCallbacksApi(ABC):
@abstractmethod
def pre_denoise_loop(self, ctx: DenoiseContext):
def pre_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
pass
@abstractmethod
def post_denoise_loop(self, ctx: DenoiseContext):
def post_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
pass
@abstractmethod
def pre_step(self, ctx: DenoiseContext):
def pre_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
pass
@abstractmethod
def post_step(self, ctx: DenoiseContext):
def post_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
pass
@abstractmethod
def modify_noise_prediction(self, ctx: DenoiseContext):
def modify_noise_prediction(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
pass
@abstractmethod
def pre_unet_forward(self, ctx: DenoiseContext):
def pre_unet_forward(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
pass
@abstractmethod
@ -51,7 +51,7 @@ class ExtOverridesApi(ABC):
pass
@abstractmethod
def combine_noise(self, orig_func: Callable, ctx: DenoiseContext):
def apply_cfg(self, orig_func: Callable, ctx: DenoiseContext):
pass
@ -63,27 +63,19 @@ class ProxyCallsClass:
return partial(self._handler, item)
class ModifierInjectionPoint:
class CallbackInjectionPoint:
def __init__(self):
self.first = []
self.any = []
self.last = []
self.handlers = {}
def add(self, func: Callable, order: str):
if order == "first":
self.first.append(func)
elif order == "last":
self.last.append(func)
else: # elif order == "any":
self.any.append(func)
def add(self, func: Callable, order: int):
if order not in self.handlers:
self.handlers[order] = []
self.handlers[order].append(func)
def __call__(self, *args, **kwargs):
for func in self.first:
func(*args, **kwargs)
for func in self.any:
func(*args, **kwargs)
for func in reversed(self.last):
func(*args, **kwargs)
for order in sorted(self.handlers.keys(), reverse=True):
for handler in self.handlers[order]:
handler(*args, **kwargs)
class ExtensionsManager:
@ -91,9 +83,9 @@ class ExtensionsManager:
self.extensions = []
self._overrides = {}
self._modifiers = {}
self._callbacks = {}
self.modifiers: ExtModifiersApi = ProxyCallsClass(self.call_modifier)
self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback)
self.overrides: ExtOverridesApi = ProxyCallsClass(self.call_override)
def add_extension(self, ext: ExtensionBase):
@ -101,23 +93,23 @@ class ExtensionsManager:
ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority)
self._overrides.clear()
self._modifiers.clear()
self._callbacks.clear()
for ext in ordered_extensions:
for inj_info in ext.injections:
if inj_info.type == "modifier":
if inj_info.name not in self._modifiers:
self._modifiers[inj_info.name] = ModifierInjectionPoint()
self._modifiers[inj_info.name].add(inj_info.function, inj_info.order)
if inj_info.type == "callback":
if inj_info.name not in self._callbacks:
self._callbacks[inj_info.name] = CallbackInjectionPoint()
self._callbacks[inj_info.name].add(inj_info.function, inj_info.order)
else:
if inj_info.name in self._overrides:
raise Exception(f"Already overloaded - {inj_info.name}")
self._overrides[inj_info.name] = inj_info.function
def call_modifier(self, name: str, *args, **kwargs):
if name in self._modifiers:
self._modifiers[name](*args, **kwargs)
def call_callback(self, name: str, *args, **kwargs):
if name in self._callbacks:
self._callbacks[name](*args, **kwargs)
def call_override(self, name: str, orig_func: Callable, *args, **kwargs):
if name in self._overrides: