mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Simplify the logic in prepare_noise_and_latents(...).
This commit is contained in:
parent
610a1fd611
commit
8004a0d5f5
@ -663,27 +663,27 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
def prepare_noise_and_latents(
|
def prepare_noise_and_latents(
|
||||||
context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None
|
context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None
|
||||||
) -> Tuple[float, torch.Tensor | None, torch.Tensor]:
|
) -> Tuple[float, torch.Tensor | None, torch.Tensor]:
|
||||||
seed = None
|
|
||||||
noise = None
|
noise = None
|
||||||
|
|
||||||
if noise_field is not None:
|
if noise_field is not None:
|
||||||
noise = context.tensors.load(noise_field.latents_name)
|
noise = context.tensors.load(noise_field.latents_name)
|
||||||
seed = noise_field.seed
|
|
||||||
|
|
||||||
if latents_field is not None:
|
if latents_field is not None:
|
||||||
latents = context.tensors.load(latents_field.latents_name)
|
latents = context.tensors.load(latents_field.latents_name)
|
||||||
if seed is None:
|
|
||||||
seed = latents_field.seed
|
|
||||||
|
|
||||||
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
|
||||||
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
|
||||||
|
|
||||||
elif noise is not None:
|
elif noise is not None:
|
||||||
latents = torch.zeros_like(noise)
|
latents = torch.zeros_like(noise)
|
||||||
else:
|
else:
|
||||||
raise Exception("'latents' or 'noise' must be provided!")
|
raise ValueError("'latents' or 'noise' must be provided!")
|
||||||
|
|
||||||
if seed is None:
|
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
||||||
|
raise ValueError(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
||||||
|
|
||||||
|
# The seed comes from (in order of priority): the noise field, the latents field, or 0.
|
||||||
|
seed = 0
|
||||||
|
if noise_field is not None and noise_field.seed is not None:
|
||||||
|
seed = noise_field.seed
|
||||||
|
elif latents_field is not None and latents_field.seed is not None:
|
||||||
|
seed = latents_field.seed
|
||||||
|
else:
|
||||||
seed = 0
|
seed = 0
|
||||||
|
|
||||||
return seed, noise, latents
|
return seed, noise, latents
|
||||||
|
Loading…
Reference in New Issue
Block a user