mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat[web]: use the predicted denoised image for previews (#2915)
Some schedulers report not only the noisy latents at the current timestep, but also their estimate so far of what the de-noised latents will be. It makes for a more legible preview than the noisy latents do. I think this is a huge improvement, but there are a few considerations: - Need to not spook @JPPhoto by changing how previews look. - Some schedulers (most notably **DPM Solver++**) don't provide this data, and it falls back to the current behavior there. That's not terrible, but seeing such a big difference in how _previews_ look from one scheduler to the next might mislead people into thinking there's a bigger difference in their overall effectiveness than there really is. My fear of configuration-option-overwhelm leaves me inclined to _not_ add a configuration option for this, but we could.
This commit is contained in:
commit
076fac07eb
@ -1,17 +1,13 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from functools import partial
|
||||||
from typing import Any, Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from torch import Tensor
|
|
||||||
from PIL import Image
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from skimage.exposure.histogram_matching import match_histograms
|
|
||||||
|
|
||||||
from ..services.image_storage import ImageType
|
from ..services.image_storage import ImageType
|
||||||
from ..services.invocation_services import InvocationServices
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageField, ImageOutput
|
from .image import ImageField, ImageOutput
|
||||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
|
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
|
||||||
@ -45,24 +41,27 @@ class TextToImageInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self, context: InvocationContext, sample: Tensor, step: int
|
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO: only output a preview image when requested
|
step = intermediate_state.step
|
||||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
if intermediate_state.predicted_original is not None:
|
||||||
|
# Some schedulers report not only the noisy latents at the current timestep,
|
||||||
|
# but also their estimate so far of what the de-noised latents will be.
|
||||||
|
|
||||||
|
sample = intermediate_state.predicted_original
|
||||||
|
else:
|
||||||
|
sample = intermediate_state.latents
|
||||||
|
|
||||||
|
image = Generator(context.services.model_manager.get_model()).sample_to_image(sample)
|
||||||
(width, height) = image.size
|
(width, height) = image.size
|
||||||
width *= 8
|
|
||||||
height *= 8
|
|
||||||
|
|
||||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||||
|
|
||||||
context.services.events.emit_generator_progress(
|
context.services.events.emit_generator_progress(
|
||||||
context.graph_execution_state_id,
|
context.graph_execution_state_id,
|
||||||
self.id,
|
self.id,
|
||||||
{
|
{
|
||||||
"width": width,
|
"width" : width,
|
||||||
"height": height,
|
"height": height,
|
||||||
"dataURL": dataURL
|
"dataURL": dataURL,
|
||||||
},
|
},
|
||||||
step,
|
step,
|
||||||
self.steps,
|
self.steps,
|
||||||
@ -79,7 +78,7 @@ class TextToImageInvocation(BaseInvocation):
|
|||||||
model= context.services.model_manager.get_model()
|
model= context.services.model_manager.get_model()
|
||||||
outputs = Txt2Img(model).generate(
|
outputs = Txt2Img(model).generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
step_callback=step_callback,
|
step_callback=partial(self.dispatch_progress, context),
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt"}
|
exclude={"prompt"}
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
@ -126,9 +125,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
)
|
)
|
||||||
mask = None
|
mask = None
|
||||||
|
|
||||||
def step_callback(sample, step=0):
|
|
||||||
self.dispatch_progress(context, sample, step)
|
|
||||||
|
|
||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
# TODO: How to get the default model name now?
|
# TODO: How to get the default model name now?
|
||||||
@ -138,7 +134,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
init_image=image,
|
init_image=image,
|
||||||
init_mask=mask,
|
init_mask=mask,
|
||||||
step_callback=step_callback,
|
step_callback=partial(self.dispatch_progress, context),
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt", "image", "mask"}
|
exclude={"prompt", "image", "mask"}
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
@ -187,19 +183,16 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
else context.services.images.get(self.mask.image_type, self.mask.image_name)
|
else context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
def step_callback(sample, step=0):
|
|
||||||
self.dispatch_progress(context, sample, step)
|
|
||||||
|
|
||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
# TODO: How to get the default model name now?
|
# TODO: How to get the default model name now?
|
||||||
manager = context.services.model_manager.get_model()
|
model = context.services.model_manager.get_model()
|
||||||
generator_output = next(
|
generator_output = next(
|
||||||
Inpaint(model).generate(
|
Inpaint(model).generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
init_image=image,
|
init_img=image,
|
||||||
mask_image=mask,
|
init_mask=mask,
|
||||||
step_callback=step_callback,
|
step_callback=partial(self.dispatch_progress, context),
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt", "image", "mask"}
|
exclude={"prompt", "image", "mask"}
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
|
@ -1022,7 +1022,7 @@ class InvokeAIWebServer:
|
|||||||
"RGB"
|
"RGB"
|
||||||
)
|
)
|
||||||
|
|
||||||
def image_progress(sample, step):
|
def image_progress(intermediate_state: PipelineIntermediateState):
|
||||||
if self.canceled.is_set():
|
if self.canceled.is_set():
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
|
|
||||||
@ -1030,6 +1030,14 @@ class InvokeAIWebServer:
|
|||||||
nonlocal generation_parameters
|
nonlocal generation_parameters
|
||||||
nonlocal progress
|
nonlocal progress
|
||||||
|
|
||||||
|
step = intermediate_state.step
|
||||||
|
if intermediate_state.predicted_original is not None:
|
||||||
|
# Some schedulers report not only the noisy latents at the current timestep,
|
||||||
|
# but also their estimate so far of what the de-noised latents will be.
|
||||||
|
sample = intermediate_state.predicted_original
|
||||||
|
else:
|
||||||
|
sample = intermediate_state.latents
|
||||||
|
|
||||||
generation_messages = {
|
generation_messages = {
|
||||||
"txt2img": "common.statusGeneratingTextToImage",
|
"txt2img": "common.statusGeneratingTextToImage",
|
||||||
"img2img": "common.statusGeneratingImageToImage",
|
"img2img": "common.statusGeneratingImageToImage",
|
||||||
@ -1302,16 +1310,9 @@ class InvokeAIWebServer:
|
|||||||
|
|
||||||
progress.set_current_iteration(progress.current_iteration + 1)
|
progress.set_current_iteration(progress.current_iteration + 1)
|
||||||
|
|
||||||
def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
|
||||||
if isinstance(cb_args[0], PipelineIntermediateState):
|
|
||||||
progress_state: PipelineIntermediateState = cb_args[0]
|
|
||||||
return image_progress(progress_state.latents, progress_state.step)
|
|
||||||
else:
|
|
||||||
return image_progress(*cb_args, **kwargs)
|
|
||||||
|
|
||||||
self.generate.prompt2image(
|
self.generate.prompt2image(
|
||||||
**generation_parameters,
|
**generation_parameters,
|
||||||
step_callback=diffusers_step_callback_adapter,
|
step_callback=image_progress,
|
||||||
image_callback=image_done,
|
image_callback=image_done,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user