diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 3a758839ea..f5b02889e9 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -125,125 +125,85 @@ class TextConditioningData: return isinstance(self.cond_text, SDXLConditioningInfo) def to_unet_kwargs(self, unet_kwargs, conditioning_mode): + _, _, h, w = unet_kwargs.sample.shape + device = unet_kwargs.sample.device + dtype = unet_kwargs.sample.dtype + if conditioning_mode == "both": - encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch( - self.uncond_text.embeds, self.cond_text.embeds - ) + conditionings = [self.uncond_text.embeds, self.cond_text.embeds] + c_regions = [self.uncond_regions, self.cond_regions] elif conditioning_mode == "positive": - encoder_hidden_states = self.cond_text.embeds - encoder_attention_mask = None - else: # elif conditioning_mode == "negative": - encoder_hidden_states = self.uncond_text.embeds - encoder_attention_mask = None + conditionings = [self.cond_text.embeds] + c_regions = [self.cond_regions] + else: + conditionings = [self.uncond_text.embeds] + c_regions = [self.uncond_regions] + + encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(conditionings) unet_kwargs.encoder_hidden_states = encoder_hidden_states unet_kwargs.encoder_attention_mask = encoder_attention_mask if self.is_sdxl(): - if conditioning_mode == "negative": - added_cond_kwargs = dict( # noqa: C408 - text_embeds=self.cond_text.pooled_embeds, - time_ids=self.cond_text.add_time_ids, - ) - elif conditioning_mode == "positive": - added_cond_kwargs = dict( # noqa: C408 - text_embeds=self.uncond_text.pooled_embeds, - time_ids=self.uncond_text.add_time_ids, - ) - else: # elif conditioning_mode == "both": - added_cond_kwargs = dict( # noqa: C408 - text_embeds=torch.cat( - [ - # TODO: how to pad? just by zeros? or even truncate? - self.uncond_text.pooled_embeds, - self.cond_text.pooled_embeds, - ], - ), - time_ids=torch.cat( - [ - self.uncond_text.add_time_ids, - self.cond_text.add_time_ids, - ], - ), - ) + added_cond_kwargs = dict( # noqa: C408 + text_embeds=torch.cat([c.pooled_embeds for c in conditionings]), + time_ids=torch.cat([c.add_time_ids for c in conditionings]), + ) unet_kwargs.added_cond_kwargs = added_cond_kwargs - if self.cond_regions is not None or self.uncond_regions is not None: - # TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings - # and masks are not changing from step-to-step, so this really only needs to be done once. While this seems - # painfully inefficient, the time spent is typically negligible compared to the forward inference pass of - # the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly - # awkward to handle both standard conditioning and sequential conditioning further up the stack. - - _tmp_regions = self.cond_regions if self.cond_regions is not None else self.uncond_regions - _, _, h, w = _tmp_regions.masks.shape - dtype = self.cond_text.embeds.dtype - device = self.cond_text.embeds.device - - regions = [] - for c, r in [ - (self.uncond_text, self.uncond_regions), - (self.cond_text, self.cond_regions), - ]: + if any(r is not None for r in c_regions): + tmp_regions = [] + for c, r in zip(conditionings, c_regions, strict=True): if r is None: - # Create a dummy mask and range for text conditioning that doesn't have region masks. r = TextConditioningRegions( masks=torch.ones((1, 1, h, w), dtype=dtype), ranges=[Range(start=0, end=c.embeds.shape[1])], ) - regions.append(r) + tmp_regions.append(r) if unet_kwargs.cross_attention_kwargs is None: unet_kwargs.cross_attention_kwargs = {} unet_kwargs.cross_attention_kwargs.update( - regional_prompt_data=RegionalPromptData(regions=regions, device=device, dtype=dtype), + regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype), ) - def _concat_conditionings_for_batch(self, unconditioning, conditioning): + def _concat_conditionings_for_batch(self, conditionings): + def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int): + return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim) + def _pad_conditioning(cond, target_len, encoder_attention_mask): conditioning_attention_mask = torch.ones( (cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype ) if cond.shape[1] < max_len: - conditioning_attention_mask = torch.cat( - [ - conditioning_attention_mask, - torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype), - ], + conditioning_attention_mask = _pad_zeros( + conditioning_attention_mask, + pad_shape=(cond.shape[0], max_len - cond.shape[1]), dim=1, ) - cond = torch.cat( - [ - cond, - torch.zeros( - (cond.shape[0], max_len - cond.shape[1], cond.shape[2]), - device=cond.device, - dtype=cond.dtype, - ), - ], + cond = _pad_zeros( + cond, + pad_shape=(cond.shape[0], max_len - cond.shape[1], cond.shape[2]), dim=1, ) if encoder_attention_mask is None: encoder_attention_mask = conditioning_attention_mask else: - encoder_attention_mask = torch.cat( - [ - encoder_attention_mask, - conditioning_attention_mask, - ] - ) + encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask]) return cond, encoder_attention_mask encoder_attention_mask = None - if unconditioning.shape[1] != conditioning.shape[1]: - max_len = max(unconditioning.shape[1], conditioning.shape[1]) - unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask) - conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask) + max_len = max([c.shape[1] for c in conditionings]) + if any(c.shape[1] != max_len for c in conditionings): + for i in range(len(conditionings)): + conditionings[i], encoder_attention_mask = _pad_conditioning( + conditionings[i], max_len, encoder_attention_mask + ) - return torch.cat([unconditioning, conditioning]), encoder_attention_mask + return torch.cat(conditionings), encoder_attention_mask