InvokeAI/invokeai/backend/stable_diffusion/extensions/preview.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

64 lines
2.2 KiB
Python
Raw Permalink Normal View History

2024-07-12 21:45:04 +00:00
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional
import torch
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
2024-07-12 21:45:04 +00:00
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
# TODO: change event to accept image instead of latents
@dataclass
class PipelineIntermediateState:
step: int
order: int
total_steps: int
timestep: int
latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None
class PreviewExt(ExtensionBase):
def __init__(self, callback: Callable[[PipelineIntermediateState], None]):
super().__init__()
self.callback = callback
# do last so that all other changes shown
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
def initial_preview(self, ctx: DenoiseContext):
2024-07-12 21:45:04 +00:00
self.callback(
PipelineIntermediateState(
step=-1,
order=ctx.scheduler.order,
2024-07-16 16:30:29 +00:00
total_steps=len(ctx.inputs.timesteps),
2024-07-12 21:45:04 +00:00
timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it?
latents=ctx.latents,
)
)
# do last so that all other changes shown
@callback(ExtensionCallbackType.POST_STEP, order=1000)
def step_preview(self, ctx: DenoiseContext):
2024-07-12 21:45:04 +00:00
if hasattr(ctx.step_output, "denoised"):
predicted_original = ctx.step_output.denoised
elif hasattr(ctx.step_output, "pred_original_sample"):
predicted_original = ctx.step_output.pred_original_sample
else:
predicted_original = ctx.step_output.prev_sample
self.callback(
PipelineIntermediateState(
step=ctx.step_index,
order=ctx.scheduler.order,
2024-07-16 16:30:29 +00:00
total_steps=len(ctx.inputs.timesteps),
2024-07-12 21:45:04 +00:00
timestep=int(ctx.timestep), # TODO: is there any code which uses it?
latents=ctx.step_output.prev_sample,
predicted_original=predicted_original, # TODO: is there any reason for additional field?
)
)