mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Raise a clear error message if prompt-to-prompt cross-attention control is triggered when using multiple prompts.
This commit is contained in:
parent
e132afb705
commit
e7f7ae660d
@ -400,14 +400,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
cur_text_embedding_len = 0
|
||||
processed_masks = []
|
||||
embedding_ranges = []
|
||||
extra_conditioning = None
|
||||
|
||||
for text_embedding_info, mask in zip(text_conditionings, masks, strict=True):
|
||||
# HACK(ryand): Figure out the intended relationship with CAC. Probably want to raise if more than one text
|
||||
# embedding is passed in and CAC is being used.
|
||||
assert (
|
||||
text_embedding_info.extra_conditioning is None
|
||||
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
|
||||
)
|
||||
if (
|
||||
text_embedding_info.extra_conditioning is not None
|
||||
and text_embedding_info.extra_conditioning.wants_cross_attention_control
|
||||
):
|
||||
extra_conditioning = text_embedding_info.extra_conditioning
|
||||
|
||||
if is_sdxl:
|
||||
# HACK(ryand): We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. This is
|
||||
@ -441,18 +441,22 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if not all_masks_are_none:
|
||||
regions = TextConditioningRegions(masks=torch.cat(processed_masks, dim=1), ranges=embedding_ranges)
|
||||
|
||||
if extra_conditioning is not None and len(text_conditionings) > 1:
|
||||
raise ValueError(
|
||||
"Prompt-to-prompt cross-attention control (a.k.a. `swap()`) is not supported when using multiple "
|
||||
"prompts."
|
||||
)
|
||||
|
||||
if is_sdxl:
|
||||
return SDXLConditioningInfo(
|
||||
embeds=text_embedding,
|
||||
# TODO(ryand): This should not be hard-coded to None.
|
||||
extra_conditioning=None,
|
||||
extra_conditioning=extra_conditioning,
|
||||
pooled_embeds=pooled_embedding,
|
||||
add_time_ids=add_time_ids,
|
||||
), regions
|
||||
return BasicConditioningInfo(
|
||||
embeds=text_embedding,
|
||||
# TODO(ryand): This should not be hard-coded to None.
|
||||
extra_conditioning=None,
|
||||
extra_conditioning=extra_conditioning,
|
||||
), regions
|
||||
|
||||
def get_conditioning_data(
|
||||
|
Loading…
Reference in New Issue
Block a user