mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Pass IP-Adapter conditioning via cross_attention_kwargs instead of concatenating to the text embedding. This avoids interference with other features that manipulate the text embedding (e.g. long prompts).
This commit is contained in:
@ -30,6 +30,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterXL
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ConditioningData,
|
||||
IPAdapterConditioningInfo,
|
||||
)
|
||||
|
||||
from ..util import auto_detect_slice_size, normalize_device
|
||||
@ -449,27 +450,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image)
|
||||
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
|
||||
image_prompt_embeds, uncond_image_prompt_embeds
|
||||
)
|
||||
|
||||
# The following commented block is kept for reference on how to repeat/reshape the image embeddings to
|
||||
# generate a batch of multiple images:
|
||||
# bs_embed, seq_len, _ = image_prompt_embeds.shape
|
||||
# image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
||||
# image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
||||
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
||||
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
||||
|
||||
text_prompt_embeds = conditioning_data.text_embeddings.embeds
|
||||
uncond_text_prompt_embeds = conditioning_data.unconditioned_embeddings.embeds
|
||||
print("text embeds shape:", text_prompt_embeds.shape)
|
||||
concat_prompt_embeds = torch.cat([text_prompt_embeds, image_prompt_embeds], dim=1)
|
||||
concat_uncond_prompt_embeds = torch.cat([uncond_text_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
print("concat embeds shape:", concat_prompt_embeds.shape)
|
||||
conditioning_data.text_embeddings.embeds = concat_prompt_embeds
|
||||
conditioning_data.unconditioned_embeddings.embeds = concat_uncond_prompt_embeds
|
||||
else:
|
||||
image_prompt_embeds = None
|
||||
uncond_image_prompt_embeds = None
|
||||
|
||||
# TODO(ryand): Apply IP-Adapter or custom attention control
|
||||
extra_conditioning_info = conditioning_data.extra
|
||||
with self.invokeai_diffuser.custom_attention_context(
|
||||
self.invokeai_diffuser.model,
|
||||
|
Reference in New Issue
Block a user