Add symmetric support for regional negative text prompts.

This commit is contained in:
Ryan Dick 2024-02-27 20:05:02 -05:00
parent cfba51aed5
commit 54971afe44
4 changed files with 65 additions and 61 deletions

View File

@ -44,6 +44,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo, BasicConditioningInfo,
ConditioningData, ConditioningData,
IPAdapterConditioningInfo, IPAdapterConditioningInfo,
SDXLConditioningInfo,
) )
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
@ -233,8 +234,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField( positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0 description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
) )
negative_conditioning: ConditioningField = InputField( negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1 description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=0
) )
noise: Optional[LatentsField] = InputField( noise: Optional[LatentsField] = InputField(
default=None, default=None,
@ -327,6 +328,31 @@ class DenoiseLatentsInvocation(BaseInvocation):
base_model=base_model, base_model=base_model,
) )
def _get_text_embeddings_and_masks(
self,
cond_field: Union[ConditioningField, list[ConditioningField]],
context: InvocationContext,
device: torch.device,
dtype: torch.dtype,
):
# Normalize cond_field to a list.
cond_list = cond_field
if not isinstance(cond_list, list):
cond_list = [cond_list]
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = []
for cond in cond_list:
cond_data = context.services.latents.get(cond.conditioning_name)
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
mask = cond.mask
if mask is not None:
mask = context.services.latents.get(mask.mask_name)
text_embeddings_masks.append(mask)
return text_embeddings, text_embeddings_masks
def get_conditioning_data( def get_conditioning_data(
self, self,
context: InvocationContext, context: InvocationContext,
@ -334,29 +360,18 @@ class DenoiseLatentsInvocation(BaseInvocation):
unet, unet,
seed, seed,
) -> ConditioningData: ) -> ConditioningData:
# self.positive_conditioning could be a list or a single ConditioningField. Normalize to a list here. cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
positive_conditioning_list = self.positive_conditioning self.positive_conditioning, context, unet.device, unet.dtype
if not isinstance(positive_conditioning_list, list): )
positive_conditioning_list = [positive_conditioning_list] uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
self.negative_conditioning, context, unet.device, unet.dtype
text_embeddings: list[BasicConditioningInfo] = [] )
text_embeddings_masks: list[Optional[torch.Tensor]] = []
for positive_conditioning in positive_conditioning_list:
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype))
mask = positive_conditioning.mask
if mask is not None:
mask = context.services.latents.get(mask.mask_name)
text_embeddings_masks.append(mask)
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
unconditioned_embeddings=uc, uncond_text_embeddings=uncond_text_embeddings,
text_embeddings=text_embeddings, uncond_text_embedding_masks=uncond_text_embedding_masks,
text_embedding_masks=text_embeddings_masks, cond_text_embeddings=cond_text_embeddings,
cond_text_embedding_masks=cond_text_embedding_masks,
guidance_scale=self.cfg_scale, guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier, guidance_rescale_multiplier=self.cfg_rescale_multiplier,
postprocessing_settings=PostprocessingSettings( postprocessing_settings=PostprocessingSettings(

View File

@ -404,12 +404,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0: if timesteps.shape[0] == 0:
return latents return latents
extra_conditioning_info = conditioning_data.text_embeddings[0].extra_conditioning extra_conditioning_info = conditioning_data.cond_text_embeddings[0].extra_conditioning
use_cross_attention_control = ( use_cross_attention_control = (
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
) )
use_ip_adapter = ip_adapter_data is not None use_ip_adapter = ip_adapter_data is not None
use_regional_prompting = len(conditioning_data.text_embeddings) > 1 # HACK(ryand): Fix this logic.
use_regional_prompting = len(conditioning_data.cond_text_embeddings) > 1
if sum([use_cross_attention_control, use_ip_adapter, use_regional_prompting]) > 1: if sum([use_cross_attention_control, use_ip_adapter, use_regional_prompting]) > 1:
raise Exception( raise Exception(
"Cross-attention control, IP-Adapter, and regional prompting cannot be used simultaneously (yet)." "Cross-attention control, IP-Adapter, and regional prompting cannot be used simultaneously (yet)."

View File

@ -65,10 +65,10 @@ class IPAdapterConditioningInfo:
@dataclass @dataclass
class ConditioningData: class ConditioningData:
# TODO(ryand): Support masks for unconditioned_embeddings. uncond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]]
unconditioned_embeddings: BasicConditioningInfo uncond_text_embedding_masks: list[Optional[torch.Tensor]]
text_embeddings: list[BasicConditioningInfo] cond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]]
text_embedding_masks: list[Optional[torch.Tensor]] cond_text_embedding_masks: list[Optional[torch.Tensor]]
""" """
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).

View File

@ -234,7 +234,8 @@ class InvokeAIDiffuserComponent:
down_block_res_samples, mid_block_res_sample = None, None down_block_res_samples, mid_block_res_sample = None, None
# HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably # HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably
# concatenate all of the embeddings for the ControlNet, but not apply embedding masks. # concatenate all of the embeddings for the ControlNet, but not apply embedding masks.
text_embeddings = conditioning_data.text_embeddings[0] uncond_text_embeddings = conditioning_data.uncond_text_embeddings[0]
cond_text_embeddings = conditioning_data.cond_text_embeddings[0]
# control_data should be type List[ControlNetData] # control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list) # this loop covers both ControlNet (one ControlNetData in list)
@ -265,38 +266,25 @@ class InvokeAIDiffuserComponent:
added_cond_kwargs = None added_cond_kwargs = None
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
if type(text_embeddings) is SDXLConditioningInfo: if type(cond_text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": text_embeddings.pooled_embeds, "text_embeds": cond_text_embeddings.pooled_embeds,
"time_ids": text_embeddings.add_time_ids, "time_ids": cond_text_embeddings.add_time_ids,
} }
encoder_hidden_states = text_embeddings.embeds encoder_hidden_states = cond_text_embeddings.embeds
encoder_attention_mask = None encoder_attention_mask = None
else: else:
if type(text_embeddings) is SDXLConditioningInfo: if type(cond_text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": torch.cat( "text_embeds": torch.cat(
[ [uncond_text_embeddings.pooled_embeds, cond_text_embeddings.pooled_embeds], dim=0
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
text_embeddings.pooled_embeds,
],
dim=0,
), ),
"time_ids": torch.cat( "time_ids": torch.cat(
[ [uncond_text_embeddings.add_time_ids, cond_text_embeddings.add_time_ids], dim=0
conditioning_data.unconditioned_embeddings.add_time_ids,
text_embeddings.add_time_ids,
],
dim=0,
), ),
} }
( (encoder_hidden_states, encoder_attention_mask) = self._concat_conditionings_for_batch(
encoder_hidden_states, uncond_text_embeddings.embeds, cond_text_embeddings.embeds
encoder_attention_mask,
) = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds,
text_embeddings.embeds,
) )
if isinstance(control_datum.weight, list): if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step # if controlnet has multiple weights, use the weight for the current step
@ -487,14 +475,14 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = None cross_attention_kwargs = None
_, _, h, w = x.shape _, _, h, w = x.shape
cond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks( cond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
text_conditionings=conditioning_data.text_embeddings, text_conditionings=conditioning_data.cond_text_embeddings,
masks=conditioning_data.text_embedding_masks, masks=conditioning_data.cond_text_embedding_masks,
latent_height=h, latent_height=h,
latent_width=w, latent_width=w,
) )
uncond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks( uncond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
text_conditionings=[conditioning_data.unconditioned_embeddings], text_conditionings=conditioning_data.uncond_text_embeddings,
masks=[None], masks=conditioning_data.uncond_text_embedding_masks,
latent_height=h, latent_height=h,
latent_width=w, latent_width=w,
) )
@ -579,8 +567,8 @@ class InvokeAIDiffuserComponent:
slower execution speed. slower execution speed.
""" """
assert len(conditioning_data.text_embeddings) == 1 assert len(conditioning_data.cond_text_embeddings) == 1
text_embeddings = conditioning_data.text_embeddings[0] text_embeddings = conditioning_data.cond_text_embeddings[0]
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet # Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
# and T2I-Adapter residuals into two chunks. # and T2I-Adapter residuals into two chunks.
@ -642,15 +630,15 @@ class InvokeAIDiffuserComponent:
is_sdxl = type(text_embeddings) is SDXLConditioningInfo is_sdxl = type(text_embeddings) is SDXLConditioningInfo
if is_sdxl: if is_sdxl:
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds, "text_embeds": conditioning_data.uncond_text_embeddings.pooled_embeds,
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids, "time_ids": conditioning_data.uncond_text_embeddings.add_time_ids,
} }
# Run unconditioned UNet denoising (i.e. negative prompt). # Run unconditioned UNet denoising (i.e. negative prompt).
unconditioned_next_x = self.model_forward_callback( unconditioned_next_x = self.model_forward_callback(
x, x,
sigma, sigma,
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.uncond_text_embeddings.embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=uncond_down_block, down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block, mid_block_additional_residual=uncond_mid_block,