2024-07-12 21:45:04 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from typing import TYPE_CHECKING, Callable, Optional
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
2024-07-18 20:49:44 +00:00
|
|
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
2024-07-12 21:45:04 +00:00
|
|
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: change event to accept image instead of latents
|
|
|
|
@dataclass
|
|
|
|
class PipelineIntermediateState:
|
|
|
|
step: int
|
|
|
|
order: int
|
|
|
|
total_steps: int
|
|
|
|
timestep: int
|
|
|
|
latents: torch.Tensor
|
|
|
|
predicted_original: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
|
|
|
|
|
class PreviewExt(ExtensionBase):
|
|
|
|
def __init__(self, callback: Callable[[PipelineIntermediateState], None]):
|
|
|
|
super().__init__()
|
|
|
|
self.callback = callback
|
|
|
|
|
|
|
|
# do last so that all other changes shown
|
2024-07-18 20:49:44 +00:00
|
|
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
|
|
|
|
def initial_preview(self, ctx: DenoiseContext):
|
2024-07-12 21:45:04 +00:00
|
|
|
self.callback(
|
|
|
|
PipelineIntermediateState(
|
|
|
|
step=-1,
|
|
|
|
order=ctx.scheduler.order,
|
2024-07-16 16:30:29 +00:00
|
|
|
total_steps=len(ctx.inputs.timesteps),
|
2024-07-12 21:45:04 +00:00
|
|
|
timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it?
|
|
|
|
latents=ctx.latents,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
# do last so that all other changes shown
|
2024-07-18 20:49:44 +00:00
|
|
|
@callback(ExtensionCallbackType.POST_STEP, order=1000)
|
|
|
|
def step_preview(self, ctx: DenoiseContext):
|
2024-07-12 21:45:04 +00:00
|
|
|
if hasattr(ctx.step_output, "denoised"):
|
|
|
|
predicted_original = ctx.step_output.denoised
|
|
|
|
elif hasattr(ctx.step_output, "pred_original_sample"):
|
|
|
|
predicted_original = ctx.step_output.pred_original_sample
|
|
|
|
else:
|
|
|
|
predicted_original = ctx.step_output.prev_sample
|
|
|
|
|
|
|
|
self.callback(
|
|
|
|
PipelineIntermediateState(
|
|
|
|
step=ctx.step_index,
|
|
|
|
order=ctx.scheduler.order,
|
2024-07-16 16:30:29 +00:00
|
|
|
total_steps=len(ctx.inputs.timesteps),
|
2024-07-12 21:45:04 +00:00
|
|
|
timestep=int(ctx.timestep), # TODO: is there any code which uses it?
|
|
|
|
latents=ctx.step_output.prev_sample,
|
|
|
|
predicted_original=predicted_original, # TODO: is there any reason for additional field?
|
|
|
|
)
|
|
|
|
)
|