mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(nodes): doubly-noised latents
When using refiner with a mask (i.e. inpainting), we don't have noise provided as an input to the node.
This situation uniquely hits a code path that wasn't reviewed when gradient denoising was implemented.
That code path does two things wrong:
- It lerp'd the input latents. This was fixed in 5a1f4cb1ce
.
- It added noise to the latents an extra time. This is fixed in this change.
We don't need to add noise in `latents_from_embeddings` because we do it just a lines later in `AddsMaskGuidance`.
- Remove the extraneous call to `add_noise`
- Make `seed` a required arg. We never call the function without seed anyways. If we refactor this in the future, it will be clearer that we need to look at how seed is handled.
- Move the call to create the noise to a deeper conditional, just before we call `AddsMaskGuidance`. The created noise tensor is now only used in that function, no need to create it every time.
Note: Whether or not having both noise and latents as inputs on the node is correct is a separate conversation. This change just fixes the issue with the current setup.
This commit is contained in:
parent
026d095afe
commit
7bc77ddb40
@ -301,7 +301,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
masked_latents: Optional[torch.Tensor] = None,
|
||||
gradient_mask: Optional[bool] = False,
|
||||
seed: Optional[int] = None,
|
||||
seed: int,
|
||||
) -> torch.Tensor:
|
||||
if init_timestep.shape[0] == 0:
|
||||
return latents
|
||||
@ -319,17 +319,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
if mask is not None:
|
||||
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
|
||||
if noise is None:
|
||||
noise = torch.randn(
|
||||
orig_latents.shape,
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
|
||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
if is_inpainting_model(self.unet):
|
||||
if masked_latents is None:
|
||||
raise Exception("Source image required for inpaint mask when inpaint model used!")
|
||||
@ -338,6 +327,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self._unet_forward, mask, masked_latents
|
||||
)
|
||||
else:
|
||||
# if no noise provided, noisify unmasked area based on seed
|
||||
if noise is None:
|
||||
noise = torch.randn(
|
||||
orig_latents.shape,
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||
|
||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
|
||||
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user