diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 9c532cec3e..4a45679d95 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -325,17 +325,71 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): mask_guidance = AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, is_gradient_mask) try: - latents = self.generate_latents_from_embeddings( - latents, - timesteps, - conditioning_data, - scheduler_step_kwargs=scheduler_step_kwargs, - mask_guidance=mask_guidance, - control_data=control_data, - ip_adapter_data=ip_adapter_data, - t2i_adapter_data=t2i_adapter_data, - callback=callback, + self._adjust_memory_efficient_attention(latents) + + batch_size = latents.shape[0] + + if timesteps.shape[0] == 0: + return latents + + use_ip_adapter = ip_adapter_data is not None + use_regional_prompting = ( + conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None ) + unet_attention_patcher = None + attn_ctx = nullcontext() + + if use_ip_adapter or use_regional_prompting: + ip_adapters: Optional[List[UNetIPAdapterData]] = ( + [ + {"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} + for ipa in ip_adapter_data + ] + if use_ip_adapter + else None + ) + unet_attention_patcher = UNetAttentionPatcher(ip_adapters) + attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) + + with attn_ctx: + callback( + PipelineIntermediateState( + step=-1, + order=self.scheduler.order, + total_steps=len(timesteps), + timestep=self.scheduler.config.num_train_timesteps, + latents=latents, + ) + ) + + for i, t in enumerate(self.progress_bar(timesteps)): + batched_t = t.expand(batch_size) + step_output = self.step( + batched_t, + latents, + conditioning_data, + step_index=i, + total_step_count=len(timesteps), + scheduler_step_kwargs=scheduler_step_kwargs, + mask_guidance=mask_guidance, + control_data=control_data, + ip_adapter_data=ip_adapter_data, + t2i_adapter_data=t2i_adapter_data, + ) + latents = step_output.prev_sample + predicted_original = getattr(step_output, "pred_original_sample", None) + + callback( + PipelineIntermediateState( + step=i, + order=self.scheduler.order, + total_steps=len(timesteps), + timestep=int(t), + latents=latents, + predicted_original=predicted_original, + ) + ) + finally: self.invokeai_diffuser.model_forward_callback = self._unet_forward @@ -351,82 +405,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): return latents - def generate_latents_from_embeddings( - self, - latents: torch.Tensor, - timesteps: torch.Tensor, - conditioning_data: TextConditioningData, - scheduler_step_kwargs: dict[str, Any], - callback: Callable[[PipelineIntermediateState], None], - mask_guidance: AddsMaskGuidance | None = None, - control_data: list[ControlNetData] | None = None, - ip_adapter_data: Optional[list[IPAdapterData]] = None, - t2i_adapter_data: Optional[list[T2IAdapterData]] = None, - ) -> torch.Tensor: - self._adjust_memory_efficient_attention(latents) - - batch_size = latents.shape[0] - - if timesteps.shape[0] == 0: - return latents - - use_ip_adapter = ip_adapter_data is not None - use_regional_prompting = ( - conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None - ) - unet_attention_patcher = None - attn_ctx = nullcontext() - - if use_ip_adapter or use_regional_prompting: - ip_adapters: Optional[List[UNetIPAdapterData]] = ( - [{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data] - if use_ip_adapter - else None - ) - unet_attention_patcher = UNetAttentionPatcher(ip_adapters) - attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) - - with attn_ctx: - callback( - PipelineIntermediateState( - step=-1, - order=self.scheduler.order, - total_steps=len(timesteps), - timestep=self.scheduler.config.num_train_timesteps, - latents=latents, - ) - ) - - for i, t in enumerate(self.progress_bar(timesteps)): - batched_t = t.expand(batch_size) - step_output = self.step( - batched_t, - latents, - conditioning_data, - step_index=i, - total_step_count=len(timesteps), - scheduler_step_kwargs=scheduler_step_kwargs, - mask_guidance=mask_guidance, - control_data=control_data, - ip_adapter_data=ip_adapter_data, - t2i_adapter_data=t2i_adapter_data, - ) - latents = step_output.prev_sample - predicted_original = getattr(step_output, "pred_original_sample", None) - - callback( - PipelineIntermediateState( - step=i, - order=self.scheduler.order, - total_steps=len(timesteps), - timestep=int(t), - latents=latents, - predicted_original=predicted_original, - ) - ) - - return latents - @torch.inference_mode() def step( self,