diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 5312ba7fcd..2fbbc549fe 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -23,6 +23,7 @@ from invokeai.backend.flux.denoise import denoise from invokeai.backend.flux.inpaint_extension import InpaintExtension from invokeai.backend.flux.model import Flux from invokeai.backend.flux.sampling_utils import ( + clip_timestep_schedule, generate_img_ids, get_noise, get_schedule, @@ -62,6 +63,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): le=1, description=FieldDescriptions.denoising_start, ) + denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) transformer: TransformerField = InputField( description=FieldDescriptions.flux_model, input=Input.Connection, @@ -130,6 +132,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): shift=not is_schnell, ) + # Clip the timesteps schedule based on denoising_start and denoising_end. + timesteps = clip_timestep_schedule(timesteps, self.denoising_start, self.denoising_end) + # Prepare input latent image. if init_latents is not None: # If init_latents is provided, we are doing image-to-image. @@ -140,11 +145,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): "to be poor. Consider using a FLUX dev model instead." ) - # Clip the timesteps schedule based on denoising_start. - # TODO(ryand): Should we apply denoising_start in timestep-space rather than timestep-index-space? - start_idx = int(self.denoising_start * len(timesteps)) - timesteps = timesteps[start_idx:] - # Noise the orig_latents by the appropriate amount for the first timestep. t_0 = timesteps[0] x = t_0 * noise + (1.0 - t_0) * init_latents @@ -155,6 +155,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): x = noise + # If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any + # denoising steps. + if len(timesteps) <= 1: + return x + inpaint_mask = self._prep_inpaint_mask(context, x) b, _c, h, w = x.shape diff --git a/invokeai/backend/flux/sampling_utils.py b/invokeai/backend/flux/sampling_utils.py index 4be15c491f..92c78de01e 100644 --- a/invokeai/backend/flux/sampling_utils.py +++ b/invokeai/backend/flux/sampling_utils.py @@ -59,6 +59,44 @@ def get_schedule( return timesteps.tolist() +def _find_last_index_ge_val(timesteps: list[float], val: float, eps: float = 1e-6) -> int: + """Find the last index in timesteps that is >= val. + + We use epsilon-close equality to avoid potential floating point errors. + """ + idx = len(list(filter(lambda t: t >= (val - eps), timesteps))) - 1 + assert idx >= 0 + return idx + + +def clip_timestep_schedule(timesteps: list[float], denoising_start: float, denoising_end: float) -> list[float]: + """Clip the timestep schedule to the denoising range. + + Args: + timesteps (list[float]): The original timestep schedule: [1.0, ..., 0.0]. + denoising_start (float): A value in [0, 1] specifying the start of the denoising process. E.g. a value of 0.2 + would mean that the denoising process start at the last timestep in the schedule >= 0.8. + denoising_end (float): A value in [0, 1] specifying the end of the denoising process. E.g. a value of 0.8 would + mean that the denoising process end at the last timestep in the schedule >= 0.2. + + Returns: + list[float]: The clipped timestep schedule. + """ + assert 0.0 <= denoising_start <= 1.0 + assert 0.0 <= denoising_end <= 1.0 + assert denoising_start <= denoising_end + + t_start_val = 1.0 - denoising_start + t_end_val = 1.0 - denoising_end + + t_start_idx = _find_last_index_ge_val(timesteps, t_start_val) + t_end_idx = _find_last_index_ge_val(timesteps, t_end_val) + + clipped_timesteps = timesteps[t_start_idx : t_end_idx + 1] + + return clipped_timesteps + + def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor: """Unpack flat array of patch embeddings to latent image.""" return rearrange( diff --git a/tests/backend/flux/test_sampling_utils.py b/tests/backend/flux/test_sampling_utils.py new file mode 100644 index 0000000000..f7264492d5 --- /dev/null +++ b/tests/backend/flux/test_sampling_utils.py @@ -0,0 +1,42 @@ +import pytest +import torch + +from invokeai.backend.flux.sampling_utils import clip_timestep_schedule + + +def float_lists_almost_equal(list1: list[float], list2: list[float], tol: float = 1e-6) -> bool: + return all(abs(a - b) < tol for a, b in zip(list1, list2, strict=True)) + + +@pytest.mark.parametrize( + ["denoising_start", "denoising_end", "expected_timesteps", "raises"], + [ + (0.0, 1.0, [1.0, 0.75, 0.5, 0.25, 0.0], False), # Default case. + (-0.1, 1.0, [], True), # Negative denoising_start should raise. + (0.0, 1.1, [], True), # denoising_end > 1 should raise. + (0.5, 0.0, [], True), # denoising_start > denoising_end should raise. + (0.0, 0.0, [1.0], False), # denoising_end == 0. + (1.0, 1.0, [0.0], False), # denoising_start == 1. + (0.2, 0.8, [1.0, 0.75, 0.5, 0.25], False), # Middle of the schedule. + # If we denoise from 0.0 to x, then from x to 1.0, it is important that denoise_end = x and denoise_start = x + # map to the same timestep. We test this first when x is equal to a timestep, then when it falls between two + # timesteps. + # x = 0.5 + (0.0, 0.5, [1.0, 0.75, 0.5], False), + (0.5, 1.0, [0.5, 0.25, 0.0], False), + # x = 0.3 + (0.0, 0.3, [1.0, 0.75], False), + (0.3, 1.0, [0.75, 0.5, 0.25, 0.0], False), + ], +) +def test_clip_timestep_schedule( + denoising_start: float, denoising_end: float, expected_timesteps: list[float], raises: bool +): + timesteps = torch.linspace(1, 0, 5).tolist() + if raises: + with pytest.raises(AssertionError): + clip_timestep_schedule(timesteps, denoising_start, denoising_end) + else: + assert float_lists_almost_equal( + clip_timestep_schedule(timesteps, denoising_start, denoising_end), expected_timesteps + )