Simplify the logic in prepare_noise_and_latents(...).

This commit is contained in:
Ryan Dick 2024-06-06 15:16:34 -04:00 committed by Kent Keirsey
parent 610a1fd611
commit 8004a0d5f5

View File

@ -663,27 +663,27 @@ class DenoiseLatentsInvocation(BaseInvocation):
def prepare_noise_and_latents( def prepare_noise_and_latents(
context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None
) -> Tuple[float, torch.Tensor | None, torch.Tensor]: ) -> Tuple[float, torch.Tensor | None, torch.Tensor]:
seed = None
noise = None noise = None
if noise_field is not None: if noise_field is not None:
noise = context.tensors.load(noise_field.latents_name) noise = context.tensors.load(noise_field.latents_name)
seed = noise_field.seed
if latents_field is not None: if latents_field is not None:
latents = context.tensors.load(latents_field.latents_name) latents = context.tensors.load(latents_field.latents_name)
if seed is None:
seed = latents_field.seed
if noise is not None and noise.shape[1:] != latents.shape[1:]:
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
elif noise is not None: elif noise is not None:
latents = torch.zeros_like(noise) latents = torch.zeros_like(noise)
else: else:
raise Exception("'latents' or 'noise' must be provided!") raise ValueError("'latents' or 'noise' must be provided!")
if seed is None: if noise is not None and noise.shape[1:] != latents.shape[1:]:
raise ValueError(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
# The seed comes from (in order of priority): the noise field, the latents field, or 0.
seed = 0
if noise_field is not None and noise_field.seed is not None:
seed = noise_field.seed
elif latents_field is not None and latents_field.seed is not None:
seed = latents_field.seed
else:
seed = 0 seed = 0
return seed, noise, latents return seed, noise, latents