mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add denoise_end param to FluxDenoiseInvocation.
This commit is contained in:
parent
661c9db7ac
commit
6675aaba4c
@ -23,6 +23,7 @@ from invokeai.backend.flux.denoise import denoise
|
|||||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||||
from invokeai.backend.flux.model import Flux
|
from invokeai.backend.flux.model import Flux
|
||||||
from invokeai.backend.flux.sampling_utils import (
|
from invokeai.backend.flux.sampling_utils import (
|
||||||
|
clip_timestep_schedule,
|
||||||
generate_img_ids,
|
generate_img_ids,
|
||||||
get_noise,
|
get_noise,
|
||||||
get_schedule,
|
get_schedule,
|
||||||
@ -62,6 +63,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
le=1,
|
le=1,
|
||||||
description=FieldDescriptions.denoising_start,
|
description=FieldDescriptions.denoising_start,
|
||||||
)
|
)
|
||||||
|
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||||
transformer: TransformerField = InputField(
|
transformer: TransformerField = InputField(
|
||||||
description=FieldDescriptions.flux_model,
|
description=FieldDescriptions.flux_model,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
@ -130,6 +132,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
shift=not is_schnell,
|
shift=not is_schnell,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Clip the timesteps schedule based on denoising_start and denoising_end.
|
||||||
|
timesteps = clip_timestep_schedule(timesteps, self.denoising_start, self.denoising_end)
|
||||||
|
|
||||||
# Prepare input latent image.
|
# Prepare input latent image.
|
||||||
if init_latents is not None:
|
if init_latents is not None:
|
||||||
# If init_latents is provided, we are doing image-to-image.
|
# If init_latents is provided, we are doing image-to-image.
|
||||||
@ -140,11 +145,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
"to be poor. Consider using a FLUX dev model instead."
|
"to be poor. Consider using a FLUX dev model instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clip the timesteps schedule based on denoising_start.
|
|
||||||
# TODO(ryand): Should we apply denoising_start in timestep-space rather than timestep-index-space?
|
|
||||||
start_idx = int(self.denoising_start * len(timesteps))
|
|
||||||
timesteps = timesteps[start_idx:]
|
|
||||||
|
|
||||||
# Noise the orig_latents by the appropriate amount for the first timestep.
|
# Noise the orig_latents by the appropriate amount for the first timestep.
|
||||||
t_0 = timesteps[0]
|
t_0 = timesteps[0]
|
||||||
x = t_0 * noise + (1.0 - t_0) * init_latents
|
x = t_0 * noise + (1.0 - t_0) * init_latents
|
||||||
@ -155,6 +155,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
x = noise
|
x = noise
|
||||||
|
|
||||||
|
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
|
||||||
|
# denoising steps.
|
||||||
|
if len(timesteps) <= 1:
|
||||||
|
return x
|
||||||
|
|
||||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||||
|
|
||||||
b, _c, h, w = x.shape
|
b, _c, h, w = x.shape
|
||||||
|
@ -59,6 +59,44 @@ def get_schedule(
|
|||||||
return timesteps.tolist()
|
return timesteps.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def _find_last_index_ge_val(timesteps: list[float], val: float, eps: float = 1e-6) -> int:
|
||||||
|
"""Find the last index in timesteps that is >= val.
|
||||||
|
|
||||||
|
We use epsilon-close equality to avoid potential floating point errors.
|
||||||
|
"""
|
||||||
|
idx = len(list(filter(lambda t: t >= (val - eps), timesteps))) - 1
|
||||||
|
assert idx >= 0
|
||||||
|
return idx
|
||||||
|
|
||||||
|
|
||||||
|
def clip_timestep_schedule(timesteps: list[float], denoising_start: float, denoising_end: float) -> list[float]:
|
||||||
|
"""Clip the timestep schedule to the denoising range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timesteps (list[float]): The original timestep schedule: [1.0, ..., 0.0].
|
||||||
|
denoising_start (float): A value in [0, 1] specifying the start of the denoising process. E.g. a value of 0.2
|
||||||
|
would mean that the denoising process start at the last timestep in the schedule >= 0.8.
|
||||||
|
denoising_end (float): A value in [0, 1] specifying the end of the denoising process. E.g. a value of 0.8 would
|
||||||
|
mean that the denoising process end at the last timestep in the schedule >= 0.2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float]: The clipped timestep schedule.
|
||||||
|
"""
|
||||||
|
assert 0.0 <= denoising_start <= 1.0
|
||||||
|
assert 0.0 <= denoising_end <= 1.0
|
||||||
|
assert denoising_start <= denoising_end
|
||||||
|
|
||||||
|
t_start_val = 1.0 - denoising_start
|
||||||
|
t_end_val = 1.0 - denoising_end
|
||||||
|
|
||||||
|
t_start_idx = _find_last_index_ge_val(timesteps, t_start_val)
|
||||||
|
t_end_idx = _find_last_index_ge_val(timesteps, t_end_val)
|
||||||
|
|
||||||
|
clipped_timesteps = timesteps[t_start_idx : t_end_idx + 1]
|
||||||
|
|
||||||
|
return clipped_timesteps
|
||||||
|
|
||||||
|
|
||||||
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
"""Unpack flat array of patch embeddings to latent image."""
|
"""Unpack flat array of patch embeddings to latent image."""
|
||||||
return rearrange(
|
return rearrange(
|
||||||
|
42
tests/backend/flux/test_sampling_utils.py
Normal file
42
tests/backend/flux/test_sampling_utils.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule
|
||||||
|
|
||||||
|
|
||||||
|
def float_lists_almost_equal(list1: list[float], list2: list[float], tol: float = 1e-6) -> bool:
|
||||||
|
return all(abs(a - b) < tol for a, b in zip(list1, list2, strict=True))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["denoising_start", "denoising_end", "expected_timesteps", "raises"],
|
||||||
|
[
|
||||||
|
(0.0, 1.0, [1.0, 0.75, 0.5, 0.25, 0.0], False), # Default case.
|
||||||
|
(-0.1, 1.0, [], True), # Negative denoising_start should raise.
|
||||||
|
(0.0, 1.1, [], True), # denoising_end > 1 should raise.
|
||||||
|
(0.5, 0.0, [], True), # denoising_start > denoising_end should raise.
|
||||||
|
(0.0, 0.0, [1.0], False), # denoising_end == 0.
|
||||||
|
(1.0, 1.0, [0.0], False), # denoising_start == 1.
|
||||||
|
(0.2, 0.8, [1.0, 0.75, 0.5, 0.25], False), # Middle of the schedule.
|
||||||
|
# If we denoise from 0.0 to x, then from x to 1.0, it is important that denoise_end = x and denoise_start = x
|
||||||
|
# map to the same timestep. We test this first when x is equal to a timestep, then when it falls between two
|
||||||
|
# timesteps.
|
||||||
|
# x = 0.5
|
||||||
|
(0.0, 0.5, [1.0, 0.75, 0.5], False),
|
||||||
|
(0.5, 1.0, [0.5, 0.25, 0.0], False),
|
||||||
|
# x = 0.3
|
||||||
|
(0.0, 0.3, [1.0, 0.75], False),
|
||||||
|
(0.3, 1.0, [0.75, 0.5, 0.25, 0.0], False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_clip_timestep_schedule(
|
||||||
|
denoising_start: float, denoising_end: float, expected_timesteps: list[float], raises: bool
|
||||||
|
):
|
||||||
|
timesteps = torch.linspace(1, 0, 5).tolist()
|
||||||
|
if raises:
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
clip_timestep_schedule(timesteps, denoising_start, denoising_end)
|
||||||
|
else:
|
||||||
|
assert float_lists_almost_equal(
|
||||||
|
clip_timestep_schedule(timesteps, denoising_start, denoising_end), expected_timesteps
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user