mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add concatenation of multiple text conditioning tensors, and patching of RegionalPromptAttnProcessor2_0 into the UNet.
This commit is contained in:
parent
38248b988f
commit
caa690e24d
@ -25,6 +25,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import apply_regional_prompt_attn
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
from ..util import auto_detect_slice_size, normalize_device
|
||||
@ -415,20 +416,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
)
|
||||
|
||||
ip_adapter_unet_patcher = None
|
||||
self.use_ip_adapter = use_ip_adapter
|
||||
if use_cross_attention_control:
|
||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
)
|
||||
self.use_ip_adapter = False
|
||||
elif use_ip_adapter:
|
||||
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
||||
# As it is now, the IP-Adapter will silently be skipped.
|
||||
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
|
||||
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||
self.use_ip_adapter = True
|
||||
elif use_regional_prompting:
|
||||
raise NotImplementedError("Regional prompting is not yet supported.")
|
||||
attn_ctx = apply_regional_prompt_attn(self.invokeai_diffuser.model)
|
||||
else:
|
||||
attn_ctx = nullcontext()
|
||||
|
||||
|
@ -1,7 +1,9 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
||||
from diffusers.utils import USE_PEFT_BACKEND
|
||||
|
||||
@ -17,7 +19,10 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
regional_prompt_data=None,
|
||||
) -> torch.FloatTensor:
|
||||
assert regional_prompt_data is None
|
||||
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
@ -83,3 +88,16 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@contextmanager
|
||||
def apply_regional_prompt_attn(unet: UNet2DConditionModel):
|
||||
"""A context manager that patches `unet` with RegionalPromptAttnProcessor2_0 attention processors."""
|
||||
|
||||
orig_attn_processors = unet.attn_processors
|
||||
|
||||
try:
|
||||
unet.set_attn_processor(RegionalPromptAttnProcessor2_0())
|
||||
yield None
|
||||
finally:
|
||||
unet.set_attn_processor(orig_attn_processors)
|
||||
|
@ -10,6 +10,7 @@ from typing_extensions import TypeAlias
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ConditioningData,
|
||||
ExtraConditioningInfo,
|
||||
PostprocessingSettings,
|
||||
@ -309,6 +310,55 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def _prepare_text_embeddings(
|
||||
self, text_embeddings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]]
|
||||
) -> Union[BasicConditioningInfo, SDXLConditioningInfo]:
|
||||
if len(text_embeddings) == 1:
|
||||
# If there is only one text embedding, we can just return it.
|
||||
# We short-circuit here, because there are some features that are only supported when there is a single
|
||||
# text_embedding provided.
|
||||
return text_embeddings[0]
|
||||
|
||||
is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo
|
||||
|
||||
text_embedding = []
|
||||
pooled_embedding = None
|
||||
add_time_ids = None
|
||||
|
||||
for text_embedding_info in text_embeddings:
|
||||
# TODO(ryand): Having to check this feels super hacky.
|
||||
# Extra conditioning is not supported when there are multiple text embeddings.
|
||||
assert (
|
||||
text_embedding_info.extra_conditioning is None
|
||||
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
|
||||
)
|
||||
|
||||
if is_sdxl:
|
||||
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
|
||||
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
|
||||
# the conditioning info, then we shouldn't allow it to be passed in.
|
||||
if pooled_embedding is None:
|
||||
pooled_embedding = text_embedding_info.pooled_embeds
|
||||
if add_time_ids is None:
|
||||
add_time_ids = text_embedding_info.add_time_ids
|
||||
|
||||
text_embedding.append(text_embedding_info.embeds)
|
||||
|
||||
text_embedding = torch.cat(text_embedding, dim=1)
|
||||
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
||||
|
||||
if is_sdxl:
|
||||
return SDXLConditioningInfo(
|
||||
embeds=text_embedding,
|
||||
extra_conditioning=None,
|
||||
pooled_embeds=pooled_embedding,
|
||||
add_time_ids=add_time_ids,
|
||||
)
|
||||
return BasicConditioningInfo(
|
||||
embeds=text_embedding,
|
||||
extra_conditioning=None,
|
||||
)
|
||||
|
||||
def _apply_standard_conditioning(
|
||||
self,
|
||||
x,
|
||||
@ -324,8 +374,9 @@ class InvokeAIDiffuserComponent:
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
|
||||
assert len(conditioning_data.text_embeddings) == 1
|
||||
text_embeddings = conditioning_data.text_embeddings[0]
|
||||
text_embeddings = self._prepare_text_embeddings(conditioning_data.text_embeddings)
|
||||
if len(conditioning_data.text_embeddings) > 1:
|
||||
cross_attention_kwargs = {"regional_prompt_data": None}
|
||||
|
||||
cross_attention_kwargs = None
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user