mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add concatenation of multiple text conditioning tensors, and patching of RegionalPromptAttnProcessor2_0 into the UNet.
This commit is contained in:
parent
38248b988f
commit
caa690e24d
@ -25,6 +25,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import apply_regional_prompt_attn
|
||||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
|
||||||
from ..util import auto_detect_slice_size, normalize_device
|
from ..util import auto_detect_slice_size, normalize_device
|
||||||
@ -415,20 +416,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
ip_adapter_unet_patcher = None
|
ip_adapter_unet_patcher = None
|
||||||
|
self.use_ip_adapter = use_ip_adapter
|
||||||
if use_cross_attention_control:
|
if use_cross_attention_control:
|
||||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||||
self.invokeai_diffuser.model,
|
self.invokeai_diffuser.model,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
)
|
)
|
||||||
self.use_ip_adapter = False
|
|
||||||
elif use_ip_adapter:
|
elif use_ip_adapter:
|
||||||
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
||||||
# As it is now, the IP-Adapter will silently be skipped.
|
# As it is now, the IP-Adapter will silently be skipped.
|
||||||
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
|
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
|
||||||
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||||
self.use_ip_adapter = True
|
|
||||||
elif use_regional_prompting:
|
elif use_regional_prompting:
|
||||||
raise NotImplementedError("Regional prompting is not yet supported.")
|
attn_ctx = apply_regional_prompt_attn(self.invokeai_diffuser.model)
|
||||||
else:
|
else:
|
||||||
attn_ctx = nullcontext()
|
attn_ctx = nullcontext()
|
||||||
|
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
||||||
from diffusers.utils import USE_PEFT_BACKEND
|
from diffusers.utils import USE_PEFT_BACKEND
|
||||||
|
|
||||||
@ -17,7 +19,10 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
temb: Optional[torch.FloatTensor] = None,
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
scale: float = 1.0,
|
scale: float = 1.0,
|
||||||
|
regional_prompt_data=None,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
|
assert regional_prompt_data is None
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
if attn.spatial_norm is not None:
|
if attn.spatial_norm is not None:
|
||||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
@ -83,3 +88,16 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def apply_regional_prompt_attn(unet: UNet2DConditionModel):
|
||||||
|
"""A context manager that patches `unet` with RegionalPromptAttnProcessor2_0 attention processors."""
|
||||||
|
|
||||||
|
orig_attn_processors = unet.attn_processors
|
||||||
|
|
||||||
|
try:
|
||||||
|
unet.set_attn_processor(RegionalPromptAttnProcessor2_0())
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
unet.set_attn_processor(orig_attn_processors)
|
||||||
|
@ -10,6 +10,7 @@ from typing_extensions import TypeAlias
|
|||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
BasicConditioningInfo,
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
ExtraConditioningInfo,
|
ExtraConditioningInfo,
|
||||||
PostprocessingSettings,
|
PostprocessingSettings,
|
||||||
@ -309,6 +310,55 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
|
def _prepare_text_embeddings(
|
||||||
|
self, text_embeddings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]]
|
||||||
|
) -> Union[BasicConditioningInfo, SDXLConditioningInfo]:
|
||||||
|
if len(text_embeddings) == 1:
|
||||||
|
# If there is only one text embedding, we can just return it.
|
||||||
|
# We short-circuit here, because there are some features that are only supported when there is a single
|
||||||
|
# text_embedding provided.
|
||||||
|
return text_embeddings[0]
|
||||||
|
|
||||||
|
is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo
|
||||||
|
|
||||||
|
text_embedding = []
|
||||||
|
pooled_embedding = None
|
||||||
|
add_time_ids = None
|
||||||
|
|
||||||
|
for text_embedding_info in text_embeddings:
|
||||||
|
# TODO(ryand): Having to check this feels super hacky.
|
||||||
|
# Extra conditioning is not supported when there are multiple text embeddings.
|
||||||
|
assert (
|
||||||
|
text_embedding_info.extra_conditioning is None
|
||||||
|
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sdxl:
|
||||||
|
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
|
||||||
|
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
|
||||||
|
# the conditioning info, then we shouldn't allow it to be passed in.
|
||||||
|
if pooled_embedding is None:
|
||||||
|
pooled_embedding = text_embedding_info.pooled_embeds
|
||||||
|
if add_time_ids is None:
|
||||||
|
add_time_ids = text_embedding_info.add_time_ids
|
||||||
|
|
||||||
|
text_embedding.append(text_embedding_info.embeds)
|
||||||
|
|
||||||
|
text_embedding = torch.cat(text_embedding, dim=1)
|
||||||
|
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
||||||
|
|
||||||
|
if is_sdxl:
|
||||||
|
return SDXLConditioningInfo(
|
||||||
|
embeds=text_embedding,
|
||||||
|
extra_conditioning=None,
|
||||||
|
pooled_embeds=pooled_embedding,
|
||||||
|
add_time_ids=add_time_ids,
|
||||||
|
)
|
||||||
|
return BasicConditioningInfo(
|
||||||
|
embeds=text_embedding,
|
||||||
|
extra_conditioning=None,
|
||||||
|
)
|
||||||
|
|
||||||
def _apply_standard_conditioning(
|
def _apply_standard_conditioning(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
@ -324,8 +374,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
assert len(conditioning_data.text_embeddings) == 1
|
text_embeddings = self._prepare_text_embeddings(conditioning_data.text_embeddings)
|
||||||
text_embeddings = conditioning_data.text_embeddings[0]
|
if len(conditioning_data.text_embeddings) > 1:
|
||||||
|
cross_attention_kwargs = {"regional_prompt_data": None}
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
if conditioning_data.ip_adapter_conditioning is not None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user