mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add magic to debug
This commit is contained in:
parent
409e5d01ba
commit
511da59793
@ -92,7 +92,7 @@ class AddsMaskGuidance:
|
|||||||
mask: torch.FloatTensor
|
mask: torch.FloatTensor
|
||||||
mask_latents: torch.FloatTensor
|
mask_latents: torch.FloatTensor
|
||||||
scheduler: SchedulerMixin
|
scheduler: SchedulerMixin
|
||||||
noise: torch.Tensor
|
noise: Optional[torch.Tensor]
|
||||||
|
|
||||||
def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput:
|
def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput:
|
||||||
output_class = step_output.__class__ # We'll create a new one with masked data.
|
output_class = step_output.__class__ # We'll create a new one with masked data.
|
||||||
@ -124,7 +124,10 @@ class AddsMaskGuidance:
|
|||||||
t = einops.repeat(t, "-> batch", batch=batch_size)
|
t = einops.repeat(t, "-> batch", batch=batch_size)
|
||||||
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
|
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
|
||||||
# get very confused about what is happening from step to step when we do that.
|
# get very confused about what is happening from step to step when we do that.
|
||||||
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
|
if self.noise is not None:
|
||||||
|
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
|
||||||
|
else:
|
||||||
|
mask_latents = self.mask_latents.clone()
|
||||||
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||||
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||||
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
@ -368,19 +371,21 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
|
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
|
||||||
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(self._unet_forward, mask, orig_latents)
|
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(self._unet_forward, mask, orig_latents)
|
||||||
else:
|
else:
|
||||||
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
|
# TODO: debug better with or without Oo
|
||||||
if noise is None:
|
if False:
|
||||||
noise = torch.randn(
|
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
|
||||||
orig_latents.shape,
|
if noise is None:
|
||||||
dtype=torch.float32,
|
noise = torch.randn(
|
||||||
device="cpu",
|
orig_latents.shape,
|
||||||
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
|
dtype=torch.float32,
|
||||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
device="cpu",
|
||||||
|
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
|
||||||
|
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||||
|
|
||||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||||
latents = torch.lerp(
|
latents = torch.lerp(
|
||||||
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
|
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))
|
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user