Add preview extension to check logic

This commit is contained in:
Sergey Borisov 2024-07-13 00:45:04 +03:00
parent e961dd1dec
commit 499e4d4fde
4 changed files with 82 additions and 9 deletions

View File

@ -57,6 +57,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
)
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extensions import PreviewExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@ -777,6 +778,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
scheduler=scheduler,
)
### preview
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
ext_manager.add_extension(PreviewExt(step_callback))
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)

View File

@ -23,19 +23,19 @@ from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.stable_diffusion.extensions import PipelineIntermediateState
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
@dataclass
class PipelineIntermediateState:
step: int
order: int
total_steps: int
timestep: int
latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None
# @dataclass
# class PipelineIntermediateState:
# step: int
# order: int
# total_steps: int
# timestep: int
# latents: torch.Tensor
# predicted_original: Optional[torch.Tensor] = None
@dataclass

View File

@ -3,7 +3,10 @@ Initialization file for the invokeai.backend.stable_diffusion.extensions package
"""
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState, PreviewExt
__all__ = [
"ExtensionBase",
"PipelineIntermediateState",
"PreviewExt",
]

View File

@ -0,0 +1,63 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional
import torch
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
# TODO: change event to accept image instead of latents
@dataclass
class PipelineIntermediateState:
step: int
order: int
total_steps: int
timestep: int
latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None
class PreviewExt(ExtensionBase):
def __init__(self, callback: Callable[[PipelineIntermediateState], None]):
super().__init__()
self.callback = callback
# do last so that all other changes shown
@callback("pre_denoise_loop", order=1000)
def initial_preview(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
self.callback(
PipelineIntermediateState(
step=-1,
order=ctx.scheduler.order,
total_steps=len(ctx.timesteps),
timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it?
latents=ctx.latents,
)
)
# do last so that all other changes shown
@callback("post_step", order=1000)
def step_preview(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
if hasattr(ctx.step_output, "denoised"):
predicted_original = ctx.step_output.denoised
elif hasattr(ctx.step_output, "pred_original_sample"):
predicted_original = ctx.step_output.pred_original_sample
else:
predicted_original = ctx.step_output.prev_sample
self.callback(
PipelineIntermediateState(
step=ctx.step_index,
order=ctx.scheduler.order,
total_steps=len(ctx.timesteps),
timestep=int(ctx.timestep), # TODO: is there any code which uses it?
latents=ctx.step_output.prev_sample,
predicted_original=predicted_original, # TODO: is there any reason for additional field?
)
)