Ryan's suggested changes to extension manager/extensions

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
Sergey Borisov
2024-07-18 23:49:44 +03:00
parent 710dc6b487
commit 0c56d4a581
6 changed files with 79 additions and 109 deletions

View File

@ -8,6 +8,7 @@ from tqdm.auto import tqdm
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
@ -41,23 +42,23 @@ class StableDiffusionBackend:
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
# ext: preview[pre_denoise_loop, priority=low]
ext_manager.callbacks.pre_denoise_loop(ctx, ext_manager)
ext_manager.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, ctx)
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020
# ext: inpaint (apply mask to latents on non-inpaint models)
ext_manager.callbacks.pre_step(ctx, ext_manager)
ext_manager.run_callback(ExtensionCallbackType.PRE_STEP, ctx)
# ext: tiles? [override: step]
ctx.step_output = self.step(ctx, ext_manager)
# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
# ext: preview[post_step, priority=low]
ext_manager.callbacks.post_step(ctx, ext_manager)
ext_manager.run_callback(ExtensionCallbackType.POST_STEP, ctx)
ctx.latents = ctx.step_output.prev_sample
# ext: inpaint[post_denoise_loop] (restore unmasked part)
ext_manager.callbacks.post_denoise_loop(ctx, ext_manager)
ext_manager.run_callback(ExtensionCallbackType.POST_DENOISE_LOOP, ctx)
return ctx.latents
@torch.inference_mode()
@ -80,7 +81,7 @@ class StableDiffusionBackend:
# ext: cfg_rescale [modify_noise_prediction]
# TODO: rename
ext_manager.callbacks.post_apply_cfg(ctx, ext_manager)
ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx)
# compute the previous noisy sample x_t -> x_t-1
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
@ -120,14 +121,14 @@ class StableDiffusionBackend:
ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
# ext: controlnet/ip/t2i [pre_unet]
ext_manager.callbacks.pre_unet(ctx, ext_manager)
ext_manager.run_callback(ExtensionCallbackType.PRE_UNET, ctx)
# ext: inpaint [pre_unet, priority=low]
# or
# ext: inpaint [override: unet_forward]
noise_pred = self._unet_forward(**vars(ctx.unet_kwargs))
ext_manager.callbacks.post_unet(ctx, ext_manager)
ext_manager.run_callback(ExtensionCallbackType.POST_UNET, ctx)
# clean up locals
ctx.unet_kwargs = None