mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Pass IP-Adapter conditioning via cross_attention_kwargs instead of concatenating to the text embedding. This avoids interference with other features that manipulate the text embedding (e.g. long prompts).
This commit is contained in:
parent
ddc148b70b
commit
b2d5b53b5f
@ -19,12 +19,42 @@ class AttnProcessor(DiffusersAttnProcessor, nn.Module):
|
|||||||
DiffusersAttnProcessor.__init__(self)
|
DiffusersAttnProcessor.__init__(self)
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn,
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
attention_mask=None,
|
||||||
|
temb=None,
|
||||||
|
ip_adapter_image_prompt_embeds=None,
|
||||||
|
):
|
||||||
|
"""Re-definition of DiffusersAttnProcessor.__call__(...) that accepts and ignores the
|
||||||
|
ip_adapter_image_prompt_embeds parameter.
|
||||||
|
"""
|
||||||
|
return DiffusersAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb)
|
||||||
|
|
||||||
|
|
||||||
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
|
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
DiffusersAttnProcessor2_0.__init__(self)
|
DiffusersAttnProcessor2_0.__init__(self)
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn,
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
attention_mask=None,
|
||||||
|
temb=None,
|
||||||
|
ip_adapter_image_prompt_embeds=None,
|
||||||
|
):
|
||||||
|
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
|
||||||
|
ip_adapter_image_prompt_embeds parameter.
|
||||||
|
"""
|
||||||
|
return DiffusersAttnProcessor2_0.__call__(
|
||||||
|
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class IPAttnProcessor(nn.Module):
|
class IPAttnProcessor(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
@ -32,21 +62,17 @@ class IPAttnProcessor(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
hidden_size (`int`):
|
hidden_size (`int`):
|
||||||
The hidden size of the attention layer.
|
The hidden size of the attention layer.
|
||||||
image_embedding_len (`int`):
|
|
||||||
The length of the IP-Adapter image embedding. It is assumed that the last `image_embedding_len` 'tokens' of
|
|
||||||
the `encoder_hidden_states` are the IP-Adapter image embeddings.
|
|
||||||
cross_attention_dim (`int`):
|
cross_attention_dim (`int`):
|
||||||
The number of channels in the `encoder_hidden_states`.
|
The number of channels in the `encoder_hidden_states`.
|
||||||
scale (`float`, defaults to 1.0):
|
scale (`float`, defaults to 1.0):
|
||||||
the weight scale of image prompt.
|
the weight scale of image prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hidden_size, image_embedding_len, cross_attention_dim=None, scale=1.0):
|
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.cross_attention_dim = cross_attention_dim
|
self.cross_attention_dim = cross_attention_dim
|
||||||
self.image_embedding_len = image_embedding_len
|
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||||
@ -59,7 +85,18 @@ class IPAttnProcessor(nn.Module):
|
|||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
temb=None,
|
temb=None,
|
||||||
|
ip_adapter_image_prompt_embeds=None,
|
||||||
):
|
):
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
||||||
|
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
||||||
|
assert ip_adapter_image_prompt_embeds is not None
|
||||||
|
# The batch dimensions should match.
|
||||||
|
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
|
||||||
|
# The channel dimensions should match.
|
||||||
|
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
|
||||||
|
ip_hidden_states = ip_adapter_image_prompt_embeds
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
if attn.spatial_norm is not None:
|
if attn.spatial_norm is not None:
|
||||||
@ -86,12 +123,6 @@ class IPAttnProcessor(nn.Module):
|
|||||||
elif attn.norm_cross:
|
elif attn.norm_cross:
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
# Split text encoder hidden states and image encoder hidden state.
|
|
||||||
encoder_hidden_states, ip_hidden_states = (
|
|
||||||
encoder_hidden_states[:, : -self.image_embedding_len, :],
|
|
||||||
encoder_hidden_states[:, -self.image_embedding_len :, :],
|
|
||||||
)
|
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
key = attn.to_k(encoder_hidden_states)
|
||||||
value = attn.to_v(encoder_hidden_states)
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
@ -103,7 +134,7 @@ class IPAttnProcessor(nn.Module):
|
|||||||
hidden_states = torch.bmm(attention_probs, value)
|
hidden_states = torch.bmm(attention_probs, value)
|
||||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
# for ip-adapter
|
if ip_hidden_states is not None:
|
||||||
ip_key = self.to_k_ip(ip_hidden_states)
|
ip_key = self.to_k_ip(ip_hidden_states)
|
||||||
ip_value = self.to_v_ip(ip_hidden_states)
|
ip_value = self.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
@ -138,16 +169,13 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
hidden_size (`int`):
|
hidden_size (`int`):
|
||||||
The hidden size of the attention layer.
|
The hidden size of the attention layer.
|
||||||
image_embedding_len (`int`):
|
|
||||||
The length of the IP-Adapter image embedding. It is assumed that the last `image_embedding_len` 'tokens' of
|
|
||||||
the `encoder_hidden_states` are the IP-Adapter image embeddings.
|
|
||||||
cross_attention_dim (`int`):
|
cross_attention_dim (`int`):
|
||||||
The number of channels in the `encoder_hidden_states`.
|
The number of channels in the `encoder_hidden_states`.
|
||||||
scale (`float`, defaults to 1.0):
|
scale (`float`, defaults to 1.0):
|
||||||
the weight scale of image prompt.
|
the weight scale of image prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hidden_size, image_embedding_len, cross_attention_dim=None, scale=1.0):
|
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
@ -155,7 +183,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.cross_attention_dim = cross_attention_dim
|
self.cross_attention_dim = cross_attention_dim
|
||||||
self.text_context_len = text_context_len
|
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||||
@ -168,7 +195,18 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
temb=None,
|
temb=None,
|
||||||
|
ip_adapter_image_prompt_embeds=None,
|
||||||
):
|
):
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
||||||
|
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
||||||
|
assert ip_adapter_image_prompt_embeds is not None
|
||||||
|
# The batch dimensions should match.
|
||||||
|
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
|
||||||
|
# The channel dimensions should match.
|
||||||
|
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
|
||||||
|
ip_hidden_states = ip_adapter_image_prompt_embeds
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
if attn.spatial_norm is not None:
|
if attn.spatial_norm is not None:
|
||||||
@ -200,12 +238,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
elif attn.norm_cross:
|
elif attn.norm_cross:
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
# Split text encoder hidden states and image encoder hidden state.
|
|
||||||
encoder_hidden_states, ip_hidden_states = (
|
|
||||||
encoder_hidden_states[:, : -self.image_embedding_len, :],
|
|
||||||
encoder_hidden_states[:, -self.image_embedding_len :, :],
|
|
||||||
)
|
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
key = attn.to_k(encoder_hidden_states)
|
||||||
value = attn.to_v(encoder_hidden_states)
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
@ -226,7 +258,7 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
# for ip-adapter
|
if ip_hidden_states:
|
||||||
ip_key = self.to_k_ip(ip_hidden_states)
|
ip_key = self.to_k_ip(ip_hidden_states)
|
||||||
ip_value = self.to_v_ip(ip_hidden_states)
|
ip_value = self.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
|
@ -92,7 +92,6 @@ class IPAdapter:
|
|||||||
print("swapping in IPAttnProcessor for", name)
|
print("swapping in IPAttnProcessor for", name)
|
||||||
attn_procs[name] = IPAttnProcessor(
|
attn_procs[name] = IPAttnProcessor(
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
image_embedding_len=self.num_tokens,
|
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
scale=1.0,
|
scale=1.0,
|
||||||
).to(self.device, dtype=torch.float16)
|
).to(self.device, dtype=torch.float16)
|
||||||
|
@ -30,6 +30,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterXL
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterXL
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
|
IPAdapterConditioningInfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..util import auto_detect_slice_size, normalize_device
|
from ..util import auto_detect_slice_size, normalize_device
|
||||||
@ -449,27 +450,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
# Get image embeddings from CLIP and ImageProjModel.
|
# Get image embeddings from CLIP and ImageProjModel.
|
||||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image)
|
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image)
|
||||||
|
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
|
||||||
|
image_prompt_embeds, uncond_image_prompt_embeds
|
||||||
|
)
|
||||||
|
|
||||||
# The following commented block is kept for reference on how to repeat/reshape the image embeddings to
|
# TODO(ryand): Apply IP-Adapter or custom attention control
|
||||||
# generate a batch of multiple images:
|
|
||||||
# bs_embed, seq_len, _ = image_prompt_embeds.shape
|
|
||||||
# image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
|
||||||
# image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
|
||||||
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
|
||||||
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
|
||||||
|
|
||||||
text_prompt_embeds = conditioning_data.text_embeddings.embeds
|
|
||||||
uncond_text_prompt_embeds = conditioning_data.unconditioned_embeddings.embeds
|
|
||||||
print("text embeds shape:", text_prompt_embeds.shape)
|
|
||||||
concat_prompt_embeds = torch.cat([text_prompt_embeds, image_prompt_embeds], dim=1)
|
|
||||||
concat_uncond_prompt_embeds = torch.cat([uncond_text_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
|
||||||
print("concat embeds shape:", concat_prompt_embeds.shape)
|
|
||||||
conditioning_data.text_embeddings.embeds = concat_prompt_embeds
|
|
||||||
conditioning_data.unconditioned_embeddings.embeds = concat_uncond_prompt_embeds
|
|
||||||
else:
|
|
||||||
image_prompt_embeds = None
|
|
||||||
uncond_image_prompt_embeds = None
|
|
||||||
|
|
||||||
extra_conditioning_info = conditioning_data.extra
|
extra_conditioning_info = conditioning_data.extra
|
||||||
with self.invokeai_diffuser.custom_attention_context(
|
with self.invokeai_diffuser.custom_attention_context(
|
||||||
self.invokeai_diffuser.model,
|
self.invokeai_diffuser.model,
|
||||||
|
@ -51,6 +51,18 @@ class PostprocessingSettings:
|
|||||||
v_symmetry_time_pct: Optional[float]
|
v_symmetry_time_pct: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IPAdapterConditioningInfo:
|
||||||
|
cond_image_prompt_embeds: torch.Tensor
|
||||||
|
"""IP-Adapter image encoder conditioning embeddings.
|
||||||
|
Shape: (batch_size, num_tokens, encoding_dim). Typically: (1, 4, 1024) TODO(ryand): confirm
|
||||||
|
"""
|
||||||
|
uncond_image_prompt_embeds: torch.Tensor
|
||||||
|
"""IP-Adapter image encoding embeddings to use for unconditional generation.
|
||||||
|
Shape: (batch_size, num_tokens, encoding_dim). Typically: (1, 4, 1024) TODO(ryand): confirm
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
unconditioned_embeddings: BasicConditioningInfo
|
unconditioned_embeddings: BasicConditioningInfo
|
||||||
@ -69,6 +81,8 @@ class ConditioningData:
|
|||||||
"""
|
"""
|
||||||
postprocessing_settings: Optional[PostprocessingSettings] = None
|
postprocessing_settings: Optional[PostprocessingSettings] = None
|
||||||
|
|
||||||
|
ip_adapter_conditioning: Optional[IPAdapterConditioningInfo] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
return self.text_embeddings.dtype
|
return self.text_embeddings.dtype
|
||||||
|
@ -10,6 +10,7 @@ from typing_extensions import TypeAlias
|
|||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
ConditioningData,
|
||||||
ExtraConditioningInfo,
|
ExtraConditioningInfo,
|
||||||
PostprocessingSettings,
|
PostprocessingSettings,
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
@ -232,6 +233,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
# TODO(ryand): Raise here if both cross attention control and ip-adapter are enabled?
|
||||||
|
|
||||||
cross_attention_control_types_to_do = []
|
cross_attention_control_types_to_do = []
|
||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
if self.cross_attention_control_context is not None:
|
if self.cross_attention_control_context is not None:
|
||||||
@ -339,11 +342,24 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, conditioning_data, **kwargs):
|
def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData, **kwargs):
|
||||||
# fast batched path
|
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
|
||||||
|
the cost of higher memory usage.
|
||||||
|
"""
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
|
cross_attention_kwargs = None
|
||||||
|
if conditioning_data.ip_adapter_conditioning is not None:
|
||||||
|
cross_attention_kwargs = {
|
||||||
|
"ip_adapter_image_prompt_embeds": torch.cat(
|
||||||
|
[
|
||||||
|
conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds,
|
||||||
|
conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
@ -371,6 +387,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
x_twice,
|
x_twice,
|
||||||
sigma_twice,
|
sigma_twice,
|
||||||
both_conditionings,
|
both_conditionings,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
added_cond_kwargs=added_cond_kwargs,
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -382,9 +399,12 @@ class InvokeAIDiffuserComponent:
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data,
|
conditioning_data: ConditioningData,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
|
||||||
|
slower execution speed.
|
||||||
|
"""
|
||||||
# low-memory sequential path
|
# low-memory sequential path
|
||||||
uncond_down_block, cond_down_block = None, None
|
uncond_down_block, cond_down_block = None, None
|
||||||
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
||||||
@ -400,6 +420,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
if mid_block_additional_residual is not None:
|
if mid_block_additional_residual is not None:
|
||||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||||
|
|
||||||
|
# Run unconditional UNet denoising.
|
||||||
|
cross_attention_kwargs = None
|
||||||
|
if conditioning_data.ip_adapter_conditioning is not None:
|
||||||
|
cross_attention_kwargs = {
|
||||||
|
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds
|
||||||
|
}
|
||||||
|
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
@ -412,12 +439,21 @@ class InvokeAIDiffuserComponent:
|
|||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data.unconditioned_embeddings.embeds,
|
conditioning_data.unconditioned_embeddings.embeds,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
down_block_additional_residuals=uncond_down_block,
|
down_block_additional_residuals=uncond_down_block,
|
||||||
mid_block_additional_residual=uncond_mid_block,
|
mid_block_additional_residual=uncond_mid_block,
|
||||||
added_cond_kwargs=added_cond_kwargs,
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Run conditional UNet denoising.
|
||||||
|
cross_attention_kwargs = None
|
||||||
|
if conditioning_data.ip_adapter_conditioning is not None:
|
||||||
|
cross_attention_kwargs = {
|
||||||
|
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds
|
||||||
|
}
|
||||||
|
|
||||||
|
added_cond_kwargs = None
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||||
@ -428,6 +464,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data.text_embeddings.embeds,
|
conditioning_data.text_embeddings.embeds,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
down_block_additional_residuals=cond_down_block,
|
down_block_additional_residuals=cond_down_block,
|
||||||
mid_block_additional_residual=cond_mid_block,
|
mid_block_additional_residual=cond_mid_block,
|
||||||
added_cond_kwargs=added_cond_kwargs,
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user