InvokeAI/tests/backend/flux/test_sampling_utils.py

43 lines
1.8 KiB
Python

import pytest
import torch
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule
def float_lists_almost_equal(list1: list[float], list2: list[float], tol: float = 1e-6) -> bool:
return all(abs(a - b) < tol for a, b in zip(list1, list2, strict=True))
@pytest.mark.parametrize(
["denoising_start", "denoising_end", "expected_timesteps", "raises"],
[
(0.0, 1.0, [1.0, 0.75, 0.5, 0.25, 0.0], False), # Default case.
(-0.1, 1.0, [], True), # Negative denoising_start should raise.
(0.0, 1.1, [], True), # denoising_end > 1 should raise.
(0.5, 0.0, [], True), # denoising_start > denoising_end should raise.
(0.0, 0.0, [1.0], False), # denoising_end == 0.
(1.0, 1.0, [0.0], False), # denoising_start == 1.
(0.2, 0.8, [1.0, 0.75, 0.5, 0.25], False), # Middle of the schedule.
# If we denoise from 0.0 to x, then from x to 1.0, it is important that denoise_end = x and denoise_start = x
# map to the same timestep. We test this first when x is equal to a timestep, then when it falls between two
# timesteps.
# x = 0.5
(0.0, 0.5, [1.0, 0.75, 0.5], False),
(0.5, 1.0, [0.5, 0.25, 0.0], False),
# x = 0.3
(0.0, 0.3, [1.0, 0.75], False),
(0.3, 1.0, [0.75, 0.5, 0.25, 0.0], False),
],
)
def test_clip_timestep_schedule(
denoising_start: float, denoising_end: float, expected_timesteps: list[float], raises: bool
):
timesteps = torch.linspace(1, 0, 5).tolist()
if raises:
with pytest.raises(AssertionError):
clip_timestep_schedule(timesteps, denoising_start, denoising_end)
else:
assert float_lists_almost_equal(
clip_timestep_schedule(timesteps, denoising_start, denoising_end), expected_timesteps
)