feat: use the predicted denoised image for previews

Some schedulers report not only the noisy latents at the current timestep,
but also their estimate so far of what the de-noised latents will be.

It makes for a more legible preview than the noisy latents do.
This commit is contained in:
Kevin Turner
2023-03-09 20:28:06 -08:00
parent 12c7db3a16
commit fe6858f2d9
2 changed files with 23 additions and 22 deletions

View File

@ -1,6 +1,7 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from functools import partial
from typing import Any, Literal, Optional, Union
import numpy as np
@ -12,6 +13,7 @@ from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.stable_diffusion import PipelineIntermediateState
SAMPLER_NAME_VALUES = Literal[
"ddim", "plms", "k_lms", "k_dpm_2", "k_dpm_2_a", "k_euler", "k_euler_a", "k_heun"
@ -41,8 +43,15 @@ class TextToImageInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, sample: Any = None, step: int = 0
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
step = intermediate_state.step
# if intermediate_state.predicted_original is not None:
# # Some schedulers report not only the noisy latents at the current timestep,
# # but also their estimate so far of what the de-noised latents will be.
# sample = intermediate_state.predicted_original
# else:
# sample = intermediate_state.latents
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
@ -51,9 +60,6 @@ class TextToImageInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
@ -65,7 +71,7 @@ class TextToImageInvocation(BaseInvocation):
results = context.services.generate.prompt2image(
prompt=self.prompt,
step_callback=step_callback,
step_callback=partial(self.dispatch_progress, context),
**self.dict(
exclude={"prompt"}
), # Shorthand for passing all of the parameters above manually
@ -109,9 +115,6 @@ class ImageToImageInvocation(TextToImageInvocation):
)
mask = None
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
@ -125,7 +128,7 @@ class ImageToImageInvocation(TextToImageInvocation):
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=step_callback,
step_callback=partial(self.dispatch_progress, context),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
@ -174,9 +177,6 @@ class InpaintInvocation(ImageToImageInvocation):
else context.services.images.get(self.mask.image_type, self.mask.image_name)
)
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
@ -190,7 +190,7 @@ class InpaintInvocation(ImageToImageInvocation):
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=step_callback,
step_callback=partial(self.dispatch_progress, context),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually