Revert "Remove the redundant init_timestep parameter that was being passed around. It is simply the first element of the timesteps array."

This reverts commit fa40061eca.
This commit is contained in:
Ryan Dick
2024-06-25 18:30:59 -04:00
parent dc23bebebf
commit bd74b84cc5
4 changed files with 20 additions and 11 deletions

View File

@ -273,6 +273,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: Optional[torch.Tensor],
seed: int,
timesteps: torch.Tensor,
init_timestep: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None],
control_data: list[ControlNetData] | None = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
@ -298,6 +299,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
same noise used earlier in the pipeline. This should really be handled in a clearer way.
timesteps: The timestep schedule for the denoising process.
init_timestep: The first timestep in the schedule.
TODO(ryand): I'm pretty sure this should always be the same as timesteps[0:1]. Confirm that that is the
case, and remove this duplicate param.
callback: A callback function that is called to report progress during the denoising process.
control_data: ControlNet data.
ip_adapter_data: IP-Adapter data.
@ -312,17 +316,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
SD UNet model.
is_gradient_mask: A flag indicating whether `mask` is a gradient mask or not.
"""
if timesteps.shape[0] == 0:
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
return latents
orig_latents = latents.clone()
batch_size = latents.shape[0]
batched_init_timestep = init_timestep.expand(batch_size)
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
if noise is not None:
# batched_init_timestep should have shape (batch_size, 1).
batched_init_timestep = timesteps[0:1].expand(batch_size)
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
# full noise. Investigate the history of why this got commented out.
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers