mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add preview extension to check logic
This commit is contained in:
parent
e961dd1dec
commit
499e4d4fde
@ -57,6 +57,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||||
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
||||||
|
from invokeai.backend.stable_diffusion.extensions import PreviewExt
|
||||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||||
@ -777,6 +778,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### preview
|
||||||
|
def step_callback(state: PipelineIntermediateState) -> None:
|
||||||
|
context.util.sd_step_callback(state, unet_config.base)
|
||||||
|
|
||||||
|
ext_manager.add_extension(PreviewExt(step_callback))
|
||||||
|
|
||||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||||
unet_config = context.models.get_config(self.unet.unet.key)
|
unet_config = context.models.get_config(self.unet.unet.key)
|
||||||
|
|
||||||
|
@ -23,19 +23,19 @@ from invokeai.app.services.config.config_default import get_config
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||||
|
from invokeai.backend.stable_diffusion.extensions import PipelineIntermediateState
|
||||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||||
|
|
||||||
|
# @dataclass
|
||||||
@dataclass
|
# class PipelineIntermediateState:
|
||||||
class PipelineIntermediateState:
|
# step: int
|
||||||
step: int
|
# order: int
|
||||||
order: int
|
# total_steps: int
|
||||||
total_steps: int
|
# timestep: int
|
||||||
timestep: int
|
# latents: torch.Tensor
|
||||||
latents: torch.Tensor
|
# predicted_original: Optional[torch.Tensor] = None
|
||||||
predicted_original: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -3,7 +3,10 @@ Initialization file for the invokeai.backend.stable_diffusion.extensions package
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState, PreviewExt
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ExtensionBase",
|
"ExtensionBase",
|
||||||
|
"PipelineIntermediateState",
|
||||||
|
"PreviewExt",
|
||||||
]
|
]
|
||||||
|
63
invokeai/backend/stable_diffusion/extensions/preview.py
Normal file
63
invokeai/backend/stable_diffusion/extensions/preview.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: change event to accept image instead of latents
|
||||||
|
@dataclass
|
||||||
|
class PipelineIntermediateState:
|
||||||
|
step: int
|
||||||
|
order: int
|
||||||
|
total_steps: int
|
||||||
|
timestep: int
|
||||||
|
latents: torch.Tensor
|
||||||
|
predicted_original: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewExt(ExtensionBase):
|
||||||
|
def __init__(self, callback: Callable[[PipelineIntermediateState], None]):
|
||||||
|
super().__init__()
|
||||||
|
self.callback = callback
|
||||||
|
|
||||||
|
# do last so that all other changes shown
|
||||||
|
@callback("pre_denoise_loop", order=1000)
|
||||||
|
def initial_preview(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||||
|
self.callback(
|
||||||
|
PipelineIntermediateState(
|
||||||
|
step=-1,
|
||||||
|
order=ctx.scheduler.order,
|
||||||
|
total_steps=len(ctx.timesteps),
|
||||||
|
timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it?
|
||||||
|
latents=ctx.latents,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# do last so that all other changes shown
|
||||||
|
@callback("post_step", order=1000)
|
||||||
|
def step_preview(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||||
|
if hasattr(ctx.step_output, "denoised"):
|
||||||
|
predicted_original = ctx.step_output.denoised
|
||||||
|
elif hasattr(ctx.step_output, "pred_original_sample"):
|
||||||
|
predicted_original = ctx.step_output.pred_original_sample
|
||||||
|
else:
|
||||||
|
predicted_original = ctx.step_output.prev_sample
|
||||||
|
|
||||||
|
self.callback(
|
||||||
|
PipelineIntermediateState(
|
||||||
|
step=ctx.step_index,
|
||||||
|
order=ctx.scheduler.order,
|
||||||
|
total_steps=len(ctx.timesteps),
|
||||||
|
timestep=int(ctx.timestep), # TODO: is there any code which uses it?
|
||||||
|
latents=ctx.step_output.prev_sample,
|
||||||
|
predicted_original=predicted_original, # TODO: is there any reason for additional field?
|
||||||
|
)
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user