mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add preview extension to check logic
This commit is contained in:
parent
e961dd1dec
commit
499e4d4fde
@ -57,6 +57,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
||||
from invokeai.backend.stable_diffusion.extensions import PreviewExt
|
||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
@ -777,6 +778,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
### preview
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, unet_config.base)
|
||||
|
||||
ext_manager.add_extension(PreviewExt(step_callback))
|
||||
|
||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
|
@ -23,19 +23,19 @@ from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||
from invokeai.backend.stable_diffusion.extensions import PipelineIntermediateState
|
||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
step: int
|
||||
order: int
|
||||
total_steps: int
|
||||
timestep: int
|
||||
latents: torch.Tensor
|
||||
predicted_original: Optional[torch.Tensor] = None
|
||||
# @dataclass
|
||||
# class PipelineIntermediateState:
|
||||
# step: int
|
||||
# order: int
|
||||
# total_steps: int
|
||||
# timestep: int
|
||||
# latents: torch.Tensor
|
||||
# predicted_original: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -3,7 +3,10 @@ Initialization file for the invokeai.backend.stable_diffusion.extensions package
|
||||
"""
|
||||
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState, PreviewExt
|
||||
|
||||
__all__ = [
|
||||
"ExtensionBase",
|
||||
"PipelineIntermediateState",
|
||||
"PreviewExt",
|
||||
]
|
||||
|
63
invokeai/backend/stable_diffusion/extensions/preview.py
Normal file
63
invokeai/backend/stable_diffusion/extensions/preview.py
Normal file
@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||
|
||||
|
||||
# TODO: change event to accept image instead of latents
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
step: int
|
||||
order: int
|
||||
total_steps: int
|
||||
timestep: int
|
||||
latents: torch.Tensor
|
||||
predicted_original: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class PreviewExt(ExtensionBase):
|
||||
def __init__(self, callback: Callable[[PipelineIntermediateState], None]):
|
||||
super().__init__()
|
||||
self.callback = callback
|
||||
|
||||
# do last so that all other changes shown
|
||||
@callback("pre_denoise_loop", order=1000)
|
||||
def initial_preview(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
self.callback(
|
||||
PipelineIntermediateState(
|
||||
step=-1,
|
||||
order=ctx.scheduler.order,
|
||||
total_steps=len(ctx.timesteps),
|
||||
timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it?
|
||||
latents=ctx.latents,
|
||||
)
|
||||
)
|
||||
|
||||
# do last so that all other changes shown
|
||||
@callback("post_step", order=1000)
|
||||
def step_preview(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
if hasattr(ctx.step_output, "denoised"):
|
||||
predicted_original = ctx.step_output.denoised
|
||||
elif hasattr(ctx.step_output, "pred_original_sample"):
|
||||
predicted_original = ctx.step_output.pred_original_sample
|
||||
else:
|
||||
predicted_original = ctx.step_output.prev_sample
|
||||
|
||||
self.callback(
|
||||
PipelineIntermediateState(
|
||||
step=ctx.step_index,
|
||||
order=ctx.scheduler.order,
|
||||
total_steps=len(ctx.timesteps),
|
||||
timestep=int(ctx.timestep), # TODO: is there any code which uses it?
|
||||
latents=ctx.step_output.prev_sample,
|
||||
predicted_original=predicted_original, # TODO: is there any reason for additional field?
|
||||
)
|
||||
)
|
Loading…
Reference in New Issue
Block a user