feat: use the predicted denoised image for previews

Some schedulers report not only the noisy latents at the current timestep,
but also their estimate so far of what the de-noised latents will be.

It makes for a more legible preview than the noisy latents do.
This commit is contained in:
Kevin Turner 2023-03-09 20:28:06 -08:00
parent 12c7db3a16
commit fe6858f2d9
2 changed files with 23 additions and 22 deletions

View File

@ -1,6 +1,7 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone from datetime import datetime, timezone
from functools import partial
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Optional, Union
import numpy as np import numpy as np
@ -12,6 +13,7 @@ from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput from .image import ImageField, ImageOutput
from ...backend.stable_diffusion import PipelineIntermediateState
SAMPLER_NAME_VALUES = Literal[ SAMPLER_NAME_VALUES = Literal[
"ddim", "plms", "k_lms", "k_dpm_2", "k_dpm_2_a", "k_euler", "k_euler_a", "k_heun" "ddim", "plms", "k_lms", "k_dpm_2", "k_dpm_2_a", "k_euler", "k_euler_a", "k_heun"
@ -41,8 +43,15 @@ class TextToImageInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching? # TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress( def dispatch_progress(
self, context: InvocationContext, sample: Any = None, step: int = 0 self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None: ) -> None:
step = intermediate_state.step
# if intermediate_state.predicted_original is not None:
# # Some schedulers report not only the noisy latents at the current timestep,
# # but also their estimate so far of what the de-noised latents will be.
# sample = intermediate_state.predicted_original
# else:
# sample = intermediate_state.latents
context.services.events.emit_generator_progress( context.services.events.emit_generator_progress(
context.graph_execution_state_id, context.graph_execution_state_id,
self.id, self.id,
@ -51,9 +60,6 @@ class TextToImageInvocation(BaseInvocation):
) )
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter # Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now? # TODO: How to get the default model name now?
@ -65,7 +71,7 @@ class TextToImageInvocation(BaseInvocation):
results = context.services.generate.prompt2image( results = context.services.generate.prompt2image(
prompt=self.prompt, prompt=self.prompt,
step_callback=step_callback, step_callback=partial(self.dispatch_progress, context),
**self.dict( **self.dict(
exclude={"prompt"} exclude={"prompt"}
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually
@ -109,9 +115,6 @@ class ImageToImageInvocation(TextToImageInvocation):
) )
mask = None mask = None
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter # Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now? # TODO: How to get the default model name now?
@ -125,7 +128,7 @@ class ImageToImageInvocation(TextToImageInvocation):
prompt=self.prompt, prompt=self.prompt,
init_img=image, init_img=image,
init_mask=mask, init_mask=mask,
step_callback=step_callback, step_callback=partial(self.dispatch_progress, context),
**self.dict( **self.dict(
exclude={"prompt", "image", "mask"} exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually
@ -174,9 +177,6 @@ class InpaintInvocation(ImageToImageInvocation):
else context.services.images.get(self.mask.image_type, self.mask.image_name) else context.services.images.get(self.mask.image_type, self.mask.image_name)
) )
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter # Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now? # TODO: How to get the default model name now?
@ -190,7 +190,7 @@ class InpaintInvocation(ImageToImageInvocation):
prompt=self.prompt, prompt=self.prompt,
init_img=image, init_img=image,
init_mask=mask, init_mask=mask,
step_callback=step_callback, step_callback=partial(self.dispatch_progress, context),
**self.dict( **self.dict(
exclude={"prompt", "image", "mask"} exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually

View File

@ -1022,7 +1022,7 @@ class InvokeAIWebServer:
"RGB" "RGB"
) )
def image_progress(sample, step): def image_progress(intermediate_state: PipelineIntermediateState):
if self.canceled.is_set(): if self.canceled.is_set():
raise CanceledException raise CanceledException
@ -1030,6 +1030,14 @@ class InvokeAIWebServer:
nonlocal generation_parameters nonlocal generation_parameters
nonlocal progress nonlocal progress
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
generation_messages = { generation_messages = {
"txt2img": "common.statusGeneratingTextToImage", "txt2img": "common.statusGeneratingTextToImage",
"img2img": "common.statusGeneratingImageToImage", "img2img": "common.statusGeneratingImageToImage",
@ -1302,16 +1310,9 @@ class InvokeAIWebServer:
progress.set_current_iteration(progress.current_iteration + 1) progress.set_current_iteration(progress.current_iteration + 1)
def diffusers_step_callback_adapter(*cb_args, **kwargs):
if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0]
return image_progress(progress_state.latents, progress_state.step)
else:
return image_progress(*cb_args, **kwargs)
self.generate.prompt2image( self.generate.prompt2image(
**generation_parameters, **generation_parameters,
step_callback=diffusers_step_callback_adapter, step_callback=image_progress,
image_callback=image_done, image_callback=image_done,
) )