diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 485f23d7b1..755b1a2b08 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -73,7 +73,7 @@ class IPAdapterConditioningInfo: @dataclass class ConditioningData: - unconditioned_embeddings: BasicConditioningInfo + unconditioned_embeddings: Union[BasicConditioningInfo, SDXLConditioningInfo] text_embeddings: list[TextConditioningInfoWithMask] """ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 10c76c43a4..63cd2c65c5 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -5,13 +5,16 @@ from contextlib import contextmanager from typing import Any, Callable, Optional, Union import torch +import torchvision from diffusers import UNet2DConditionModel from typing_extensions import TypeAlias from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + BasicConditioningInfo, ConditioningData, ExtraConditioningInfo, + IPAdapterConditioningInfo, PostprocessingSettings, SDXLConditioningInfo, ) @@ -217,35 +220,87 @@ class InvokeAIDiffuserComponent: ) wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0 - if wants_cross_attention_control or self.sequential_guidance: - # If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention - # control is currently only supported in sequential mode. - ( - unconditioned_next_x, - conditioned_next_x, - ) = self._apply_standard_conditioning_sequentially( - x=sample, - sigma=timestep, - conditioning_data=conditioning_data, - cross_attention_control_types_to_do=cross_attention_control_types_to_do, - down_block_additional_residuals=down_block_additional_residuals, - mid_block_additional_residual=mid_block_additional_residual, - down_intrablock_additional_residuals=down_intrablock_additional_residuals, - ) - else: - ( - unconditioned_next_x, - conditioned_next_x, - ) = self._apply_standard_conditioning( - x=sample, - sigma=timestep, - conditioning_data=conditioning_data, - down_block_additional_residuals=down_block_additional_residuals, - mid_block_additional_residual=mid_block_additional_residual, - down_intrablock_additional_residuals=down_intrablock_additional_residuals, - ) + cond_next_xs = [] + uncond_next_x = None + for text_conditioning in conditioning_data.text_embeddings: + if wants_cross_attention_control or self.sequential_guidance: + raise NotImplementedError( + "Sequential conditioning has not yet been updated to work with multiple text embeddings." + ) + # If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention + # control is currently only supported in sequential mode. + # ( + # unconditioned_next_x, + # conditioned_next_x, + # ) = self._apply_standard_conditioning_sequentially( + # x=sample, + # sigma=timestep, + # conditioning_data=conditioning_data, + # cross_attention_control_types_to_do=cross_attention_control_types_to_do, + # down_block_additional_residuals=down_block_additional_residuals, + # mid_block_additional_residual=mid_block_additional_residual, + # down_intrablock_additional_residuals=down_intrablock_additional_residuals, + # ) + else: + ( + unconditioned_next_x, + conditioned_next_x, + ) = self._apply_standard_conditioning( + x=sample, + sigma=timestep, + cond_text_embedding=text_conditioning.text_conditioning_info, + uncond_text_embedding=conditioning_data.unconditioned_embeddings, + ip_adapter_conditioning=conditioning_data.ip_adapter_conditioning, + down_block_additional_residuals=down_block_additional_residuals, + mid_block_additional_residual=mid_block_additional_residual, + down_intrablock_additional_residuals=down_intrablock_additional_residuals, + ) + cond_next_xs.append(conditioned_next_x) + # HACK(ryand): We re-run unconditioned denoising for each text embedding, but we should only need to do it + # once. + uncond_next_x = unconditioned_next_x - return unconditioned_next_x, conditioned_next_x + # TODO(ryand): Think about how to handle the batch dimension here. Should this be torch.stack()? It probably + # doesn't matter, as I'm sure there are many other places where we don't properly support batching. + cond_out = torch.concat(cond_next_xs, dim=0) + # Initialize count to 1e-9 to avoid division by zero. + cond_count = torch.ones_like(cond_out[0, ...]) * 1e-9 + + _, _, height, width = cond_out.shape + for te_idx, te in enumerate(conditioning_data.text_embeddings): + mask = te.mask + if mask is not None: + # Resize if necessary. + tf = torchvision.transforms.Resize( + (height, width), interpolation=torchvision.transforms.InterpolationMode.NEAREST + ) + mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w) + mask = tf(mask) + + # TODO(ryand): We are converting from uint8 to float here. Should we just be storing a float mask to + # begin with? + mask = mask.to(cond_out.device, cond_out.dtype) + + # Make sure that all mask values are either 0.0 or 1.0. + # HACK(ryand): This is not the right place to be doing this. Just be clear about the expected format of + # the mask in the passed data structures. + mask[mask < 0.5] = 0.0 + mask[mask >= 0.5] = 1.0 + + mask *= te.mask_strength + else: + # mask is None, so treat as a mask of all 1.0s (by taking advantage of torch's treatment of scalar + # values). + mask = 1.0 + + # Apply the mask and update the count. + cond_out[te_idx, ...] *= mask[0] + cond_count += mask[0] + + # Combine the masked conditionings. + cond_out = cond_out.sum(dim=0, keepdim=True) / cond_count + + return uncond_next_x, cond_out def do_latent_postprocessing( self, @@ -313,7 +368,9 @@ class InvokeAIDiffuserComponent: self, x, sigma, - conditioning_data: ConditioningData, + cond_text_embedding: Union[BasicConditioningInfo, SDXLConditioningInfo], + uncond_text_embedding: Union[BasicConditioningInfo, SDXLConditioningInfo], + ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter @@ -324,43 +381,40 @@ class InvokeAIDiffuserComponent: x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) - assert len(conditioning_data.text_embeddings) == 1 - text_embeddings = conditioning_data.text_embeddings[0].text_conditioning_info - cross_attention_kwargs = None - if conditioning_data.ip_adapter_conditioning is not None: + if ip_adapter_conditioning is not None: # Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len). cross_attention_kwargs = { "ip_adapter_image_prompt_embeds": [ torch.stack( [ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds] ) - for ipa_conditioning in conditioning_data.ip_adapter_conditioning + for ipa_conditioning in ip_adapter_conditioning ] } added_cond_kwargs = None - if type(text_embeddings) is SDXLConditioningInfo: + if type(cond_text_embedding) is SDXLConditioningInfo: added_cond_kwargs = { "text_embeds": torch.cat( [ # TODO: how to pad? just by zeros? or even truncate? - conditioning_data.unconditioned_embeddings.pooled_embeds, - text_embeddings.pooled_embeds, + uncond_text_embedding.pooled_embeds, + cond_text_embedding.pooled_embeds, ], dim=0, ), "time_ids": torch.cat( [ - conditioning_data.unconditioned_embeddings.add_time_ids, - text_embeddings.add_time_ids, + uncond_text_embedding.add_time_ids, + cond_text_embedding.add_time_ids, ], dim=0, ), } both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( - conditioning_data.unconditioned_embeddings.embeds, text_embeddings.embeds + uncond_text_embedding.embeds, cond_text_embedding.embeds ) both_results = self.model_forward_callback( x_twice, @@ -385,7 +439,7 @@ class InvokeAIDiffuserComponent: down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter - ): + ) -> tuple[torch.Tensor, torch.Tensor]: """Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of slower execution speed. """