mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: use the predicted denoised image for previews
Some schedulers report not only the noisy latents at the current timestep, but also their estimate so far of what the de-noised latents will be. It makes for a more legible preview than the noisy latents do.
This commit is contained in:
parent
12c7db3a16
commit
fe6858f2d9
@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@ -12,6 +13,7 @@ from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
"ddim", "plms", "k_lms", "k_dpm_2", "k_dpm_2_a", "k_euler", "k_euler_a", "k_heun"
|
||||
@ -41,8 +43,15 @@ class TextToImageInvocation(BaseInvocation):
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, sample: Any = None, step: int = 0
|
||||
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||
) -> None:
|
||||
step = intermediate_state.step
|
||||
# if intermediate_state.predicted_original is not None:
|
||||
# # Some schedulers report not only the noisy latents at the current timestep,
|
||||
# # but also their estimate so far of what the de-noised latents will be.
|
||||
# sample = intermediate_state.predicted_original
|
||||
# else:
|
||||
# sample = intermediate_state.latents
|
||||
context.services.events.emit_generator_progress(
|
||||
context.graph_execution_state_id,
|
||||
self.id,
|
||||
@ -51,9 +60,6 @@ class TextToImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
def step_callback(sample, step=0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
@ -65,7 +71,7 @@ class TextToImageInvocation(BaseInvocation):
|
||||
|
||||
results = context.services.generate.prompt2image(
|
||||
prompt=self.prompt,
|
||||
step_callback=step_callback,
|
||||
step_callback=partial(self.dispatch_progress, context),
|
||||
**self.dict(
|
||||
exclude={"prompt"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
@ -109,9 +115,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
)
|
||||
mask = None
|
||||
|
||||
def step_callback(sample, step=0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
@ -125,7 +128,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
step_callback=step_callback,
|
||||
step_callback=partial(self.dispatch_progress, context),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
@ -174,9 +177,6 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
else context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
)
|
||||
|
||||
def step_callback(sample, step=0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
@ -190,7 +190,7 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
step_callback=step_callback,
|
||||
step_callback=partial(self.dispatch_progress, context),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
|
@ -1022,7 +1022,7 @@ class InvokeAIWebServer:
|
||||
"RGB"
|
||||
)
|
||||
|
||||
def image_progress(sample, step):
|
||||
def image_progress(intermediate_state: PipelineIntermediateState):
|
||||
if self.canceled.is_set():
|
||||
raise CanceledException
|
||||
|
||||
@ -1030,6 +1030,14 @@ class InvokeAIWebServer:
|
||||
nonlocal generation_parameters
|
||||
nonlocal progress
|
||||
|
||||
step = intermediate_state.step
|
||||
if intermediate_state.predicted_original is not None:
|
||||
# Some schedulers report not only the noisy latents at the current timestep,
|
||||
# but also their estimate so far of what the de-noised latents will be.
|
||||
sample = intermediate_state.predicted_original
|
||||
else:
|
||||
sample = intermediate_state.latents
|
||||
|
||||
generation_messages = {
|
||||
"txt2img": "common.statusGeneratingTextToImage",
|
||||
"img2img": "common.statusGeneratingImageToImage",
|
||||
@ -1302,16 +1310,9 @@ class InvokeAIWebServer:
|
||||
|
||||
progress.set_current_iteration(progress.current_iteration + 1)
|
||||
|
||||
def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
||||
if isinstance(cb_args[0], PipelineIntermediateState):
|
||||
progress_state: PipelineIntermediateState = cb_args[0]
|
||||
return image_progress(progress_state.latents, progress_state.step)
|
||||
else:
|
||||
return image_progress(*cb_args, **kwargs)
|
||||
|
||||
self.generate.prompt2image(
|
||||
**generation_parameters,
|
||||
step_callback=diffusers_step_callback_adapter,
|
||||
step_callback=image_progress,
|
||||
image_callback=image_done,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user