mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rename modifiers to callbacks, convert order to int, a bit unify injection points
This commit is contained in:
parent
0bc60378d3
commit
87e96e1be2
@ -781,7 +781,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
# ext: t2i/ip adapter
|
||||
ext_manager.modifiers.pre_unet_load(denoise_ctx, ext_manager)
|
||||
ext_manager.callbacks.pre_unet_load(denoise_ctx, ext_manager)
|
||||
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
|
@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from tqdm.auto import tqdm
|
||||
@ -12,30 +10,6 @@ from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UN
|
||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||
|
||||
|
||||
def trim_to_multiple_of(*args, multiple_of=8):
|
||||
return tuple((x - x % multiple_of) for x in args)
|
||||
|
||||
|
||||
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool = True, multiple_of=8) -> torch.FloatTensor:
|
||||
"""
|
||||
|
||||
:param image: input image
|
||||
:param normalize: scale the range to [-1, 1] instead of [0, 1]
|
||||
:param multiple_of: resize the input so both dimensions are a multiple of this
|
||||
"""
|
||||
w, h = trim_to_multiple_of(*image.size, multiple_of=multiple_of)
|
||||
transformation = T.Compose(
|
||||
[
|
||||
T.Resize((h, w), T.InterpolationMode.LANCZOS, antialias=True),
|
||||
T.ToTensor(),
|
||||
]
|
||||
)
|
||||
tensor = transformation(image)
|
||||
if normalize:
|
||||
tensor = tensor * 2.0 - 1.0
|
||||
return tensor
|
||||
|
||||
|
||||
class StableDiffusionBackend:
|
||||
def __init__(
|
||||
self,
|
||||
@ -64,23 +38,23 @@ class StableDiffusionBackend:
|
||||
|
||||
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
|
||||
# ext: preview[pre_denoise_loop, priority=low]
|
||||
ext_manager.modifiers.pre_denoise_loop(ctx)
|
||||
ext_manager.callbacks.pre_denoise_loop(ctx, ext_manager)
|
||||
|
||||
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.timesteps)): # noqa: B020
|
||||
# ext: inpaint (apply mask to latents on non-inpaint models)
|
||||
ext_manager.modifiers.pre_step(ctx)
|
||||
ext_manager.callbacks.pre_step(ctx, ext_manager)
|
||||
|
||||
# ext: tiles? [override: step]
|
||||
ctx.step_output = ext_manager.overrides.step(self.step, ctx, ext_manager)
|
||||
|
||||
# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
|
||||
# ext: preview[post_step, priority=low]
|
||||
ext_manager.modifiers.post_step(ctx)
|
||||
ext_manager.callbacks.post_step(ctx, ext_manager)
|
||||
|
||||
ctx.latents = ctx.step_output.prev_sample
|
||||
|
||||
# ext: inpaint[post_denoise_loop] (restore unmasked part)
|
||||
ext_manager.modifiers.post_denoise_loop(ctx)
|
||||
ext_manager.callbacks.post_denoise_loop(ctx, ext_manager)
|
||||
return ctx.latents
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -95,11 +69,11 @@ class StableDiffusionBackend:
|
||||
# not sure if here needed override
|
||||
ctx.negative_noise_pred, ctx.positive_noise_pred = conditioning_call(ctx, ext_manager)
|
||||
|
||||
# ext: override combine_noise
|
||||
ctx.noise_pred = ext_manager.overrides.combine_noise(self.combine_noise, ctx)
|
||||
# ext: override apply_cfg
|
||||
ctx.noise_pred = ext_manager.overrides.apply_cfg(self.apply_cfg, ctx)
|
||||
|
||||
# ext: cfg_rescale [modify_noise_prediction]
|
||||
ext_manager.modifiers.modify_noise_prediction(ctx)
|
||||
ext_manager.callbacks.modify_noise_prediction(ctx, ext_manager)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.scheduler_step_kwargs)
|
||||
@ -113,7 +87,7 @@ class StableDiffusionBackend:
|
||||
return step_output
|
||||
|
||||
@staticmethod
|
||||
def combine_noise(ctx: DenoiseContext) -> torch.Tensor:
|
||||
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
|
||||
guidance_scale = ctx.conditioning_data.guidance_scale
|
||||
if isinstance(guidance_scale, list):
|
||||
guidance_scale = guidance_scale[ctx.step_index]
|
||||
@ -141,7 +115,7 @@ class StableDiffusionBackend:
|
||||
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
|
||||
|
||||
# ext: controlnet/ip/t2i [pre_unet_forward]
|
||||
ext_manager.modifiers.pre_unet_forward(ctx)
|
||||
ext_manager.callbacks.pre_unet_forward(ctx, ext_manager)
|
||||
|
||||
# ext: inpaint [pre_unet_forward, priority=low]
|
||||
# or
|
||||
@ -175,7 +149,7 @@ class StableDiffusionBackend:
|
||||
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "negative")
|
||||
|
||||
# ext: controlnet/ip/t2i [pre_unet_forward]
|
||||
ext_manager.modifiers.pre_unet_forward(ctx)
|
||||
ext_manager.callbacks.pre_unet_forward(ctx, ext_manager)
|
||||
|
||||
# ext: inpaint [pre_unet_forward, priority=low]
|
||||
# or
|
||||
@ -203,7 +177,7 @@ class StableDiffusionBackend:
|
||||
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "positive")
|
||||
|
||||
# ext: controlnet/ip/t2i [pre_unet_forward]
|
||||
ext_manager.modifiers.pre_unet_forward(ctx)
|
||||
ext_manager.callbacks.pre_unet_forward(ctx, ext_manager)
|
||||
|
||||
# ext: inpaint [pre_unet_forward, priority=low]
|
||||
# or
|
||||
|
@ -10,14 +10,14 @@ from diffusers import UNet2DConditionModel
|
||||
class InjectionInfo:
|
||||
type: str
|
||||
name: str
|
||||
order: Optional[str]
|
||||
order: Optional[int]
|
||||
function: Callable
|
||||
|
||||
|
||||
def modifier(name: str, order: str = "any"):
|
||||
def callback(name: str, order: int = 0):
|
||||
def _decorator(func):
|
||||
func.__inj_info__ = {
|
||||
"type": "modifier",
|
||||
"type": "callback",
|
||||
"name": name,
|
||||
"order": order,
|
||||
}
|
||||
|
@ -15,29 +15,29 @@ if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.extensions import ExtensionBase
|
||||
|
||||
|
||||
class ExtModifiersApi(ABC):
|
||||
class ExtCallbacksApi(ABC):
|
||||
@abstractmethod
|
||||
def pre_denoise_loop(self, ctx: DenoiseContext):
|
||||
def pre_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_denoise_loop(self, ctx: DenoiseContext):
|
||||
def post_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_step(self, ctx: DenoiseContext):
|
||||
def pre_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_step(self, ctx: DenoiseContext):
|
||||
def post_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def modify_noise_prediction(self, ctx: DenoiseContext):
|
||||
def modify_noise_prediction(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_unet_forward(self, ctx: DenoiseContext):
|
||||
def pre_unet_forward(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -51,7 +51,7 @@ class ExtOverridesApi(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def combine_noise(self, orig_func: Callable, ctx: DenoiseContext):
|
||||
def apply_cfg(self, orig_func: Callable, ctx: DenoiseContext):
|
||||
pass
|
||||
|
||||
|
||||
@ -63,27 +63,19 @@ class ProxyCallsClass:
|
||||
return partial(self._handler, item)
|
||||
|
||||
|
||||
class ModifierInjectionPoint:
|
||||
class CallbackInjectionPoint:
|
||||
def __init__(self):
|
||||
self.first = []
|
||||
self.any = []
|
||||
self.last = []
|
||||
self.handlers = {}
|
||||
|
||||
def add(self, func: Callable, order: str):
|
||||
if order == "first":
|
||||
self.first.append(func)
|
||||
elif order == "last":
|
||||
self.last.append(func)
|
||||
else: # elif order == "any":
|
||||
self.any.append(func)
|
||||
def add(self, func: Callable, order: int):
|
||||
if order not in self.handlers:
|
||||
self.handlers[order] = []
|
||||
self.handlers[order].append(func)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
for func in self.first:
|
||||
func(*args, **kwargs)
|
||||
for func in self.any:
|
||||
func(*args, **kwargs)
|
||||
for func in reversed(self.last):
|
||||
func(*args, **kwargs)
|
||||
for order in sorted(self.handlers.keys(), reverse=True):
|
||||
for handler in self.handlers[order]:
|
||||
handler(*args, **kwargs)
|
||||
|
||||
|
||||
class ExtensionsManager:
|
||||
@ -91,9 +83,9 @@ class ExtensionsManager:
|
||||
self.extensions = []
|
||||
|
||||
self._overrides = {}
|
||||
self._modifiers = {}
|
||||
self._callbacks = {}
|
||||
|
||||
self.modifiers: ExtModifiersApi = ProxyCallsClass(self.call_modifier)
|
||||
self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback)
|
||||
self.overrides: ExtOverridesApi = ProxyCallsClass(self.call_override)
|
||||
|
||||
def add_extension(self, ext: ExtensionBase):
|
||||
@ -101,23 +93,23 @@ class ExtensionsManager:
|
||||
ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority)
|
||||
|
||||
self._overrides.clear()
|
||||
self._modifiers.clear()
|
||||
self._callbacks.clear()
|
||||
|
||||
for ext in ordered_extensions:
|
||||
for inj_info in ext.injections:
|
||||
if inj_info.type == "modifier":
|
||||
if inj_info.name not in self._modifiers:
|
||||
self._modifiers[inj_info.name] = ModifierInjectionPoint()
|
||||
self._modifiers[inj_info.name].add(inj_info.function, inj_info.order)
|
||||
if inj_info.type == "callback":
|
||||
if inj_info.name not in self._callbacks:
|
||||
self._callbacks[inj_info.name] = CallbackInjectionPoint()
|
||||
self._callbacks[inj_info.name].add(inj_info.function, inj_info.order)
|
||||
|
||||
else:
|
||||
if inj_info.name in self._overrides:
|
||||
raise Exception(f"Already overloaded - {inj_info.name}")
|
||||
self._overrides[inj_info.name] = inj_info.function
|
||||
|
||||
def call_modifier(self, name: str, *args, **kwargs):
|
||||
if name in self._modifiers:
|
||||
self._modifiers[name](*args, **kwargs)
|
||||
def call_callback(self, name: str, *args, **kwargs):
|
||||
if name in self._callbacks:
|
||||
self._callbacks[name](*args, **kwargs)
|
||||
|
||||
def call_override(self, name: str, orig_func: Callable, *args, **kwargs):
|
||||
if name in self._overrides:
|
||||
|
Loading…
Reference in New Issue
Block a user