Raise a clear error message if prompt-to-prompt cross-attention control is triggered when using multiple prompts.

This commit is contained in:
Ryan Dick 2024-02-28 21:38:25 -05:00
parent e132afb705
commit e7f7ae660d

View File

@ -400,14 +400,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
cur_text_embedding_len = 0
processed_masks = []
embedding_ranges = []
extra_conditioning = None
for text_embedding_info, mask in zip(text_conditionings, masks, strict=True):
# HACK(ryand): Figure out the intended relationship with CAC. Probably want to raise if more than one text
# embedding is passed in and CAC is being used.
assert (
text_embedding_info.extra_conditioning is None
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
)
if (
text_embedding_info.extra_conditioning is not None
and text_embedding_info.extra_conditioning.wants_cross_attention_control
):
extra_conditioning = text_embedding_info.extra_conditioning
if is_sdxl:
# HACK(ryand): We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. This is
@ -441,18 +441,22 @@ class DenoiseLatentsInvocation(BaseInvocation):
if not all_masks_are_none:
regions = TextConditioningRegions(masks=torch.cat(processed_masks, dim=1), ranges=embedding_ranges)
if extra_conditioning is not None and len(text_conditionings) > 1:
raise ValueError(
"Prompt-to-prompt cross-attention control (a.k.a. `swap()`) is not supported when using multiple "
"prompts."
)
if is_sdxl:
return SDXLConditioningInfo(
embeds=text_embedding,
# TODO(ryand): This should not be hard-coded to None.
extra_conditioning=None,
extra_conditioning=extra_conditioning,
pooled_embeds=pooled_embedding,
add_time_ids=add_time_ids,
), regions
return BasicConditioningInfo(
embeds=text_embedding,
# TODO(ryand): This should not be hard-coded to None.
extra_conditioning=None,
extra_conditioning=extra_conditioning,
), regions
def get_conditioning_data(