Simplify handling of AddsMaskGuidance, and fix some related type errors.

This commit is contained in:
Ryan Dick 2024-06-12 14:27:49 -04:00
parent cb389063b2
commit 7d24ad8ccd

View File

@ -78,8 +78,8 @@ def are_like_tensors(a: torch.Tensor, b: object) -> bool:
@dataclass
class AddsMaskGuidance:
mask: torch.FloatTensor
mask_latents: torch.FloatTensor
mask: torch.Tensor
mask_latents: torch.Tensor
scheduler: SchedulerMixin
noise: torch.Tensor
gradient_mask: bool
@ -87,7 +87,7 @@ class AddsMaskGuidance:
def __call__(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return self.apply_mask(latents, t)
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
def apply_mask(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
batch_size = latents.size(0)
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
if t.dim() == 0:
@ -285,11 +285,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents: torch.Tensor,
scheduler_step_kwargs: dict[str, Any],
conditioning_data: TextConditioningData,
*,
noise: Optional[torch.Tensor],
timesteps: torch.Tensor,
init_timestep: torch.Tensor,
additional_guidance: List[Callable] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
@ -302,9 +300,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if init_timestep.shape[0] == 0:
return latents
if additional_guidance is None:
additional_guidance = []
orig_latents = latents.clone()
batch_size = latents.shape[0]
@ -314,6 +309,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
latents = self.scheduler.add_noise(latents, noise, batched_t)
mask_guidance: AddsMaskGuidance | None = None
if mask is not None:
if is_inpainting_model(self.unet):
if masked_latents is None:
@ -332,7 +328,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
generator=torch.Generator(device="cpu").manual_seed(seed),
).to(device=orig_latents.device, dtype=orig_latents.dtype)
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
mask_guidance = AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask)
try:
latents = self.generate_latents_from_embeddings(
@ -340,7 +336,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
timesteps,
conditioning_data,
scheduler_step_kwargs=scheduler_step_kwargs,
additional_guidance=additional_guidance,
mask_guidance=mask_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
@ -367,16 +363,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
timesteps,
conditioning_data: TextConditioningData,
scheduler_step_kwargs: dict[str, Any],
*,
additional_guidance: List[Callable] = None,
mask_guidance: AddsMaskGuidance | None = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
) -> torch.Tensor:
self._adjust_memory_efficient_attention(latents)
if additional_guidance is None:
additional_guidance = []
batch_size = latents.shape[0]
@ -412,7 +405,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
)
)
# print("timesteps:", timesteps)
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t = t.expand(batch_size)
step_output = self.step(
@ -422,7 +414,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index=i,
total_step_count=len(timesteps),
scheduler_step_kwargs=scheduler_step_kwargs,
additional_guidance=additional_guidance,
mask_guidance=mask_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
@ -453,19 +445,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index: int,
total_step_count: int,
scheduler_step_kwargs: dict[str, Any],
additional_guidance: List[Callable] = None,
mask_guidance: AddsMaskGuidance | None = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
if additional_guidance is None:
additional_guidance = []
# one day we will expand this extension point, but for now it just does denoise masking
for guidance in additional_guidance:
latents = guidance(latents, timestep)
if mask_guidance is not None:
latents = mask_guidance(latents, timestep)
# TODO: should this scaling happen here or inside self._unet_forward?
# i.e. before or after passing it to InvokeAIDiffuserComponent
@ -541,17 +530,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again.
for guidance in additional_guidance:
# apply the mask to any "denoised" or "pred_original_sample" fields
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting
# again.
if mask_guidance is not None:
# Apply the mask to any "denoised" or "pred_original_sample" fields.
if hasattr(step_output, "denoised"):
step_output.pred_original_sample = guidance(step_output.denoised, self.scheduler.timesteps[-1])
step_output.pred_original_sample = mask_guidance(step_output.denoised, self.scheduler.timesteps[-1])
elif hasattr(step_output, "pred_original_sample"):
step_output.pred_original_sample = guidance(
step_output.pred_original_sample = mask_guidance(
step_output.pred_original_sample, self.scheduler.timesteps[-1]
)
else:
step_output.pred_original_sample = guidance(latents, self.scheduler.timesteps[-1])
step_output.pred_original_sample = mask_guidance(latents, self.scheduler.timesteps[-1])
return step_output