mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Raise a clear error message if prompt-to-prompt cross-attention control is triggered when using multiple prompts.
This commit is contained in:
parent
e132afb705
commit
e7f7ae660d
@ -400,14 +400,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
cur_text_embedding_len = 0
|
cur_text_embedding_len = 0
|
||||||
processed_masks = []
|
processed_masks = []
|
||||||
embedding_ranges = []
|
embedding_ranges = []
|
||||||
|
extra_conditioning = None
|
||||||
|
|
||||||
for text_embedding_info, mask in zip(text_conditionings, masks, strict=True):
|
for text_embedding_info, mask in zip(text_conditionings, masks, strict=True):
|
||||||
# HACK(ryand): Figure out the intended relationship with CAC. Probably want to raise if more than one text
|
if (
|
||||||
# embedding is passed in and CAC is being used.
|
text_embedding_info.extra_conditioning is not None
|
||||||
assert (
|
and text_embedding_info.extra_conditioning.wants_cross_attention_control
|
||||||
text_embedding_info.extra_conditioning is None
|
):
|
||||||
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
|
extra_conditioning = text_embedding_info.extra_conditioning
|
||||||
)
|
|
||||||
|
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
# HACK(ryand): We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. This is
|
# HACK(ryand): We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. This is
|
||||||
@ -441,18 +441,22 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if not all_masks_are_none:
|
if not all_masks_are_none:
|
||||||
regions = TextConditioningRegions(masks=torch.cat(processed_masks, dim=1), ranges=embedding_ranges)
|
regions = TextConditioningRegions(masks=torch.cat(processed_masks, dim=1), ranges=embedding_ranges)
|
||||||
|
|
||||||
|
if extra_conditioning is not None and len(text_conditionings) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Prompt-to-prompt cross-attention control (a.k.a. `swap()`) is not supported when using multiple "
|
||||||
|
"prompts."
|
||||||
|
)
|
||||||
|
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
return SDXLConditioningInfo(
|
return SDXLConditioningInfo(
|
||||||
embeds=text_embedding,
|
embeds=text_embedding,
|
||||||
# TODO(ryand): This should not be hard-coded to None.
|
extra_conditioning=extra_conditioning,
|
||||||
extra_conditioning=None,
|
|
||||||
pooled_embeds=pooled_embedding,
|
pooled_embeds=pooled_embedding,
|
||||||
add_time_ids=add_time_ids,
|
add_time_ids=add_time_ids,
|
||||||
), regions
|
), regions
|
||||||
return BasicConditioningInfo(
|
return BasicConditioningInfo(
|
||||||
embeds=text_embedding,
|
embeds=text_embedding,
|
||||||
# TODO(ryand): This should not be hard-coded to None.
|
extra_conditioning=extra_conditioning,
|
||||||
extra_conditioning=None,
|
|
||||||
), regions
|
), regions
|
||||||
|
|
||||||
def get_conditioning_data(
|
def get_conditioning_data(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user