diffusers(AddsMaskedGuidance): partial fix for k-schedulers

Prevents them from crashing, but results are still hot garbage.
This commit is contained in:
Kevin Turner 2022-12-10 21:19:32 -08:00
parent 4fa26e82e8
commit 520c17ab86

View File

@ -90,7 +90,7 @@ def are_like_tensors(a: torch.Tensor, b: object) -> bool:
class AddsMaskGuidance: class AddsMaskGuidance:
mask: torch.FloatTensor mask: torch.FloatTensor
mask_latents: torch.FloatTensor mask_latents: torch.FloatTensor
_scheduler: SchedulerMixin scheduler: SchedulerMixin
noise: torch.Tensor noise: torch.Tensor
_debug: Optional[Callable] = None _debug: Optional[Callable] = None
@ -117,10 +117,15 @@ class AddsMaskGuidance:
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor: def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
batch_size = latents.size(0) batch_size = latents.size(0)
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size) mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
if t.dim() == 0:
# some schedulers expect t to be one-dimensional.
# TODO: file diffusers bug about inconsistency?
t = einops.repeat(t, '-> batch', batch=batch_size)
# Noise shouldn't be re-randomized between steps here. The multistep schedulers # Noise shouldn't be re-randomized between steps here. The multistep schedulers
# get very confused about what is happening from step to step when we do that. # get very confused about what is happening from step to step when we do that.
mask_latents = self._scheduler.add_noise(self.mask_latents, self.noise, t) mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already? # TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size) mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
if self._debug: if self._debug: