Add denoise_end param to FluxDenoiseInvocation.

This commit is contained in:
Ryan Dick 2024-08-30 19:13:20 +00:00
parent 661c9db7ac
commit 6675aaba4c
3 changed files with 90 additions and 5 deletions

View File

@ -23,6 +23,7 @@ from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.inpaint_extension import InpaintExtension from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.model import Flux from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.sampling_utils import ( from invokeai.backend.flux.sampling_utils import (
clip_timestep_schedule,
generate_img_ids, generate_img_ids,
get_noise, get_noise,
get_schedule, get_schedule,
@ -62,6 +63,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
le=1, le=1,
description=FieldDescriptions.denoising_start, description=FieldDescriptions.denoising_start,
) )
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
transformer: TransformerField = InputField( transformer: TransformerField = InputField(
description=FieldDescriptions.flux_model, description=FieldDescriptions.flux_model,
input=Input.Connection, input=Input.Connection,
@ -130,6 +132,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
shift=not is_schnell, shift=not is_schnell,
) )
# Clip the timesteps schedule based on denoising_start and denoising_end.
timesteps = clip_timestep_schedule(timesteps, self.denoising_start, self.denoising_end)
# Prepare input latent image. # Prepare input latent image.
if init_latents is not None: if init_latents is not None:
# If init_latents is provided, we are doing image-to-image. # If init_latents is provided, we are doing image-to-image.
@ -140,11 +145,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"to be poor. Consider using a FLUX dev model instead." "to be poor. Consider using a FLUX dev model instead."
) )
# Clip the timesteps schedule based on denoising_start.
# TODO(ryand): Should we apply denoising_start in timestep-space rather than timestep-index-space?
start_idx = int(self.denoising_start * len(timesteps))
timesteps = timesteps[start_idx:]
# Noise the orig_latents by the appropriate amount for the first timestep. # Noise the orig_latents by the appropriate amount for the first timestep.
t_0 = timesteps[0] t_0 = timesteps[0]
x = t_0 * noise + (1.0 - t_0) * init_latents x = t_0 * noise + (1.0 - t_0) * init_latents
@ -155,6 +155,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
x = noise x = noise
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
# denoising steps.
if len(timesteps) <= 1:
return x
inpaint_mask = self._prep_inpaint_mask(context, x) inpaint_mask = self._prep_inpaint_mask(context, x)
b, _c, h, w = x.shape b, _c, h, w = x.shape

View File

@ -59,6 +59,44 @@ def get_schedule(
return timesteps.tolist() return timesteps.tolist()
def _find_last_index_ge_val(timesteps: list[float], val: float, eps: float = 1e-6) -> int:
"""Find the last index in timesteps that is >= val.
We use epsilon-close equality to avoid potential floating point errors.
"""
idx = len(list(filter(lambda t: t >= (val - eps), timesteps))) - 1
assert idx >= 0
return idx
def clip_timestep_schedule(timesteps: list[float], denoising_start: float, denoising_end: float) -> list[float]:
"""Clip the timestep schedule to the denoising range.
Args:
timesteps (list[float]): The original timestep schedule: [1.0, ..., 0.0].
denoising_start (float): A value in [0, 1] specifying the start of the denoising process. E.g. a value of 0.2
would mean that the denoising process start at the last timestep in the schedule >= 0.8.
denoising_end (float): A value in [0, 1] specifying the end of the denoising process. E.g. a value of 0.8 would
mean that the denoising process end at the last timestep in the schedule >= 0.2.
Returns:
list[float]: The clipped timestep schedule.
"""
assert 0.0 <= denoising_start <= 1.0
assert 0.0 <= denoising_end <= 1.0
assert denoising_start <= denoising_end
t_start_val = 1.0 - denoising_start
t_end_val = 1.0 - denoising_end
t_start_idx = _find_last_index_ge_val(timesteps, t_start_val)
t_end_idx = _find_last_index_ge_val(timesteps, t_end_val)
clipped_timesteps = timesteps[t_start_idx : t_end_idx + 1]
return clipped_timesteps
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor: def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""Unpack flat array of patch embeddings to latent image.""" """Unpack flat array of patch embeddings to latent image."""
return rearrange( return rearrange(

View File

@ -0,0 +1,42 @@
import pytest
import torch
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule
def float_lists_almost_equal(list1: list[float], list2: list[float], tol: float = 1e-6) -> bool:
return all(abs(a - b) < tol for a, b in zip(list1, list2, strict=True))
@pytest.mark.parametrize(
["denoising_start", "denoising_end", "expected_timesteps", "raises"],
[
(0.0, 1.0, [1.0, 0.75, 0.5, 0.25, 0.0], False), # Default case.
(-0.1, 1.0, [], True), # Negative denoising_start should raise.
(0.0, 1.1, [], True), # denoising_end > 1 should raise.
(0.5, 0.0, [], True), # denoising_start > denoising_end should raise.
(0.0, 0.0, [1.0], False), # denoising_end == 0.
(1.0, 1.0, [0.0], False), # denoising_start == 1.
(0.2, 0.8, [1.0, 0.75, 0.5, 0.25], False), # Middle of the schedule.
# If we denoise from 0.0 to x, then from x to 1.0, it is important that denoise_end = x and denoise_start = x
# map to the same timestep. We test this first when x is equal to a timestep, then when it falls between two
# timesteps.
# x = 0.5
(0.0, 0.5, [1.0, 0.75, 0.5], False),
(0.5, 1.0, [0.5, 0.25, 0.0], False),
# x = 0.3
(0.0, 0.3, [1.0, 0.75], False),
(0.3, 1.0, [0.75, 0.5, 0.25, 0.0], False),
],
)
def test_clip_timestep_schedule(
denoising_start: float, denoising_end: float, expected_timesteps: list[float], raises: bool
):
timesteps = torch.linspace(1, 0, 5).tolist()
if raises:
with pytest.raises(AssertionError):
clip_timestep_schedule(timesteps, denoising_start, denoising_end)
else:
assert float_lists_almost_equal(
clip_timestep_schedule(timesteps, denoising_start, denoising_end), expected_timesteps
)