From 7d24ad8ccdadb0d27b40597746125af9da67cb2f Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 12 Jun 2024 14:27:49 -0400 Subject: [PATCH] Simplify handling of AddsMaskGuidance, and fix some related type errors. --- .../stable_diffusion/diffusers_pipeline.py | 46 ++++++++----------- 1 file changed, 18 insertions(+), 28 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index fcdcffe10b..6847333ab0 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -78,8 +78,8 @@ def are_like_tensors(a: torch.Tensor, b: object) -> bool: @dataclass class AddsMaskGuidance: - mask: torch.FloatTensor - mask_latents: torch.FloatTensor + mask: torch.Tensor + mask_latents: torch.Tensor scheduler: SchedulerMixin noise: torch.Tensor gradient_mask: bool @@ -87,7 +87,7 @@ class AddsMaskGuidance: def __call__(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: return self.apply_mask(latents, t) - def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor: + def apply_mask(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: batch_size = latents.size(0) mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size) if t.dim() == 0: @@ -285,11 +285,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): latents: torch.Tensor, scheduler_step_kwargs: dict[str, Any], conditioning_data: TextConditioningData, - *, noise: Optional[torch.Tensor], timesteps: torch.Tensor, init_timestep: torch.Tensor, - additional_guidance: List[Callable] = None, callback: Callable[[PipelineIntermediateState], None] = None, control_data: List[ControlNetData] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None, @@ -302,9 +300,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if init_timestep.shape[0] == 0: return latents - if additional_guidance is None: - additional_guidance = [] - orig_latents = latents.clone() batch_size = latents.shape[0] @@ -314,6 +309,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers latents = self.scheduler.add_noise(latents, noise, batched_t) + mask_guidance: AddsMaskGuidance | None = None if mask is not None: if is_inpainting_model(self.unet): if masked_latents is None: @@ -332,7 +328,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): generator=torch.Generator(device="cpu").manual_seed(seed), ).to(device=orig_latents.device, dtype=orig_latents.dtype) - additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask)) + mask_guidance = AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask) try: latents = self.generate_latents_from_embeddings( @@ -340,7 +336,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): timesteps, conditioning_data, scheduler_step_kwargs=scheduler_step_kwargs, - additional_guidance=additional_guidance, + mask_guidance=mask_guidance, control_data=control_data, ip_adapter_data=ip_adapter_data, t2i_adapter_data=t2i_adapter_data, @@ -367,16 +363,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): timesteps, conditioning_data: TextConditioningData, scheduler_step_kwargs: dict[str, Any], - *, - additional_guidance: List[Callable] = None, + mask_guidance: AddsMaskGuidance | None = None, control_data: List[ControlNetData] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None, callback: Callable[[PipelineIntermediateState], None] = None, ) -> torch.Tensor: self._adjust_memory_efficient_attention(latents) - if additional_guidance is None: - additional_guidance = [] batch_size = latents.shape[0] @@ -412,7 +405,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): ) ) - # print("timesteps:", timesteps) for i, t in enumerate(self.progress_bar(timesteps)): batched_t = t.expand(batch_size) step_output = self.step( @@ -422,7 +414,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): step_index=i, total_step_count=len(timesteps), scheduler_step_kwargs=scheduler_step_kwargs, - additional_guidance=additional_guidance, + mask_guidance=mask_guidance, control_data=control_data, ip_adapter_data=ip_adapter_data, t2i_adapter_data=t2i_adapter_data, @@ -453,19 +445,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): step_index: int, total_step_count: int, scheduler_step_kwargs: dict[str, Any], - additional_guidance: List[Callable] = None, + mask_guidance: AddsMaskGuidance | None = None, control_data: List[ControlNetData] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None, ): # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value timestep = t[0] - if additional_guidance is None: - additional_guidance = [] - # one day we will expand this extension point, but for now it just does denoise masking - for guidance in additional_guidance: - latents = guidance(latents, timestep) + if mask_guidance is not None: + latents = mask_guidance(latents, timestep) # TODO: should this scaling happen here or inside self._unet_forward? # i.e. before or after passing it to InvokeAIDiffuserComponent @@ -541,17 +530,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # compute the previous noisy sample x_t -> x_t-1 step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs) - # TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again. - for guidance in additional_guidance: - # apply the mask to any "denoised" or "pred_original_sample" fields + # TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting + # again. + if mask_guidance is not None: + # Apply the mask to any "denoised" or "pred_original_sample" fields. if hasattr(step_output, "denoised"): - step_output.pred_original_sample = guidance(step_output.denoised, self.scheduler.timesteps[-1]) + step_output.pred_original_sample = mask_guidance(step_output.denoised, self.scheduler.timesteps[-1]) elif hasattr(step_output, "pred_original_sample"): - step_output.pred_original_sample = guidance( + step_output.pred_original_sample = mask_guidance( step_output.pred_original_sample, self.scheduler.timesteps[-1] ) else: - step_output.pred_original_sample = guidance(latents, self.scheduler.timesteps[-1]) + step_output.pred_original_sample = mask_guidance(latents, self.scheduler.timesteps[-1]) return step_output