feat[web]: use the predicted denoised image for previews (#2915)

Some schedulers report not only the noisy latents at the current
timestep, but also their estimate so far of what the de-noised latents
will be.

It makes for a more legible preview than the noisy latents do.

I think this is a huge improvement, but there are a few considerations:
- Need to not spook @JPPhoto by changing how previews look.
- Some schedulers (most notably **DPM Solver++**) don't provide this
data, and it falls back to the current behavior there. That's not
terrible, but seeing such a big difference in how _previews_ look from
one scheduler to the next might mislead people into thinking there's a
bigger difference in their overall effectiveness than there really is.

My fear of configuration-option-overwhelm leaves me inclined to _not_
add a configuration option for this, but we could.
This commit is contained in:
Lincoln Stein 2023-03-26 00:29:00 -04:00 committed by GitHub
commit 076fac07eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 38 deletions

View File

@ -1,17 +1,13 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Any, Literal, Optional, Union
from functools import partial
from typing import Literal, Optional, Union
import numpy as np
from torch import Tensor
from PIL import Image
from pydantic import Field
from skimage.exposure.histogram_matching import match_histograms
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
@ -45,24 +41,27 @@ class TextToImageInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, sample: Tensor, step: int
) -> None:
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
image = Generator(context.services.model_manager.get_model()).sample_to_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width": width,
"width" : width,
"height": height,
"dataURL": dataURL
"dataURL": dataURL,
},
step,
self.steps,
@ -79,7 +78,7 @@ class TextToImageInvocation(BaseInvocation):
model= context.services.model_manager.get_model()
outputs = Txt2Img(model).generate(
prompt=self.prompt,
step_callback=step_callback,
step_callback=partial(self.dispatch_progress, context),
**self.dict(
exclude={"prompt"}
), # Shorthand for passing all of the parameters above manually
@ -126,9 +125,6 @@ class ImageToImageInvocation(TextToImageInvocation):
)
mask = None
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
@ -138,7 +134,7 @@ class ImageToImageInvocation(TextToImageInvocation):
prompt=self.prompt,
init_image=image,
init_mask=mask,
step_callback=step_callback,
step_callback=partial(self.dispatch_progress, context),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
@ -187,19 +183,16 @@ class InpaintInvocation(ImageToImageInvocation):
else context.services.images.get(self.mask.image_type, self.mask.image_name)
)
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
manager = context.services.model_manager.get_model()
model = context.services.model_manager.get_model()
generator_output = next(
Inpaint(model).generate(
prompt=self.prompt,
init_image=image,
mask_image=mask,
step_callback=step_callback,
init_img=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually

View File

@ -1022,7 +1022,7 @@ class InvokeAIWebServer:
"RGB"
)
def image_progress(sample, step):
def image_progress(intermediate_state: PipelineIntermediateState):
if self.canceled.is_set():
raise CanceledException
@ -1030,6 +1030,14 @@ class InvokeAIWebServer:
nonlocal generation_parameters
nonlocal progress
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
generation_messages = {
"txt2img": "common.statusGeneratingTextToImage",
"img2img": "common.statusGeneratingImageToImage",
@ -1302,16 +1310,9 @@ class InvokeAIWebServer:
progress.set_current_iteration(progress.current_iteration + 1)
def diffusers_step_callback_adapter(*cb_args, **kwargs):
if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0]
return image_progress(progress_state.latents, progress_state.step)
else:
return image_progress(*cb_args, **kwargs)
self.generate.prompt2image(
**generation_parameters,
step_callback=diffusers_step_callback_adapter,
step_callback=image_progress,
image_callback=image_done,
)