Fixup logic around compatibility of prompt-to-prompt, IP-Adapter, regional prompting.

This commit is contained in:
Ryan Dick
2024-02-29 12:47:23 -05:00
parent bdf3691ad0
commit 1bbd4f751d

View File

@ -416,29 +416,34 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
)
use_ip_adapter = ip_adapter_data is not None
# HACK(ryand): Fix this logic.
use_regional_prompting = conditioning_data.cond_regions is not None
if sum([use_cross_attention_control, use_ip_adapter, use_regional_prompting]) > 1:
raise Exception(
"Cross-attention control, IP-Adapter, and regional prompting cannot be used simultaneously (yet)."
use_regional_prompting = (
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
)
if use_cross_attention_control and use_ip_adapter:
raise ValueError(
"Prompt-to-prompt cross-attention control (`.swap()`) and IP-Adapter cannot be used simultaneously."
)
if use_cross_attention_control and use_regional_prompting:
raise ValueError(
"Prompt-to-prompt cross-attention control (`.swap()`) and regional prompting cannot be used simultaneously."
)
if use_ip_adapter and use_regional_prompting:
# TODO(ryand): Implement this.
raise NotImplementedError("Coming soon.")
ip_adapter_unet_patcher = None
self.use_ip_adapter = use_ip_adapter
attn_ctx = nullcontext()
if use_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
)
elif use_ip_adapter:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
if use_ip_adapter:
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
elif use_regional_prompting:
if use_regional_prompting:
attn_ctx = apply_regional_prompt_attn(self.invokeai_diffuser.model)
else:
attn_ctx = nullcontext()
with attn_ctx:
if callback is not None: