Pass IP-Adapter conditioning via cross_attention_kwargs instead of concatenating to the text embedding. This avoids interference with other features that manipulate the text embedding (e.g. long prompts).

This commit is contained in:
Ryan Dick
2023-09-08 11:47:36 -04:00
parent ddc148b70b
commit b2d5b53b5f
5 changed files with 135 additions and 68 deletions

View File

@ -30,6 +30,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterXL
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData,
IPAdapterConditioningInfo,
)
from ..util import auto_detect_slice_size, normalize_device
@ -449,27 +450,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image)
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
image_prompt_embeds, uncond_image_prompt_embeds
)
# The following commented block is kept for reference on how to repeat/reshape the image embeddings to
# generate a batch of multiple images:
# bs_embed, seq_len, _ = image_prompt_embeds.shape
# image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
# image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
text_prompt_embeds = conditioning_data.text_embeddings.embeds
uncond_text_prompt_embeds = conditioning_data.unconditioned_embeddings.embeds
print("text embeds shape:", text_prompt_embeds.shape)
concat_prompt_embeds = torch.cat([text_prompt_embeds, image_prompt_embeds], dim=1)
concat_uncond_prompt_embeds = torch.cat([uncond_text_prompt_embeds, uncond_image_prompt_embeds], dim=1)
print("concat embeds shape:", concat_prompt_embeds.shape)
conditioning_data.text_embeddings.embeds = concat_prompt_embeds
conditioning_data.unconditioned_embeddings.embeds = concat_uncond_prompt_embeds
else:
image_prompt_embeds = None
uncond_image_prompt_embeds = None
# TODO(ryand): Apply IP-Adapter or custom attention control
extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,