Add concatenation of multiple text conditioning tensors, and patching of RegionalPromptAttnProcessor2_0 into the UNet.

This commit is contained in:
Ryan Dick 2024-02-16 17:09:06 -05:00
parent 38248b988f
commit caa690e24d
3 changed files with 74 additions and 5 deletions

View File

@ -25,6 +25,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import apply_regional_prompt_attn
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..util import auto_detect_slice_size, normalize_device
@ -415,20 +416,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
)
ip_adapter_unet_patcher = None
self.use_ip_adapter = use_ip_adapter
if use_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
)
self.use_ip_adapter = False
elif use_ip_adapter:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
self.use_ip_adapter = True
elif use_regional_prompting:
raise NotImplementedError("Regional prompting is not yet supported.")
attn_ctx = apply_regional_prompt_attn(self.invokeai_diffuser.model)
else:
attn_ctx = nullcontext()

View File

@ -1,7 +1,9 @@
from contextlib import contextmanager
from typing import Optional
import torch
import torch.nn.functional as F
from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
from diffusers.utils import USE_PEFT_BACKEND
@ -17,7 +19,10 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
regional_prompt_data=None,
) -> torch.FloatTensor:
assert regional_prompt_data is None
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@ -83,3 +88,16 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
@contextmanager
def apply_regional_prompt_attn(unet: UNet2DConditionModel):
"""A context manager that patches `unet` with RegionalPromptAttnProcessor2_0 attention processors."""
orig_attn_processors = unet.attn_processors
try:
unet.set_attn_processor(RegionalPromptAttnProcessor2_0())
yield None
finally:
unet.set_attn_processor(orig_attn_processors)

View File

@ -10,6 +10,7 @@ from typing_extensions import TypeAlias
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningData,
ExtraConditioningInfo,
PostprocessingSettings,
@ -309,6 +310,55 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class.
def _prepare_text_embeddings(
self, text_embeddings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]]
) -> Union[BasicConditioningInfo, SDXLConditioningInfo]:
if len(text_embeddings) == 1:
# If there is only one text embedding, we can just return it.
# We short-circuit here, because there are some features that are only supported when there is a single
# text_embedding provided.
return text_embeddings[0]
is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo
text_embedding = []
pooled_embedding = None
add_time_ids = None
for text_embedding_info in text_embeddings:
# TODO(ryand): Having to check this feels super hacky.
# Extra conditioning is not supported when there are multiple text embeddings.
assert (
text_embedding_info.extra_conditioning is None
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
)
if is_sdxl:
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
# the conditioning info, then we shouldn't allow it to be passed in.
if pooled_embedding is None:
pooled_embedding = text_embedding_info.pooled_embeds
if add_time_ids is None:
add_time_ids = text_embedding_info.add_time_ids
text_embedding.append(text_embedding_info.embeds)
text_embedding = torch.cat(text_embedding, dim=1)
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
if is_sdxl:
return SDXLConditioningInfo(
embeds=text_embedding,
extra_conditioning=None,
pooled_embeds=pooled_embedding,
add_time_ids=add_time_ids,
)
return BasicConditioningInfo(
embeds=text_embedding,
extra_conditioning=None,
)
def _apply_standard_conditioning(
self,
x,
@ -324,8 +374,9 @@ class InvokeAIDiffuserComponent:
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
assert len(conditioning_data.text_embeddings) == 1
text_embeddings = conditioning_data.text_embeddings[0]
text_embeddings = self._prepare_text_embeddings(conditioning_data.text_embeddings)
if len(conditioning_data.text_embeddings) > 1:
cross_attention_kwargs = {"regional_prompt_data": None}
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None: